Previous Up Next

22.2.2  Training a neural network

The command train is used for training neural networks created by the command neural_network (see Section 22.2.1).

Examples

Predicting values of a nonlinear function.

To demonstrate learning a nonlinear function, let

f(x1,x2)=x1sin(3x2)cos(2x1x2). 

We create a neural network with three hidden layers and train it to predict the value f(x1,x2) given the input vector x=(x1,x2) in the square S=[−1,1]2. We use Adam with the default parameters and set the weight decay factor to 10−4. The block size is set to 100 (we will use full-batch gradient descent, meaning that the network will process training samples in bulks of 100 samples at once). The network has 30+20+10 neurons in hidden layers (not counting the bias neurons). We store the network topology in t.

t:=[2,30,20,10,1]:; net:=neural_network(t,momentum=adaptive,weight_decay=1e-4,block_size=100)
     
a neural network with input of size 2 and output of size 1           

Next we create 5000 training samples in S by using the uniform random variable U(−1,1).

f(x):=x[0]*sin(3*x[1])*cos(2*x[0]*x[1]):;
data:=ranm(5000,2,uniformd(-1,1)):; res:=apply(f,data):;

Now we have data points in data and the corresponding function values in res. In a similar manner, we create a collection of another 100 samples, which will be kept unseen by the network and used for testing its accuracy.

test_data:=ranm(100,2,uniformd(-1,1)):; test_res:=apply(f,test_data):;

Next we train the network using 2500 epochs. We test the accuracy in intervals of 250 epochs.

for epoch from 1 to 2500 do net:=train(net,data,res); if irem(epoch,250)==0 then print(mean(net(test_data,test_res))); fi; od:;

0.00211048030912
0.000199757341385
8.70954607301e-05
6.21486919568e-05
5.22746108944e-05
5.0011469063e-05
4.91138941048e-05
4.81631000381e-05
4.86611973063e-05
4.79773288935e-05

Evaluation time: 16.85

Note that half-MSE is used as the error function by default (this is a regression model). Now we generate a random point x0 in S and compute the predicted and exact value of f(x0).

x0:=ranv(2,uniformd(-1,1))
     

−0.402978600934,−0.836934269406
          
net(x0),f(x0)
     
0.18592080555,0.185619807512           

To plot the learned surface, use the command:

plot3d(quote(net([x1,x2])),x1=-1..1,x2=-1..1)
Nonlinear separation of data.

Let

f(t)=
2
5
+
3
2



t
1
2



2



 
,

which defines a parabola that splits the unit square S=[0,1]2 into two regions. We generate 1024 random points in S and label them either as below or above, depending on whether they are located below or above the parabola.

f(t):=0.4+1.5*(t-0.5)^2:; g:=unapply(x[1]<f(x[0])?"below":"above",x):; pts:=ranm(1024,2,uniformd(0,1)):; lab:=apply(g,pts):;

Next we create a neural network with four hidden layers which we train to label random points in S. The error function used by default is the log-loss function since we have a binary classifier.

params:=seq[momentum=adaptive,weight_decay=1e-4,block_size=128]:; net:=neural_network([2,10$4,1],classes=["below","above"],params)
     
a classifier with input of size 2 and 2 classes           

We train on the generated data with batch size 128 and 500 epochs. Training data is shuffled before each epoch in order to avoid symmetry.

for epoch from 1 to 500 do p:=randperm(size(pts)); net:=train(net,sortperm(pts,p),sortperm(lab,p),128); od:;

Now we test the accuracy of the classifier by using 1000 random test samples which we store in tst. The number of misses is the Hamming distance of the vector of predicted labels net(tst) from the vector of correct labels which we obtain by using the command apply(g,tst).

tst:=ranm(1000,2,uniformd(0,1)):; (1-hamdist(net(tst),apply(g,tst))/size(tst))*100.0
     
99.8           
Recognizing handwritten digits.

Here we train a neural network on the MNIST dataset in PNG format, which can be obtained here. This dataset contains 60000 training grayscale images of handwritten digits in 28× 28 resolution, anlogside 10000 testing images (see Figure 22.1).


Figure 22.1: Handwritten digits from the MNIST dataset

Let us assume that the contents of mnist_png.tar.gz are unpacked in the Downloads folder. Now put the files mnist_training.csv and mnist_testing.csv, which can be obtained here, into the subfolders training and testing, respectively. These CSV files contain image paths and labels. We use these files to load and encode training and testing data in Xcas.

First we load the training data, which takes several minutes. Note that we flatten and normalize the images, so that the training vectors contain numbers strictly in [0,1] (see Section 28.1).

train_path:="/home/luka/Downloads/mnist_png/training/":; train_csv:=csv2gen(train_path+"mnist_training.csv",","):; train_data:=[0$size(train_csv)]:; train_lab:=col(train_csv,1):; for k from 1 to size(train_csv) do train_data[k-1]=<flatten(image(train_path+train_csv[k-1,0]))/255.0; od:;

We load the testing images in a similar manner and store them in test_data and test_lab.

Now we create a neural network with three hidden layers for classification of handwritten digits which uses ReLU activation in hidden layers and He normal initialization for weights.

c:=["zero","one","two","three","four","five","six","seven","eight","nine"]:; params:=seq[func=ReLU,weights="he-normal",momentum=adaptive,weight_decay=1e-4]:; net:=neural_network([28*28,500$3,10],block_size=100,classes=c,params);
     
a classifier with input of size 784 and 10 classes           

We train the network with batch size 100 and 5 epochs. Training data is shuffled before each epoch and the mean error on testing data is printed after each epoch. The training takes about a minute.

for epoch from 1 to 5 do p:=randperm(size(train_data)); net:=train(net,sortperm(train_data,p),sortperm(train_lab,p),100); print(mean(net(test_data,test_lab))); od:;

0.110725006182
0.104843072908
0.0859572165559
0.0675629083633
0.0626279369745

The printed error values are computed using the cross-entropy function, which is used by default in multiclass classifiers. To test the accuracy of the network, use the following command.

(1-hamdist(net(test_data),test_lab)/size(test_data))*100.0
     
97.88           

Previous Up Next