The data is loaded from the pre-processing_facette.py script.

# -*- coding: utf-8 -*-

# Imports
from sklearn.datasets import fetch_olivetti_faces
import numpy as np

# Download Olivetti faces dataset
olivetti = fetch_olivetti_faces()
x = olivetti.images
y = olivetti.target

# Print info on shapes and reshape where necessary
print("Original x shape:", x.shape)
X = x.reshape((400, 4096))
print("New x shape:", X.shape)
print("y shape", y.shape)

# Save the numpy arrays
np.savetxt("C://olivetti_X.csv", X, delimiter = ",")
np.savetxt("C://olivetti_y.csv", y, delimiter = ",", fmt = '%d')

print("\nDownloading and reshaping done!")

################################################################################
#                               OUTPUT
################################################################################
#
# Original x shape: (400, 64, 64)
# New x shape: (400, 4096)
# y shape (400,)
#
# Downloading and reshaping done!
# This script is used to resize images from 64x64 to 28x28 pixels

# Clear workspace
#rm(list=ls())

# Load EBImage library
require(EBImage)
## Loading required package: EBImage
# Load data
X <- read.csv("olivetti_X.csv", header = F)
labels <- read.csv("olivetti_y.csv", header = F)

# Dataframe of resized images
rs_df <- data.frame()

# Main loop: for each image, resize and set it to greyscale
for(i in 1:nrow(X))
{
    # Try-catch
    result <- tryCatch({
    # Image (as 1d vector)
    img <- as.numeric(X[i,])
    # Reshape as a 64x64 image (EBImage object)
    img <- Image(img, dim=c(64, 64), colormode = "Grayscale")
    # Resize image to 28x28 pixels
    img_resized <- resize(img, w = 28, h = 28)
    # Get image matrix (there should be another function to do this faster and more neatly!)
    img_matrix <- img_resized@.Data
    # Coerce to a vector
    img_vector <- as.vector(t(img_matrix))
    # Add label
    label <- labels[i,]
    vec <- c(label, img_vector)
    # Stack in rs_df using rbind
    rs_df <- rbind(rs_df, vec)
    # Print status
    #print(paste("Done",i,sep = " "))
    },
    # Error function (just prints the error). Btw you should get no errors!
    error = function(e){print(e)})
}

# Set names. The first columns are the labels, the other columns are the pixels.
names(rs_df) <- c("label", paste("pixel", c(1:784)))
rs_df[1:11, 1:7]
# Train-test split
#-------------------------------------------------------------------------------
# Simple train-test split. No crossvalidation is done in this tutorial.

# Set seed for reproducibility purposes
set.seed(100)

# Shuffled df
shuffled <- rs_df[sample(1:400),]

# Train-test split
train_28 <- shuffled[1:360, ]
test_28 <- shuffled[361:400, ]

# Save train-test datasets
write.csv(train_28, "train_28.csv", row.names = FALSE)
write.csv(test_28, "test_28.csv", row.names = FALSE)

# Done!
test_28[1:11, 1:7]
# Clean workspace
#rm(list=ls())

# Load MXNet
require(mxnet)
## Loading required package: mxnet
# Loading data and set up
#-------------------------------------------------------------------------------

# Load train and test datasets
train <- read.csv("train_28.csv")
test <- read.csv("test_28.csv")

# Set up train and test datasets
train <- data.matrix(train)
train_x <- t(train[, -1])
train_y <- train[, 1]
train_array <- train_x
dim(train_array) <- c(28, 28, 1, ncol(train_x))

test_x <- t(test[, -1])
test_y <- test[, 1]
test_array <- test_x
dim(test_array) <- c(28, 28, 1, ncol(test_x))
test_y[1:11]
##  [1] 25 33  6 34  1 21 27 17 24 39 25
test_x[1:6,1:7]
##              [,1]      [,2]      [,3]      [,4]      [,5]      [,6]
## pixel.1 0.7720737 0.4711376 0.2419674 0.3732290 0.6217111 0.2616377
## pixel.2 0.7425156 0.4970062 0.2866208 0.4025974 0.6606510 0.3320543
## pixel.3 0.7358534 0.5115745 0.3226725 0.2476176 0.6598921 0.3538118
## pixel.4 0.7023318 0.5222845 0.3288075 0.3417946 0.6477273 0.3931523
## pixel.5 0.7160567 0.4895008 0.3616124 0.4795918 0.6201088 0.3807978
## pixel.6 0.7014463 0.2699022 0.5030570 0.5690673 0.5807472 0.4021336
##              [,7]
## pixel.1 0.6333699
## pixel.2 0.6787612
## pixel.3 0.6947841
## pixel.4 0.6914109
## pixel.5 0.6830832
## pixel.6 0.7119666
test_array[1:6, 1:7,1, 1]
##           [,1]      [,2]      [,3]      [,4]      [,5]      [,6]      [,7]
## [1,] 0.7720737 0.7637671 0.7765433 0.7911747 0.7919337 0.7731067 0.7797478
## [2,] 0.7425156 0.7562405 0.7644628 0.7876328 0.7952226 0.7960870 0.7943582
## [3,] 0.7358534 0.7465846 0.7746880 0.7854191 0.7923132 0.8008307 0.8039931
## [4,] 0.7023318 0.7384888 0.7647580 0.7892562 0.8051948 0.7976682 0.7905844
## [5,] 0.7160567 0.7505271 0.7706612 0.7749410 0.7320374 0.6583319 0.5817381
## [6,] 0.7014463 0.6836946 0.6197715 0.5593270 0.6225966 0.6495826 0.6563923
# Set up the symbolic model
#-------------------------------------------------------------------------------

data <- mx.symbol.Variable('data')
# 1st convolutional layer
conv_1 <- mx.symbol.Convolution(data = data, kernel = c(5, 5), num_filter = 20)
tanh_1 <- mx.symbol.Activation(data = conv_1, act_type = "tanh")
pool_1 <- mx.symbol.Pooling(data = tanh_1, pool_type = "max", kernel = c(2, 2), stride = c(2, 2))
# 2nd convolutional layer
conv_2 <- mx.symbol.Convolution(data = pool_1, kernel = c(5, 5), num_filter = 50)
tanh_2 <- mx.symbol.Activation(data = conv_2, act_type = "tanh")
pool_2 <- mx.symbol.Pooling(data=tanh_2, pool_type = "max", kernel = c(2, 2), stride = c(2, 2))
# 1st fully connected layer
flatten <- mx.symbol.Flatten(data = pool_2)
fc_1 <- mx.symbol.FullyConnected(data = flatten, num_hidden = 500)
tanh_3 <- mx.symbol.Activation(data = fc_1, act_type = "tanh")
# 2nd fully connected layer
fc_2 <- mx.symbol.FullyConnected(data = tanh_3, num_hidden = 40)
# Output. Softmax output since we'd like to get some probabilities.
NN_model <- mx.symbol.SoftmaxOutput(data = fc_2)
# Pre-training set up
#-------------------------------------------------------------------------------

# Set seed for reproducibility
mx.set.seed(100)

# Device used. CPU in my case.
devices <- mx.cpu()

# Training
#-------------------------------------------------------------------------------

# Train the model
model <- mx.model.FeedForward.create(NN_model,
                                     X = train_array,
                                     y = train_y,
                                     ctx = devices,
                                     num.round = 480,
                                     array.batch.size = 40,
                                     learning.rate = 0.01,
                                     momentum = 0.9,
                                     eval.metric = mx.metric.accuracy,
                                     epoch.end.callback = mx.callback.log.train.metric(100))
## Start training with 1 devices
## [1] Train-accuracy=0.025
## [2] Train-accuracy=0.0277777777777778
## [3] Train-accuracy=0.0277777777777778
## [4] Train-accuracy=0.0277777777777778
## [5] Train-accuracy=0.0277777777777778
## [6] Train-accuracy=0.0277777777777778
## [7] Train-accuracy=0.0277777777777778
## [8] Train-accuracy=0.0277777777777778
## [9] Train-accuracy=0.0277777777777778
## [10] Train-accuracy=0.0277777777777778
## [11] Train-accuracy=0.0277777777777778
## [12] Train-accuracy=0.0277777777777778
## [13] Train-accuracy=0.0277777777777778
## [14] Train-accuracy=0.0277777777777778
## [15] Train-accuracy=0.0277777777777778
## [16] Train-accuracy=0.0277777777777778
## [17] Train-accuracy=0.0277777777777778
## [18] Train-accuracy=0.0277777777777778
## [19] Train-accuracy=0.0277777777777778
## [20] Train-accuracy=0.0277777777777778
## [21] Train-accuracy=0.0277777777777778
## [22] Train-accuracy=0.0277777777777778
## [23] Train-accuracy=0.0277777777777778
## [24] Train-accuracy=0.0277777777777778
## [25] Train-accuracy=0.0277777777777778
## [26] Train-accuracy=0.0277777777777778
## [27] Train-accuracy=0.0277777777777778
## [28] Train-accuracy=0.0277777777777778
## [29] Train-accuracy=0.0277777777777778
## [30] Train-accuracy=0.0277777777777778
## [31] Train-accuracy=0.0277777777777778
## [32] Train-accuracy=0.0277777777777778
## [33] Train-accuracy=0.0277777777777778
## [34] Train-accuracy=0.0277777777777778
## [35] Train-accuracy=0.0277777777777778
## [36] Train-accuracy=0.0277777777777778
## [37] Train-accuracy=0.0277777777777778
## [38] Train-accuracy=0.0277777777777778
## [39] Train-accuracy=0.0277777777777778
## [40] Train-accuracy=0.0277777777777778
## [41] Train-accuracy=0.0277777777777778
## [42] Train-accuracy=0.0277777777777778
## [43] Train-accuracy=0.0277777777777778
## [44] Train-accuracy=0.0277777777777778
## [45] Train-accuracy=0.0277777777777778
## [46] Train-accuracy=0.0277777777777778
## [47] Train-accuracy=0.0277777777777778
## [48] Train-accuracy=0.0277777777777778
## [49] Train-accuracy=0.0277777777777778
## [50] Train-accuracy=0.0277777777777778
## [51] Train-accuracy=0.0277777777777778
## [52] Train-accuracy=0.0277777777777778
## [53] Train-accuracy=0.0277777777777778
## [54] Train-accuracy=0.0277777777777778
## [55] Train-accuracy=0.0277777777777778
## [56] Train-accuracy=0.0277777777777778
## [57] Train-accuracy=0.0277777777777778
## [58] Train-accuracy=0.0277777777777778
## [59] Train-accuracy=0.0277777777777778
## [60] Train-accuracy=0.0277777777777778
## [61] Train-accuracy=0.0277777777777778
## [62] Train-accuracy=0.0277777777777778
## [63] Train-accuracy=0.0305555555555556
## [64] Train-accuracy=0.0305555555555556
## [65] Train-accuracy=0.0277777777777778
## [66] Train-accuracy=0.0277777777777778
## [67] Train-accuracy=0.0277777777777778
## [68] Train-accuracy=0.0277777777777778
## [69] Train-accuracy=0.025
## [70] Train-accuracy=0.025
## [71] Train-accuracy=0.025
## [72] Train-accuracy=0.0222222222222222
## [73] Train-accuracy=0.0194444444444444
## [74] Train-accuracy=0.0194444444444444
## [75] Train-accuracy=0.0194444444444444
## [76] Train-accuracy=0.0166666666666667
## [77] Train-accuracy=0.0194444444444444
## [78] Train-accuracy=0.0138888888888889
## [79] Train-accuracy=0.0138888888888889
## [80] Train-accuracy=0.0138888888888889
## [81] Train-accuracy=0.0138888888888889
## [82] Train-accuracy=0.0138888888888889
## [83] Train-accuracy=0.0138888888888889
## [84] Train-accuracy=0.0138888888888889
## [85] Train-accuracy=0.0138888888888889
## [86] Train-accuracy=0.0138888888888889
## [87] Train-accuracy=0.0138888888888889
## [88] Train-accuracy=0.0138888888888889
## [89] Train-accuracy=0.0138888888888889
## [90] Train-accuracy=0.0138888888888889
## [91] Train-accuracy=0.0138888888888889
## [92] Train-accuracy=0.0138888888888889
## [93] Train-accuracy=0.0138888888888889
## [94] Train-accuracy=0.0138888888888889
## [95] Train-accuracy=0.0138888888888889
## [96] Train-accuracy=0.0138888888888889
## [97] Train-accuracy=0.0138888888888889
## [98] Train-accuracy=0.0138888888888889
## [99] Train-accuracy=0.0138888888888889
## [100] Train-accuracy=0.0166666666666667
## [101] Train-accuracy=0.0166666666666667
## [102] Train-accuracy=0.0166666666666667
## [103] Train-accuracy=0.0166666666666667
## [104] Train-accuracy=0.0166666666666667
## [105] Train-accuracy=0.0138888888888889
## [106] Train-accuracy=0.0138888888888889
## [107] Train-accuracy=0.0138888888888889
## [108] Train-accuracy=0.0111111111111111
## [109] Train-accuracy=0.0111111111111111
## [110] Train-accuracy=0.0111111111111111
## [111] Train-accuracy=0.0111111111111111
## [112] Train-accuracy=0.0111111111111111
## [113] Train-accuracy=0.0111111111111111
## [114] Train-accuracy=0.0111111111111111
## [115] Train-accuracy=0.0111111111111111
## [116] Train-accuracy=0.0111111111111111
## [117] Train-accuracy=0.0111111111111111
## [118] Train-accuracy=0.0111111111111111
## [119] Train-accuracy=0.0111111111111111
## [120] Train-accuracy=0.0138888888888889
## [121] Train-accuracy=0.0138888888888889
## [122] Train-accuracy=0.0138888888888889
## [123] Train-accuracy=0.0138888888888889
## [124] Train-accuracy=0.0138888888888889
## [125] Train-accuracy=0.0138888888888889
## [126] Train-accuracy=0.0138888888888889
## [127] Train-accuracy=0.0138888888888889
## [128] Train-accuracy=0.0138888888888889
## [129] Train-accuracy=0.0138888888888889
## [130] Train-accuracy=0.0138888888888889
## [131] Train-accuracy=0.0138888888888889
## [132] Train-accuracy=0.0138888888888889
## [133] Train-accuracy=0.0138888888888889
## [134] Train-accuracy=0.0138888888888889
## [135] Train-accuracy=0.0138888888888889
## [136] Train-accuracy=0.0166666666666667
## [137] Train-accuracy=0.0166666666666667
## [138] Train-accuracy=0.0166666666666667
## [139] Train-accuracy=0.0194444444444444
## [140] Train-accuracy=0.0194444444444444
## [141] Train-accuracy=0.0194444444444444
## [142] Train-accuracy=0.0194444444444444
## [143] Train-accuracy=0.0166666666666667
## [144] Train-accuracy=0.0166666666666667
## [145] Train-accuracy=0.0166666666666667
## [146] Train-accuracy=0.0166666666666667
## [147] Train-accuracy=0.0166666666666667
## [148] Train-accuracy=0.0166666666666667
## [149] Train-accuracy=0.0194444444444444
## [150] Train-accuracy=0.0222222222222222
## [151] Train-accuracy=0.0222222222222222
## [152] Train-accuracy=0.0222222222222222
## [153] Train-accuracy=0.0222222222222222
## [154] Train-accuracy=0.0222222222222222
## [155] Train-accuracy=0.0222222222222222
## [156] Train-accuracy=0.0194444444444444
## [157] Train-accuracy=0.0194444444444444
## [158] Train-accuracy=0.0194444444444444
## [159] Train-accuracy=0.0194444444444444
## [160] Train-accuracy=0.0194444444444444
## [161] Train-accuracy=0.0194444444444444
## [162] Train-accuracy=0.0166666666666667
## [163] Train-accuracy=0.0166666666666667
## [164] Train-accuracy=0.0166666666666667
## [165] Train-accuracy=0.0166666666666667
## [166] Train-accuracy=0.0166666666666667
## [167] Train-accuracy=0.0166666666666667
## [168] Train-accuracy=0.0194444444444444
## [169] Train-accuracy=0.0277777777777778
## [170] Train-accuracy=0.0305555555555556
## [171] Train-accuracy=0.0305555555555556
## [172] Train-accuracy=0.0305555555555556
## [173] Train-accuracy=0.0305555555555556
## [174] Train-accuracy=0.0277777777777778
## [175] Train-accuracy=0.0277777777777778
## [176] Train-accuracy=0.0277777777777778
## [177] Train-accuracy=0.0277777777777778
## [178] Train-accuracy=0.0277777777777778
## [179] Train-accuracy=0.0277777777777778
## [180] Train-accuracy=0.0277777777777778
## [181] Train-accuracy=0.0277777777777778
## [182] Train-accuracy=0.0277777777777778
## [183] Train-accuracy=0.0277777777777778
## [184] Train-accuracy=0.0277777777777778
## [185] Train-accuracy=0.0277777777777778
## [186] Train-accuracy=0.0277777777777778
## [187] Train-accuracy=0.0277777777777778
## [188] Train-accuracy=0.0277777777777778
## [189] Train-accuracy=0.0277777777777778
## [190] Train-accuracy=0.0277777777777778
## [191] Train-accuracy=0.0277777777777778
## [192] Train-accuracy=0.0277777777777778
## [193] Train-accuracy=0.0277777777777778
## [194] Train-accuracy=0.0277777777777778
## [195] Train-accuracy=0.0277777777777778
## [196] Train-accuracy=0.0277777777777778
## [197] Train-accuracy=0.0277777777777778
## [198] Train-accuracy=0.0277777777777778
## [199] Train-accuracy=0.0277777777777778
## [200] Train-accuracy=0.0277777777777778
## [201] Train-accuracy=0.0277777777777778
## [202] Train-accuracy=0.0277777777777778
## [203] Train-accuracy=0.0277777777777778
## [204] Train-accuracy=0.0277777777777778
## [205] Train-accuracy=0.0305555555555556
## [206] Train-accuracy=0.0305555555555556
## [207] Train-accuracy=0.0305555555555556
## [208] Train-accuracy=0.0305555555555556
## [209] Train-accuracy=0.0277777777777778
## [210] Train-accuracy=0.0277777777777778
## [211] Train-accuracy=0.0333333333333333
## [212] Train-accuracy=0.0388888888888889
## [213] Train-accuracy=0.0416666666666667
## [214] Train-accuracy=0.0416666666666667
## [215] Train-accuracy=0.0416666666666667
## [216] Train-accuracy=0.05
## [217] Train-accuracy=0.0583333333333333
## [218] Train-accuracy=0.0611111111111111
## [219] Train-accuracy=0.0777777777777778
## [220] Train-accuracy=0.0861111111111111
## [221] Train-accuracy=0.0861111111111111
## [222] Train-accuracy=0.0861111111111111
## [223] Train-accuracy=0.0888888888888889
## [224] Train-accuracy=0.0694444444444444
## [225] Train-accuracy=0.0666666666666667
## [226] Train-accuracy=0.0666666666666667
## [227] Train-accuracy=0.0666666666666667
## [228] Train-accuracy=0.075
## [229] Train-accuracy=0.0916666666666667
## [230] Train-accuracy=0.0833333333333333
## [231] Train-accuracy=0.0944444444444444
## [232] Train-accuracy=0.15
## [233] Train-accuracy=0.152777777777778
## [234] Train-accuracy=0.147222222222222
## [235] Train-accuracy=0.180555555555556
## [236] Train-accuracy=0.175
## [237] Train-accuracy=0.186111111111111
## [238] Train-accuracy=0.216666666666667
## [239] Train-accuracy=0.233333333333333
## [240] Train-accuracy=0.244444444444444
## [241] Train-accuracy=0.266666666666667
## [242] Train-accuracy=0.322222222222222
## [243] Train-accuracy=0.35
## [244] Train-accuracy=0.358333333333333
## [245] Train-accuracy=0.405555555555556
## [246] Train-accuracy=0.433333333333333
## [247] Train-accuracy=0.413888888888889
## [248] Train-accuracy=0.438888888888889
## [249] Train-accuracy=0.461111111111111
## [250] Train-accuracy=0.483333333333333
## [251] Train-accuracy=0.472222222222222
## [252] Train-accuracy=0.494444444444444
## [253] Train-accuracy=0.519444444444444
## [254] Train-accuracy=0.538888888888889
## [255] Train-accuracy=0.577777777777778
## [256] Train-accuracy=0.597222222222222
## [257] Train-accuracy=0.602777777777778
## [258] Train-accuracy=0.622222222222222
## [259] Train-accuracy=0.622222222222222
## [260] Train-accuracy=0.633333333333333
## [261] Train-accuracy=0.686111111111111
## [262] Train-accuracy=0.727777777777778
## [263] Train-accuracy=0.716666666666667
## [264] Train-accuracy=0.730555555555556
## [265] Train-accuracy=0.725
## [266] Train-accuracy=0.730555555555556
## [267] Train-accuracy=0.780555555555556
## [268] Train-accuracy=0.811111111111111
## [269] Train-accuracy=0.813888888888889
## [270] Train-accuracy=0.816666666666667
## [271] Train-accuracy=0.813888888888889
## [272] Train-accuracy=0.855555555555556
## [273] Train-accuracy=0.869444444444444
## [274] Train-accuracy=0.883333333333333
## [275] Train-accuracy=0.891666666666667
## [276] Train-accuracy=0.880555555555556
## [277] Train-accuracy=0.897222222222222
## [278] Train-accuracy=0.905555555555556
## [279] Train-accuracy=0.922222222222222
## [280] Train-accuracy=0.933333333333333
## [281] Train-accuracy=0.933333333333333
## [282] Train-accuracy=0.944444444444444
## [283] Train-accuracy=0.952777777777778
## [284] Train-accuracy=0.961111111111111
## [285] Train-accuracy=0.966666666666667
## [286] Train-accuracy=0.963888888888889
## [287] Train-accuracy=0.975
## [288] Train-accuracy=0.980555555555556
## [289] Train-accuracy=0.980555555555556
## [290] Train-accuracy=0.983333333333333
## [291] Train-accuracy=0.983333333333333
## [292] Train-accuracy=0.986111111111111
## [293] Train-accuracy=0.986111111111111
## [294] Train-accuracy=0.986111111111111
## [295] Train-accuracy=0.988888888888889
## [296] Train-accuracy=0.988888888888889
## [297] Train-accuracy=0.991666666666667
## [298] Train-accuracy=0.991666666666667
## [299] Train-accuracy=0.994444444444444
## [300] Train-accuracy=0.994444444444444
## [301] Train-accuracy=0.994444444444444
## [302] Train-accuracy=0.994444444444444
## [303] Train-accuracy=0.994444444444444
## [304] Train-accuracy=0.994444444444444
## [305] Train-accuracy=0.994444444444444
## [306] Train-accuracy=0.997222222222222
## [307] Train-accuracy=1
## [308] Train-accuracy=1
## [309] Train-accuracy=0.994444444444444
## [310] Train-accuracy=0.988888888888889
## [311] Train-accuracy=0.986111111111111
## [312] Train-accuracy=0.986111111111111
## [313] Train-accuracy=0.966666666666667
## [314] Train-accuracy=0.941666666666667
## [315] Train-accuracy=0.927777777777778
## [316] Train-accuracy=0.841666666666667
## [317] Train-accuracy=0.725
## [318] Train-accuracy=0.786111111111111
## [319] Train-accuracy=0.905555555555555
## [320] Train-accuracy=0.961111111111111
## [321] Train-accuracy=0.986111111111111
## [322] Train-accuracy=0.991666666666667
## [323] Train-accuracy=0.997222222222222
## [324] Train-accuracy=0.997222222222222
## [325] Train-accuracy=0.997222222222222
## [326] Train-accuracy=0.997222222222222
## [327] Train-accuracy=0.997222222222222
## [328] Train-accuracy=1
## [329] Train-accuracy=1
## [330] Train-accuracy=1
## [331] Train-accuracy=1
## [332] Train-accuracy=1
## [333] Train-accuracy=1
## [334] Train-accuracy=1
## [335] Train-accuracy=1
## [336] Train-accuracy=1
## [337] Train-accuracy=1
## [338] Train-accuracy=1
## [339] Train-accuracy=1
## [340] Train-accuracy=1
## [341] Train-accuracy=1
## [342] Train-accuracy=1
## [343] Train-accuracy=1
## [344] Train-accuracy=1
## [345] Train-accuracy=1
## [346] Train-accuracy=1
## [347] Train-accuracy=1
## [348] Train-accuracy=1
## [349] Train-accuracy=1
## [350] Train-accuracy=1
## [351] Train-accuracy=1
## [352] Train-accuracy=1
## [353] Train-accuracy=1
## [354] Train-accuracy=1
## [355] Train-accuracy=1
## [356] Train-accuracy=1
## [357] Train-accuracy=1
## [358] Train-accuracy=1
## [359] Train-accuracy=1
## [360] Train-accuracy=1
## [361] Train-accuracy=1
## [362] Train-accuracy=1
## [363] Train-accuracy=1
## [364] Train-accuracy=1
## [365] Train-accuracy=1
## [366] Train-accuracy=1
## [367] Train-accuracy=1
## [368] Train-accuracy=1
## [369] Train-accuracy=1
## [370] Train-accuracy=1
## [371] Train-accuracy=1
## [372] Train-accuracy=1
## [373] Train-accuracy=1
## [374] Train-accuracy=1
## [375] Train-accuracy=1
## [376] Train-accuracy=1
## [377] Train-accuracy=1
## [378] Train-accuracy=1
## [379] Train-accuracy=1
## [380] Train-accuracy=1
## [381] Train-accuracy=1
## [382] Train-accuracy=1
## [383] Train-accuracy=1
## [384] Train-accuracy=1
## [385] Train-accuracy=1
## [386] Train-accuracy=1
## [387] Train-accuracy=1
## [388] Train-accuracy=1
## [389] Train-accuracy=1
## [390] Train-accuracy=1
## [391] Train-accuracy=1
## [392] Train-accuracy=1
## [393] Train-accuracy=1
## [394] Train-accuracy=1
## [395] Train-accuracy=1
## [396] Train-accuracy=1
## [397] Train-accuracy=1
## [398] Train-accuracy=1
## [399] Train-accuracy=1
## [400] Train-accuracy=1
## [401] Train-accuracy=1
## [402] Train-accuracy=1
## [403] Train-accuracy=1
## [404] Train-accuracy=1
## [405] Train-accuracy=1
## [406] Train-accuracy=1
## [407] Train-accuracy=1
## [408] Train-accuracy=1
## [409] Train-accuracy=1
## [410] Train-accuracy=1
## [411] Train-accuracy=1
## [412] Train-accuracy=1
## [413] Train-accuracy=1
## [414] Train-accuracy=1
## [415] Train-accuracy=1
## [416] Train-accuracy=1
## [417] Train-accuracy=1
## [418] Train-accuracy=1
## [419] Train-accuracy=1
## [420] Train-accuracy=1
## [421] Train-accuracy=1
## [422] Train-accuracy=1
## [423] Train-accuracy=1
## [424] Train-accuracy=1
## [425] Train-accuracy=1
## [426] Train-accuracy=1
## [427] Train-accuracy=1
## [428] Train-accuracy=1
## [429] Train-accuracy=1
## [430] Train-accuracy=1
## [431] Train-accuracy=1
## [432] Train-accuracy=1
## [433] Train-accuracy=1
## [434] Train-accuracy=1
## [435] Train-accuracy=1
## [436] Train-accuracy=1
## [437] Train-accuracy=1
## [438] Train-accuracy=1
## [439] Train-accuracy=1
## [440] Train-accuracy=1
## [441] Train-accuracy=1
## [442] Train-accuracy=1
## [443] Train-accuracy=1
## [444] Train-accuracy=1
## [445] Train-accuracy=1
## [446] Train-accuracy=1
## [447] Train-accuracy=1
## [448] Train-accuracy=1
## [449] Train-accuracy=1
## [450] Train-accuracy=1
## [451] Train-accuracy=1
## [452] Train-accuracy=1
## [453] Train-accuracy=1
## [454] Train-accuracy=1
## [455] Train-accuracy=1
## [456] Train-accuracy=1
## [457] Train-accuracy=1
## [458] Train-accuracy=1
## [459] Train-accuracy=1
## [460] Train-accuracy=1
## [461] Train-accuracy=1
## [462] Train-accuracy=1
## [463] Train-accuracy=1
## [464] Train-accuracy=1
## [465] Train-accuracy=1
## [466] Train-accuracy=1
## [467] Train-accuracy=1
## [468] Train-accuracy=1
## [469] Train-accuracy=1
## [470] Train-accuracy=1
## [471] Train-accuracy=1
## [472] Train-accuracy=1
## [473] Train-accuracy=1
## [474] Train-accuracy=1
## [475] Train-accuracy=1
## [476] Train-accuracy=1
## [477] Train-accuracy=1
## [478] Train-accuracy=1
## [479] Train-accuracy=1
## [480] Train-accuracy=1
#Predict labels
predicted <- predict(model, test_array)
# Assign labels
predicted_labels <- max.col(t(predicted)) - 1
# Get accuracy
table(test[, 1], predicted_labels)
##     predicted_labels
##      0 1 2 5 6 7 9 11 14 16 17 19 21 24 25 27 29 30 33 34 35 36 39
##   0  2 0 0 0 0 0 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##   1  0 3 0 0 0 0 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##   2  0 0 2 0 0 0 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##   5  0 0 0 2 0 0 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##   6  0 0 0 0 1 0 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##   7  0 0 0 0 0 1 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##   9  0 0 0 0 0 1 1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##   11 0 0 0 0 0 0 0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##   14 0 0 0 0 0 0 0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0
##   16 0 0 0 0 0 0 0  0  0  2  0  0  0  0  0  0  0  0  0  0  0  0  0
##   17 0 0 0 0 0 0 0  0  0  0  2  0  0  0  0  0  0  0  0  0  0  0  0
##   19 0 0 0 0 0 0 0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0
##   21 0 0 0 0 0 0 0  0  0  0  0  0  2  0  0  0  0  0  0  0  0  0  0
##   24 0 0 0 0 0 0 0  0  0  0  0  0  0  2  0  0  0  0  0  0  0  0  0
##   25 0 0 0 0 0 0 0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0  0
##   27 0 0 0 0 0 0 0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0
##   29 0 0 0 0 0 0 0  0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0
##   30 0 0 0 0 0 0 0  0  0  0  0  0  0  0  0  0  0  2  0  0  0  0  0
##   33 0 0 0 0 0 0 0  0  0  0  0  0  0  0  0  0  0  0  3  0  0  0  0
##   34 0 0 0 0 0 0 0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0
##   35 0 0 0 0 0 0 0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0
##   36 0 0 0 0 0 0 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  2  0
##   39 0 0 0 0 0 0 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1
predicted[1:8, 1:7]
##              [,1]         [,2]         [,3]         [,4]         [,5]
## [1,] 1.707967e-15 2.870667e-13 1.118668e-06 1.136442e-13 5.591594e-09
## [2,] 2.134021e-11 1.403163e-05 8.560136e-06 1.446867e-11 9.993473e-01
## [3,] 1.731831e-03 6.911800e-12 4.676203e-08 3.764968e-02 9.169337e-16
## [4,] 3.403856e-09 4.240436e-09 2.696895e-04 1.099637e-09 5.501808e-05
## [5,] 5.284884e-08 3.574275e-12 1.257824e-07 3.197167e-02 5.135029e-10
## [6,] 3.671228e-08 1.156968e-14 6.240798e-05 9.586072e-06 5.981553e-08
## [7,] 6.742936e-06 3.523119e-12 9.837387e-01 1.945194e-04 1.861117e-08
## [8,] 8.701897e-14 1.688865e-15 3.996787e-04 2.185292e-07 1.690369e-12
##              [,6]         [,7]
## [1,] 3.087849e-13 9.334341e-16
## [2,] 2.874120e-10 1.221662e-09
## [3,] 1.663032e-07 1.169425e-04
## [4,] 1.607566e-09 4.184732e-05
## [5,] 8.264692e-10 4.062958e-12
## [6,] 9.434141e-15 2.098688e-12
## [7,] 2.260977e-13 7.279967e-06
## [8,] 6.864256e-13 1.035769e-11
results <- mxnet:::predict.MXFeedForwardModel(model = model, X = test_array)

dim(results)
## [1] 40 40
dim(test_x)
## [1] 784  40
class(results)
## [1] "matrix"
test_y
##  [1] 25 33  6 34  1 21 27 17 24 39 25 33  7  9 21  0 29 36  5 29 36 14 27
## [24] 30 24  1 17  2 16 19  2  0  5 11 16 33 35  1 30  9
results[1:8,1:7]
##              [,1]         [,2]         [,3]         [,4]         [,5]
## [1,] 1.707967e-15 2.870667e-13 1.118668e-06 1.136442e-13 5.591594e-09
## [2,] 2.134021e-11 1.403163e-05 8.560136e-06 1.446867e-11 9.993473e-01
## [3,] 1.731831e-03 6.911800e-12 4.676203e-08 3.764968e-02 9.169337e-16
## [4,] 3.403856e-09 4.240436e-09 2.696895e-04 1.099637e-09 5.501808e-05
## [5,] 5.284884e-08 3.574275e-12 1.257824e-07 3.197167e-02 5.135029e-10
## [6,] 3.671228e-08 1.156968e-14 6.240798e-05 9.586072e-06 5.981553e-08
## [7,] 6.742936e-06 3.523119e-12 9.837387e-01 1.945194e-04 1.861117e-08
## [8,] 8.701897e-14 1.688865e-15 3.996787e-04 2.185292e-07 1.690369e-12
##              [,6]         [,7]
## [1,] 3.087849e-13 9.334341e-16
## [2,] 2.874120e-10 1.221662e-09
## [3,] 1.663032e-07 1.169425e-04
## [4,] 1.607566e-09 4.184732e-05
## [5,] 8.264692e-10 4.062958e-12
## [6,] 9.434141e-15 2.098688e-12
## [7,] 2.260977e-13 7.279967e-06
## [8,] 6.864256e-13 1.035769e-11
test[, 1]
##  [1] 25 33  6 34  1 21 27 17 24 39 25 33  7  9 21  0 29 36  5 29 36 14 27
## [24] 30 24  1 17  2 16 19  2  0  5 11 16 33 35  1 30  9
predicted_labels
##  [1] 25 33  6 34  1 21 27 17 24 39 25 33  7  7 21  0 29 36  5 29 36 14 27
## [24] 30 24  1 17  2 16 19  2  0  5 11 16 33 35  1 30  9