This example shows how to update the network state in a network defined as a function.
A batch normalization operation normalizes each input channel across a mini-batch. To speed up training of convolutional neural networks and reduce the sensitivity to network initialization, use batch normalization operations between convolutions and nonlinearities, such as ReLU layers.
During training, batch normalization operations first normalize the activations of each channel by subtracting the mini-batch mean and dividing by the mini-batch standard deviation. Then, the operation shifts the input by a learnable offset β and scales it by a learnable scale factor γ.
When you use a trained network to make predictions on new data, the batch normalization operations use the trained data set mean and variance instead of the mini-batch mean and variance to normalize the activations.
To compute the data set statistics, you must keep track of the mini-batch statistics by using a continually updating state.
If you use batch normalization operations in a model function, then you must define the behavior for both training and prediction. For example, you can specify a Boolean option doTraining
to control whether the model uses mini-batch statistics for training or data set statistics for prediction.
This example piece of code from a model function shows how to apply a batch normalization operation and update only the data set statistics during training.
if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end
The digitTrain4DArrayData
function loads the images, their digit labels, and their angles of rotation from the vertical.
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData; classNames = categories(YTrain); numClasses = numel(classNames); numObservations = numel(YTrain);
View some images from the training data.
idx = randperm(numObservations,64); I = imtile(XTrain(:,:,:,idx)); figure imshow(I)
Define the following network, which predicts both labels and angles of rotation.
A block of convolution, batch normalization, ReLU operations with 16 5-by-5 filters
A branch of two blocks of convolution and batch normalization operations, each with 32 3-by-3 filters and separated by a ReLU operation
A skip connection with a block of convolution and batch normalization operations with 32 1-by-1 convolutions
Combine both branches using an addition operation followed by a ReLU operation
For the regression output, a branch with a fully connect operation of size 1 (the number of responses)
For classification output, a branch with a fully connect operation of size 10 (the number of classes) and a softmax operation
Define the parameters for each of the operations and include them in a struct. Use the format parameters.OperationName.ParameterName
where parameters
is the struct, OperationName
is the name of the operation (for example, conv_1
), and ParameterName
is the name of the parameter (for example, Weights
).
Create a struct parameters
containing the model parameters. Initialize the learnable layer weights using the example function initializeGaussian
, listed at the end of the example. Initialize the learnable layer biases with zeros. Initialize the batch normalization offset and scale parameters with zeros and ones, respectively.
To perform training and inference using batch normalization layers, you must also manage the network state. Before prediction, you must specify the data set mean and variance derived from the training data. Create a struct state
containing the state parameters. Initialize the batch normalization trained mean and trained variance states with zeros and ones, respectively.
parameters.conv1.Weights = dlarray(initializeGaussian([5,5,1,16])); parameters.conv1.Bias = dlarray(zeros(16,1,'single')); parameters.batchnorm1.Offset = dlarray(zeros(16,1,'single')); parameters.batchnorm1.Scale = dlarray(ones(16,1,'single')); state.batchnorm1.TrainedMean = zeros(16,1,'single'); state.batchnorm1.TrainedVariance = ones(16,1,'single'); parameters.convSkip.Weights = dlarray(initializeGaussian([1,1,16,32])); parameters.convSkip.Bias = dlarray(zeros(32,1,'single')); parameters.batchnormSkip.Offset = dlarray(zeros(32,1,'single')); parameters.batchnormSkip.Scale = dlarray(ones(32,1,'single')); state.batchnormSkip.TrainedMean = zeros(32,1,'single'); state.batchnormSkip.TrainedVariance = ones(32,1,'single'); parameters.conv2.Weights = dlarray(initializeGaussian([3,3,16,32])); parameters.conv2.Bias = dlarray(zeros(32,1,'single')); parameters.batchnorm2.Offset = dlarray(zeros(32,1,'single')); parameters.batchnorm2.Scale = dlarray(ones(32,1,'single')); state.batchnorm2.TrainedMean = zeros(32,1,'single'); state.batchnorm2.TrainedVariance = ones(32,1,'single'); parameters.conv3.Weights = dlarray(initializeGaussian([3,3,32,32])); parameters.conv3.Bias = dlarray(zeros(32,1,'single')); parameters.batchnorm3.Offset = dlarray(zeros(32,1,'single')); parameters.batchnorm3.Scale = dlarray(ones(32,1,'single')); state.batchnorm3.TrainedMean = zeros(32,1,'single'); state.batchnorm3.TrainedVariance = ones(32,1,'single'); parameters.fc2.Weights = dlarray(initializeGaussian([numClasses,6272])); parameters.fc2.Bias = dlarray(zeros(numClasses,1,'single')); parameters.fc1.Weights = dlarray(initializeGaussian([1,6272])); parameters.fc1.Bias = dlarray(zeros(1,1,'single'));
View the struct of the state.
state
state = struct with fields:
batchnorm1: [1×1 struct]
batchnormSkip: [1×1 struct]
batchnorm2: [1×1 struct]
batchnorm3: [1×1 struct]
View the state parameters for the batchnorm1
operation.
state.batchnorm1
ans = struct with fields:
TrainedMean: [16×1 single]
TrainedVariance: [16×1 single]
Create the function model
, listed at the end of the example, which computes the outputs of the deep learning model described earlier.
The function model
takes as input the input data dlX
, the model parameters parameters
, the flag doTraining
, which specifies whether the model returns outputs for training or prediction, and the network state state
. The network outputs the predictions for the labels, the predictions for the angles, and the updated network state.
Create the function modelGradients
, listed at the end of the example, which takes as input a mini-batch of input data dlX
with corresponding targets T1
and T2
containing the labels and angles, respectively, and returns the gradients of the loss with respect to the learnable parameters, the updated network state, and the corresponding loss.
Specify the training options.
numEpochs = 20;
miniBatchSize = 128;
plots = "training-progress";
numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Train on a GPU if one is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.
executionEnvironment = "auto";
Train the model using a custom training loop.
For each epoch, shuffle the data and loop over mini-batches of data. At the end of each epoch, display the training progress.
For each mini-batch:
Convert the labels to dummy variables.
Convert the data to dlarray
objects with underlying type single and specify the dimension labels 'SSCB'
(spatial, spatial, channel, batch).
For GPU training, convert the data to gpuArray
objects.
Evaluate the model gradients and loss using dlfeval
and the modelGradients
function.
Update the network parameters using the adamupdate
function.
Initialize the training progress plot.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Initialize the parameters for the Adam solver.
trailingAvg = []; trailingAvgSq = [];
Train the model.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. idx = randperm(numObservations); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(idx); anglesTrain = anglesTrain(idx); % Loop over mini-batches for i = 1:numIterationsPerEpoch iteration = iteration + 1; idx = (i-1)*miniBatchSize+1:i*miniBatchSize; % Read mini-batch of data and convert the labels to dummy % variables. X = XTrain(:,:,:,idx); Y1 = zeros(numClasses, miniBatchSize, 'single'); for c = 1:numClasses Y1(c,YTrain(idx)==classNames(c)) = 1; end Y2 = anglesTrain(idx)'; Y2 = single(Y2); % Convert mini-batch of data to dlarray. dlX = dlarray(X,'SSCB'); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function. [gradients,state,loss] = dlfeval(@modelGradients, dlX, Y1, Y2, parameters, state); % Update the network parameters using the Adam optimizer. [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ... trailingAvg,trailingAvgSq,iteration); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels and angles.
[XTest,YTest,anglesTest] = digitTest4DArrayData;
Convert the data to a dlarray
object with dimension format 'SSCB'
. For GPU prediction, also convert the data to gpuArray
.
dlXTest = dlarray(XTest,'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlXTest = gpuArray(dlXTest); end
To predict the labels and angles of the validation data, use the modelPredictions
function, listed at the end of the example.
[dlYPred,anglesPred] = modelPredictions(parameters,state,dlXTest,miniBatchSize);
Evaluate the classification accuracy.
[~,idx] = max(extractdata(dlYPred),[],1); labelsPred = classNames(idx); accuracy = mean(labelsPred==YTest)
accuracy = 0.9912
Evaluate the regression accuracy.
angleRMSE = sqrt(mean((extractdata(anglesPred) - anglesTest').^2))
angleRMSE = 6.1576
View some of the images with their predictions. Display the predicted angles in red and the correct labels in green.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = extractdata(anglesPred(idx(i))); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') thetaValidation = anglesTest(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--') hold off label = string(labelsPred(idx(i))); title("Label: " + label) end
The function model
takes as input the input data dlX
, the model parameters parameters
, the flag doTraining
, which specifies whether the model returns the outputs for training or prediction, and the network state state
. The function returns the predictions for the labels, the predictions for the angles, and the updated network state.
function [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state) % Convolution weights = parameters.conv1.Weights; bias = parameters.conv1.Bias; dlY = dlconv(dlX,weights,bias,'Padding',2); % Batch normalization, ReLU offset = parameters.batchnorm1.Offset; scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution, batch normalization (skip connection) weights = parameters.convSkip.Weights; bias = parameters.convSkip.Bias; dlYSkip = dlconv(dlY,weights,bias,'Stride',2); offset = parameters.batchnormSkip.Offset; scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance); end % Convolution weights = parameters.conv2.Weights; bias = parameters.conv2.Bias; dlY = dlconv(dlY,weights,bias,'Padding',1,'Stride',2); % Batch normalization, ReLU offset = parameters.batchnorm2.Offset; scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution weights = parameters.conv3.Weights; bias = parameters.conv3.Bias; dlY = dlconv(dlY,weights,bias,'Padding',1); % Batch normalization offset = parameters.batchnorm3.Offset; scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end % Addition, ReLU dlY = dlYSkip + dlY; dlY = relu(dlY); % Fully connect (angles) weights = parameters.fc1.Weights; bias = parameters.fc1.Bias; dlY2 = fullyconnect(dlY,weights,bias); % Fully connect, softmax (labels) weights = parameters.fc2.Weights; bias = parameters.fc2.Bias; dlY1 = fullyconnect(dlY,weights,bias); dlY1 = softmax(dlY1); end
The modelGradients
function takes as input a mini-batch of the input data dlX
with corresponding targets T1
and T2
containing the labels and angles, respectively, and returns the gradients of the loss with respect to the learnable parameters, the updated network state, and the corresponding loss.
function [gradients,state,loss] = modelGradients(dlX,T1,T2,parameters,state) doTraining = true; [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state); lossLabels = crossentropy(dlY1,T1); lossAngles = mse(dlY2,T2); loss = lossLabels + 0.1*lossAngles; gradients = dlgradient(loss,parameters); end
The modelPredictions
function takes the model parameters, the network state, an array of input data dlX
, and a mini-batch size, and returns the model predictions by iterating over mini-batches of the specified size using the model
function with the doTraining
option set to false
.
function [dlYPred,anglesPred] = modelPredictions(parameters,state,dlX,miniBatchSize) doTraining = false; numObservations = size(dlX,4); numIterations = ceil(numObservations / miniBatchSize); numClasses = size(parameters.fc2.Weights,1); dlYPred = zeros(numClasses,numObservations,'like',dlX); anglesPred = zeros(1,numObservations,'like',dlX); for i = 1:numIterations idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); [dlYPred(:,idx),anglesPred(idx)] = model(dlX(:,:,:,idx), parameters,doTraining,state); end end
The initializeGaussian
function samples weights from a Gaussian distribution with mean 0 and standard deviation 0.01.
function parameter = initializeGaussian(sz) parameter = randn(sz,'single').*0.01; end
batchnorm
| crossentropy
| dlarray
| dlconv
| dlfeval
| dlgradient
| fullyconnect
| relu
| sgdmupdate
| softmax