This example shows how to update the network state in a custom training loop.
A batch normalization layer normalizes each input channel across a mini-batch. To speed up training of convolutional neural networks and reduce the sensitivity to network initialization, use batch normalization layers between convolutional layers and nonlinearities, such as ReLU layers.
During training, batch normalization layers first normalize the activations of each channel by subtracting the mini-batch mean and dividing by the mini-batch standard deviation. Then, the layer shifts the input by a learnable offset β and scales it by a learnable scale factor γ.
When network training finishes, batch normalization layers calculate the mean and variance over the full training set and stores the values in the TrainedMean
and TrainedVariance
properties. When you use a trained network to make predictions on new images, the batch normalization layers use the trained mean and variance instead of the mini-batch mean and variance to normalize the activations.
To compute the data set statistics, batch normalization layers keep track of the mini-batch statistics by using a continually updating state. If you are implementing a custom training loop, then you must update the network state between mini-batches.
The digitTrain4DArrayData
function loads images of handwritten digits and their digit labels. Create an arrayDatastore
object for the images and the angles, and then use the combine
function to make a single datastore that contains all of the training data. Extract the class names.
[XTrain,YTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsYTrain = arrayDatastore(YTrain);
dsTrain = combine(dsXTrain,dsYTrain);
classNames = categories(YTrain);
numClasses = numel(classNames);
Define the network and specify the average image using the 'Mean'
option in the image input layer.
layers = [ imageInputLayer([28 28 1], 'Name', 'input', 'Mean', mean(XTrain,4)) convolution2dLayer(5, 20, 'Name', 'conv1') batchNormalizationLayer('Name','bn1') reluLayer('Name', 'relu1') convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv2') batchNormalizationLayer('Name','bn2') reluLayer('Name', 'relu2') convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv3') batchNormalizationLayer('Name','bn3') reluLayer('Name', 'relu3') fullyConnectedLayer(numClasses, 'Name', 'fc') softmaxLayer('Name','softmax')]; lgraph = layerGraph(layers);
Create a dlnetwork
object from the layer graph.
dlnet = dlnetwork(lgraph)
dlnet = dlnetwork with properties: Layers: [12×1 nnet.cnn.layer.Layer] Connections: [11×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'input'} OutputNames: {'softmax'}
View the network state. Each batch normalization layer has a TrainedMean
parameter and a TrainedVariance
parameter containing the data set mean and variance, respectively.
dlnet.State
ans=6×3 table
Layer Parameter Value
_____ _________________ _______________
"bn1" "TrainedMean" {1×1×20 single}
"bn1" "TrainedVariance" {1×1×20 single}
"bn2" "TrainedMean" {1×1×20 single}
"bn2" "TrainedVariance" {1×1×20 single}
"bn3" "TrainedMean" {1×1×20 single}
"bn3" "TrainedVariance" {1×1×20 single}
Create the function modelGradients
, listed at the end of the example, which takes as input a dlnetwork
object dlnet
, and a mini-batch of input data dlX
with corresponding labels Y
, and returns the gradients of the loss with respect to the learnable parameters in dlnet
and the corresponding loss.
Train for five epochs using a mini-batch size of 128. For the SGDM optimization, specify a learning rate of 0.01 and a momentum of 0.9.
numEpochs = 5; miniBatchSize = 128; learnRate = 0.01; momentum = 0.9;
Visualize the training progress in a plot.
plots = "training-progress";
Use minibatchqueue
to process and manage the mini-batches of images. For each mini-batch:
Use the custom mini-batch preprocessing function preprocessMiniBatch
(defined at the end of this example) to one-hot encode the class labels.
Format the image data with the dimension labels 'SSCB'
(spatial, spatial, channel, batch). By default, the minibatchqueue
object converts the data to dlarray
objects with underlying type single
. Do not add a format to the class labels.
Train on a GPU if one is available. By default, the minibatchqueue
object converts each output to a gpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.
mbq = minibatchqueue(dsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB',''});
Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. At the end of each iteration, display the training progress. For each mini-batch:
Evaluate the model gradients, state, and loss using dlfeval
and the modelGradients
function and update the network state.
Update the network parameters using the sgdmupdate
function.
Initialize the training progress plot.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Initialize the velocity parameter for the SGDM solver.
velocity = [];
Train the network.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq) % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data and convert the labels to dummy % variables. [dlX,dlY] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function and update the network state. [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY); dlnet.State = state; % Update the network parameters using the SGDM optimizer. [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels and angles. Manage the test data set using a minibatchqueue
object with the same setting as the training data.
[XTest,YTest] = digitTest4DArrayData; dsXTest = arrayDatastore(XTest,'IterationDimension',4); dsYTest = arrayDatastore(YTest); dsTest = combine(dsXTest,dsYTest); mbqTest = minibatchqueue(dsTest,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB',''});
Classify the images using the modelPredictions
function, listed at the end of the example. The function returns the predicted classes and the comparison with the true values.
[classesPredictions,classCorr] = modelPredictions(dlnet,mbqTest,classNames);
Evaluate the classification accuracy.
accuracy = mean(classCorr)
accuracy = 0.9946
The modelGradients
function takes as input a dlnetwork
object dlnet
and a mini-batch of input data dlX
with corresponding labels Y
, and returns the gradients of the loss with respect to the learnable parameters in dlnet
, the network state, and the loss. To compute the gradients automatically, use the dlgradient
function.
function [gradients,state,loss] = modelGradients(dlnet,dlX,Y) [dlYPred,state] = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); end
The modelPredictions
function takes as input a dlnetwork
object dlnet
, a minibatchqueue
of input data mbq
, and computes the model predictions by iterating all data in the minibatchqueue
. The function uses the onehotdecode
function to find the predicted class with the highest score and then compares the prediction with the true class. The function returns the predictions and a vector of ones and zeros that represents correct and incorrect predictions.
function [classesPredictions,classCorr] = modelPredictions(dlnet,mbq,classes) classesPredictions = []; classCorr = []; while hasdata(mbq) [dlX,dlY] = next(mbq); % Make predictions using the model function. dlYPred = predict(dlnet,dlX); % Determine predicted classes. YPredBatch = onehotdecode(dlYPred,classes,1); classesPredictions = [classesPredictions YPredBatch]; % Compare predicted and true classes Y = onehotdecode(dlY,classes,1); classCorr = [classCorr YPredBatch == Y]; end end
The preprocessMiniBatch
function preprocesses the data using the following steps:
Extract the image data from the incoming cell array and concatenate into a numeric array. Concatenating the image data over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.
Extract the label data from the incoming cell arrays and concatenate into a categorical array along the second dimension..
One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.
function [X,Y] = preprocessMiniBatch(XCell,YCell) % Extract image data from cell and concatenate X = cat(4,XCell{:}); % Extract label data from cell and concatenate Y = cat(2,YCell{:}); % One-hot encode labels Y = onehotencode(Y,1); end
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| forward
| minibatchqueue
| onehotdecode
| onehotencode
| predict