Update Batch Normalization Statistics in Custom Training Loop

This example shows how to update the network state in a custom training loop.

A batch normalization layer normalizes each input channel across a mini-batch. To speed up training of convolutional neural networks and reduce the sensitivity to network initialization, use batch normalization layers between convolutional layers and nonlinearities, such as ReLU layers.

During training, batch normalization layers first normalize the activations of each channel by subtracting the mini-batch mean and dividing by the mini-batch standard deviation. Then, the layer shifts the input by a learnable offset β and scales it by a learnable scale factor γ.

When network training finishes, batch normalization layers calculate the mean and variance over the full training set and stores the values in the TrainedMean and TrainedVariance properties. When you use a trained network to make predictions on new images, the batch normalization layers use the trained mean and variance instead of the mini-batch mean and variance to normalize the activations.

To compute the data set statistics, batch normalization layers keep track of the mini-batch statistics by using a continually updating state. If you are implementing a custom training loop, then you must update the network state between mini-batches.

Load Training Data

Load the digits data.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Define Network

Define the network and specify the average image using the 'Mean' option in the image input layer.

layers = [
    imageInputLayer([28 28 1], 'Name', 'input', 'Mean', mean(XTrain,4))
    convolution2dLayer(5, 20, 'Name', 'conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name', 'relu2')
    convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name', 'relu3')
    fullyConnectedLayer(numClasses, 'Name', 'fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Create a dlnetwork object from the layer graph.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}

View the network state. Each batch normalization layer has a TrainedMean parameter and a TrainedVariance parameter containing the data set mean and variance, respectively.

dlnet.State
ans=6×3 table
    Layer        Parameter             Value     
    _____    _________________    _______________

    "bn1"    "TrainedMean"        {1×1×20 single}
    "bn1"    "TrainedVariance"    {1×1×20 single}
    "bn2"    "TrainedMean"        {1×1×20 single}
    "bn2"    "TrainedVariance"    {1×1×20 single}
    "bn3"    "TrainedMean"        {1×1×20 single}
    "bn3"    "TrainedVariance"    {1×1×20 single}

Define Model Gradients Function

Create the function modelGradients, listed at the end of the example, which takes as input a dlnetwork object dlnet, and a mini-batch of input data dlX with corresponding labels Y, and returns the gradients of the loss with respect to the learnable parameters in dlnet and the corresponding loss.

Specify Training Options

Train with a mini-batch size of 128 for 5 epochs. For the SGDM optimization. Specify a learning rate of 0.01 and a momentum of 0.9.

numEpochs = 5;
miniBatchSize = 128;
learnRate = 0.01;
momentum = 0.9;

Visualize the training progress in a plot.

plots = "training-progress";

Train on a GPU if one is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.

executionEnvironment = "auto";

Train Model

Train the model using a custom training loop.

For each epoch, shuffle the data and loop over mini-batches of data. At the end of each epoch, display the training progress.

For each mini-batch:

  • Convert the labels to dummy variables.

  • Convert the data to dlarray objects with underlying type single and specify the dimension labels 'SSCB' (spatial, spatial, channel, batch).

  • For GPU training, convert the data to gpuArray objects.

  • Evaluate the model gradients, state, and loss using dlfeval and the modelGradients function and update the network state.

  • Update the network parameters using the sgdmupdate function.

Initialize the training progress plot.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Initialize the velocity parameter for the SGDM solver.

velocity = [];

Train the network.

numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        dlnet.State = state;
                
        % Update the network parameters using the SGDM optimizer.
        [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Test Model

Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.

[XTest, YTest] = digitTest4DArrayData;

Convert the data to a dlarray object with dimension format 'SSCB'. For GPU prediction, also convert the data to gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

Classify the images using the modelPredictions function, listed at the end of the example, and find the classes with the highest scores.

dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Evaluate the classification accuracy.

accuracy = mean(YPred == YTest)
accuracy = 0.9948

Model Gradients Function

The modelGradients function takes as input a dlnetwork object dlnet and a mini-batch of input data dlX with corresponding labels Y, and returns the gradients of the loss with respect to the learnable parameters in dlnet, the network state, and the loss. To compute the gradients automatically, use the dlgradient function.

function [gradients,state,loss] = modelGradients(dlnet,dlX,Y)

[dlYPred,state] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);

end

Model Predictions Function

The modelPredictions function takes as input a dlnetwork object dlnet, an array of input data dlX, and a mini-batch size, and returns the model predictions by iterating over mini-batches of the specified size.

function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize)

numObservations = size(dlX,4);
numIterations = ceil(numObservations / miniBatchSize);

numClasses = dlnet.Layers(11).OutputSize;
dlYPred = zeros(numClasses,numObservations,'like',dlX);

for i = 1:numIterations
    idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    
    dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx));
end

end

See Also

| | | | | |

Related Topics