Train Network Using Cyclical Learn Rate for Snapshot Ensembling

This example shows how to train a network to classify images of objects using a cyclical learn rate schedule and snapshot ensembling for better test accuracy. In the example, you learn how to use a cosine function for the learn rate schedule, take snapshots of the network during training to create a model ensemble, and add L2-norm regularization (weight decay) to the training loss.

This example trains a residual network [1] on the CIFAR-10 data set [2] with a custom cyclical learn rate: for each iteration, the solver uses the learn rate given by a shifted cosine function [3] alpha(t) = (alpha0/2)*cos(pi*mod(t-1,T/M)/(T/M)+1), where t is the iteration number, T is the total number of training iterations, alpha0 is the initial learn rate, and M is the number of cycles/snapshots. This learn rate schedule effectively splits the training process into M cycles. Each cycle begins with a large learning rate that decays monotonically, forcing the network to explore different local minima. At the end of each training cycle, you take a snapshot of the network (that is, you save the model at this iteration) and later average the predictions of all the snapshot models, also known as snapshot ensembling [4], to improve the final test accuracy.

Prepare Data

Download the CIFAR-10 data set [2]. The data set contains 60,000 images. Each image is 32-by-32 in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time.

datadir = tempdir; 
downloadCIFARData(datadir);

Load the CIFAR-10 training and test images as 4-D arrays. The training set contains 50,000 images and the test set contains 10,000 images.

[XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir);
classes = categories(YTrain);
numClasses = numel(classes);

You can display a random sample of the training images using the following code.

figure;
idx = randperm(size(XTrain,4),20);
im = imtile(XTrain(:,:,:,idx),'ThumbnailSize',[96,96]);
imshow(im)

Create an augmentedImageDatastore object to use for network training. During training, the datastore randomly flips the training images along the vertical axis and randomly translates them up to four pixels horizontally and vertically. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(imageSize,XTrain,YTrain, ...
    'DataAugmentation',imageAugmenter);
auimdsTest = augmentedImageDatastore(imageSize, XTest, YTest);

Define Network Architecture

Create a residual network [1] with six standard convolutional units (two units per stage) and a width of 16. The total network depth is 2*6+2 = 14. In addition, specify the average image using the 'Mean' option in the image input layer.

netWidth = 16;
layers = [
    imageInputLayer(imageSize,'Name','input','Mean', mean(XTrain,4))
    convolution2dLayer(3,netWidth,'Padding','same','Name','convInp')
    batchNormalizationLayer('Name','BNInp')
    reluLayer('Name','reluInp')
    
    convolutionalUnit(netWidth,1,'S1U1')
    additionLayer(2,'Name','add11')
    reluLayer('Name','relu11')
    convolutionalUnit(netWidth,1,'S1U2')
    additionLayer(2,'Name','add12')
    reluLayer('Name','relu12')
    
    convolutionalUnit(2*netWidth,2,'S2U1')
    additionLayer(2,'Name','add21')
    reluLayer('Name','relu21')
    convolutionalUnit(2*netWidth,1,'S2U2')
    additionLayer(2,'Name','add22')
    reluLayer('Name','relu22')
    
    convolutionalUnit(4*netWidth,2,'S3U1')
    additionLayer(2,'Name','add31')
    reluLayer('Name','relu31')
    convolutionalUnit(4*netWidth,1,'S3U2')
    additionLayer(2,'Name','add32')
    reluLayer('Name','relu32')
    
    averagePooling2dLayer(8,'Name','globalPool')
    fullyConnectedLayer(10,'Name','fcFinal')
    ];

lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph,'reluInp','add11/in2');
lgraph = connectLayers(lgraph,'relu11','add12/in2');
skip1 = [
    convolution2dLayer(1,2*netWidth,'Stride',2,'Name','skipConv1')
    batchNormalizationLayer('Name','skipBN1')];
lgraph = addLayers(lgraph,skip1);
lgraph = connectLayers(lgraph,'relu12','skipConv1');
lgraph = connectLayers(lgraph,'skipBN1','add21/in2');
lgraph = connectLayers(lgraph,'relu21','add22/in2');
skip2 = [
    convolution2dLayer(1,4*netWidth,'Stride',2,'Name','skipConv2')
    batchNormalizationLayer('Name','skipBN2')];
lgraph = addLayers(lgraph,skip2);
lgraph = connectLayers(lgraph,'relu22','skipConv2');
lgraph = connectLayers(lgraph,'skipBN2','add31/in2');
lgraph = connectLayers(lgraph,'relu31','add32/in2');

Plot the ResNet architecture.

figure;
plot(lgraph)

Create a dlnetwork object from the layer graph.

dlnet = dlnetwork(lgraph);

Define Model Gradients Function

Create the helper function modelGradients, listed at the end of the example. The function takes in a dlnetwork object dlnet and a mini-batch of input data dlX with corresponding labels Y, and returns the gradients of the loss with respect to the learnable parameters in dlnet. This function also returns the loss and the state of the nonlearnable parameters of the network at a given iteration.

Specify Training Options

Specify the training options.

velocity = [];
numEpochs = 200;
miniBatchSize = 64;
augimdsTrain.MiniBatchSize = miniBatchSize;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);
momentum = 0.9;
weightDecay = 1e-4;

Specify the training options specific to the cyclical learn rate. Alpha0 is the initial learn rate and numSnapshots is the number of cycles or snapshots taken during training.

alpha0 = 0.1;
numSnapshots = 5;
epochsPerSnapshot = numEpochs./numSnapshots; 
iterationsPerSnapshot = ceil(numObservations./miniBatchSize)*numEpochs./numSnapshots;
modelPrefix = "SnapshotEpoch";

Train on a GPU if one is available (requires Parallel Computing Toolbox™).

executionEnvironment = "auto";

Initialize the training figure.

[lossLine, learnRateLine] = plotLossAndLearnRate();

Train Model

Train the model using a custom training loop.

For each epoch, shuffle the datastore, loop over mini-batches of data, and save the model (snapshot) if the current epoch is a multiple of epochsPerSnapshot. At the end of each epoch, display the training progress.

For each mini-batch:

  • Convert the labels to dummy variables.

  • Convert the data to dlarray objects with underlying type single and specify the dimension labels 'SSCB' (spatial, spatial, channel, batch).

  • For GPU training, convert the mini-batch data to gpuArray objects.

  • Evaluate the model gradients and loss using dlfeval and the modelGradients function.

  • Update the state of the nonlearnable parameters of the network.

  • Determine the learn rate for the cyclical learn rate schedule.

  • Update the network parameters using the sgdmupdate function.

  • Plot the loss and learn rate at each iteration.

For this example, the training took approximately 18h on a NVIDIA™ GeForce GTX 1080.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Reset image datastore.
    reset(augimdsTrain);
  
    % Shuffle data.
    augimdsTrain = shuffle(augimdsTrain);
    
    % Save snapshot model.
    if ~mod(epoch,epochsPerSnapshot)
        save(modelPrefix + epoch + ".mat",'dlnet');
    end
    
    % Loop over mini-batches.
    while hasdata(augimdsTrain)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimdsTrain);
        
        % Concatenate the inputs.
        Xdata = data{:,1};
        X = cat(4,Xdata{:});

        % Convert the labels to dummy variables.
        TrueClasses = data{:,2};
        Y = zeros(numClasses, numel(TrueClasses), 'single');
        for c = 1:numClasses
            Y(c,TrueClasses==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [gradients, loss, state] = dlfeval(@modelGradients,dlnet,dlX,Y,weightDecay);
        
        % Update the state of nonlearnable parameters.
        dlnet.State = state;
        
        % Determine learn rate for cyclical learn rate schedule.
        learnRate = 0.5*alpha0*(cos((pi*mod(iteration-1,iterationsPerSnapshot)./iterationsPerSnapshot))+1);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet.Learnables, velocity] = sgdmupdate(dlnet.Learnables, gradients, velocity, learnRate, momentum);
        
        % Plot loss and learn rate for current iteration.
        loss = double(gather(extractdata(loss)));
        addpoints(lossLine, iteration, loss);
        addpoints(learnRateLine, iteration, learnRate);
        drawnow
        
    end
    
    % Display the training progress.
    D = duration(0,0,toc(start),'Format','hh:mm:ss');
    disp( ...
        "Epoch: " + epoch + ", " + ...
        "Loss: " + num2str(loss) + ", " + ...
        "Elapsed: " + string(D))
end
Epoch: 1, Loss: 1.784, Elapsed: 00:05:59
Epoch: 2, Loss: 1.1388, Elapsed: 00:11:32
Epoch: 3, Loss: 1.0594, Elapsed: 00:17:02
Epoch: 4, Loss: 1.3154, Elapsed: 00:22:34
Epoch: 5, Loss: 0.7968, Elapsed: 00:28:08
Epoch: 6, Loss: 0.85319, Elapsed: 00:33:47
Epoch: 7, Loss: 0.5477, Elapsed: 00:39:26
Epoch: 8, Loss: 1.1158, Elapsed: 00:45:14
Epoch: 9, Loss: 0.39734, Elapsed: 00:50:47
Epoch: 10, Loss: 0.47909, Elapsed: 00:56:26
Epoch: 11, Loss: 0.68743, Elapsed: 01:02:08
Epoch: 12, Loss: 0.45266, Elapsed: 01:07:51
Epoch: 13, Loss: 1.0058, Elapsed: 01:13:27
Epoch: 14, Loss: 0.64816, Elapsed: 01:19:14
Epoch: 15, Loss: 1.1039, Elapsed: 01:25:05
Epoch: 16, Loss: 1.2072, Elapsed: 01:30:41
Epoch: 17, Loss: 0.43167, Elapsed: 01:36:24
Epoch: 18, Loss: 1.1376, Elapsed: 01:42:40
Epoch: 19, Loss: 0.76857, Elapsed: 01:48:40
Epoch: 20, Loss: 0.77434, Elapsed: 01:54:22
Epoch: 21, Loss: 0.98595, Elapsed: 01:59:57
Epoch: 22, Loss: 0.78628, Elapsed: 02:05:32
Epoch: 23, Loss: 0.55069, Elapsed: 02:11:07
Epoch: 24, Loss: 0.52066, Elapsed: 02:16:43
Epoch: 25, Loss: 0.44842, Elapsed: 02:22:18
Epoch: 26, Loss: 0.40094, Elapsed: 02:27:55
Epoch: 27, Loss: 0.78839, Elapsed: 02:33:30
Epoch: 28, Loss: 0.47829, Elapsed: 02:39:05
Epoch: 29, Loss: 0.21833, Elapsed: 02:44:41
Epoch: 30, Loss: 0.5759, Elapsed: 02:50:16
Epoch: 31, Loss: 1.1089, Elapsed: 02:55:50
Epoch: 32, Loss: 0.37353, Elapsed: 03:01:25
Epoch: 33, Loss: 0.30851, Elapsed: 03:07:01
Epoch: 34, Loss: 0.34735, Elapsed: 03:12:36
Epoch: 35, Loss: 0.28772, Elapsed: 03:18:11
Epoch: 36, Loss: 0.31045, Elapsed: 03:23:47
Epoch: 37, Loss: 0.28555, Elapsed: 03:29:22
Epoch: 38, Loss: 0.897, Elapsed: 03:34:58
Epoch: 39, Loss: 0.69014, Elapsed: 03:40:33
Epoch: 40, Loss: 0.26282, Elapsed: 03:46:17
Epoch: 41, Loss: 1.0086, Elapsed: 03:51:53
Epoch: 42, Loss: 0.47303, Elapsed: 03:57:27
Epoch: 43, Loss: 1.3765, Elapsed: 04:03:02
Epoch: 44, Loss: 0.54884, Elapsed: 04:08:39
Epoch: 45, Loss: 0.38778, Elapsed: 04:14:14
Epoch: 46, Loss: 0.74121, Elapsed: 04:19:49
Epoch: 47, Loss: 0.78481, Elapsed: 04:25:25
Epoch: 48, Loss: 0.44624, Elapsed: 04:31:01
Epoch: 49, Loss: 0.81747, Elapsed: 04:36:38
Epoch: 50, Loss: 0.40319, Elapsed: 04:42:14
Epoch: 51, Loss: 0.87757, Elapsed: 04:47:51
Epoch: 52, Loss: 1.0567, Elapsed: 04:53:27
Epoch: 53, Loss: 0.29019, Elapsed: 04:59:03
Epoch: 54, Loss: 0.92056, Elapsed: 05:04:40
Epoch: 55, Loss: 0.45776, Elapsed: 05:10:16
Epoch: 56, Loss: 1.0265, Elapsed: 05:15:52
Epoch: 57, Loss: 0.55256, Elapsed: 05:21:29
Epoch: 58, Loss: 1.0822, Elapsed: 05:27:06
Epoch: 59, Loss: 0.78332, Elapsed: 05:32:44
Epoch: 60, Loss: 0.48247, Elapsed: 05:38:20
Epoch: 61, Loss: 0.86749, Elapsed: 05:43:58
Epoch: 62, Loss: 0.64667, Elapsed: 05:49:34
Epoch: 63, Loss: 0.64563, Elapsed: 05:55:10
Epoch: 64, Loss: 0.58239, Elapsed: 06:00:46
Epoch: 65, Loss: 0.29219, Elapsed: 06:06:23
Epoch: 66, Loss: 0.37627, Elapsed: 06:11:59
Epoch: 67, Loss: 0.34035, Elapsed: 06:17:35
Epoch: 68, Loss: 0.34809, Elapsed: 06:23:11
Epoch: 69, Loss: 0.61085, Elapsed: 06:28:47
Epoch: 70, Loss: 0.42018, Elapsed: 06:34:24
Epoch: 71, Loss: 0.3739, Elapsed: 06:40:00
Epoch: 72, Loss: 0.23083, Elapsed: 06:45:37
Epoch: 73, Loss: 0.21324, Elapsed: 06:51:14
Epoch: 74, Loss: 0.18931, Elapsed: 06:56:55
Epoch: 75, Loss: 0.88882, Elapsed: 07:02:31
Epoch: 76, Loss: 0.36844, Elapsed: 07:08:09
Epoch: 77, Loss: 0.76548, Elapsed: 07:13:46
Epoch: 78, Loss: 0.42548, Elapsed: 07:19:24
Epoch: 79, Loss: 0.29112, Elapsed: 07:25:01
Epoch: 80, Loss: 0.17333, Elapsed: 07:30:45
Epoch: 81, Loss: 0.50322, Elapsed: 07:36:22
Epoch: 82, Loss: 0.40387, Elapsed: 07:41:58
Epoch: 83, Loss: 0.3939, Elapsed: 07:47:34
Epoch: 84, Loss: 0.79005, Elapsed: 07:53:11
Epoch: 85, Loss: 0.51953, Elapsed: 07:58:46
Epoch: 86, Loss: 0.65925, Elapsed: 08:04:24
Epoch: 87, Loss: 0.49915, Elapsed: 08:10:01
Epoch: 88, Loss: 0.58721, Elapsed: 08:15:38
Epoch: 89, Loss: 0.57397, Elapsed: 08:21:15
Epoch: 90, Loss: 0.51315, Elapsed: 08:26:53
Epoch: 91, Loss: 0.42037, Elapsed: 08:32:30
Epoch: 92, Loss: 0.41111, Elapsed: 08:38:06
Epoch: 93, Loss: 0.71338, Elapsed: 08:43:43
Epoch: 94, Loss: 0.31452, Elapsed: 08:49:21
Epoch: 95, Loss: 0.35696, Elapsed: 08:54:58
Epoch: 96, Loss: 0.56142, Elapsed: 09:00:36
Epoch: 97, Loss: 0.69246, Elapsed: 09:06:15
Epoch: 98, Loss: 0.40288, Elapsed: 09:11:53
Epoch: 99, Loss: 0.67491, Elapsed: 09:17:31
Epoch: 100, Loss: 0.70555, Elapsed: 09:23:08
Epoch: 101, Loss: 0.45978, Elapsed: 09:28:47
Epoch: 102, Loss: 0.3963, Elapsed: 09:34:27
Epoch: 103, Loss: 0.60798, Elapsed: 09:40:05
Epoch: 104, Loss: 0.41759, Elapsed: 09:45:45
Epoch: 105, Loss: 0.45068, Elapsed: 09:51:23
Epoch: 106, Loss: 1.103, Elapsed: 09:57:02
Epoch: 107, Loss: 0.29916, Elapsed: 10:02:41
Epoch: 108, Loss: 0.64019, Elapsed: 10:08:21
Epoch: 109, Loss: 0.26558, Elapsed: 10:13:59
Epoch: 110, Loss: 0.41303, Elapsed: 10:19:38
Epoch: 111, Loss: 0.74221, Elapsed: 10:25:18
Epoch: 112, Loss: 0.48748, Elapsed: 10:30:56
Epoch: 113, Loss: 0.27348, Elapsed: 10:36:35
Epoch: 114, Loss: 0.51661, Elapsed: 10:42:14
Epoch: 115, Loss: 0.27831, Elapsed: 10:47:54
Epoch: 116, Loss: 0.35103, Elapsed: 10:53:33
Epoch: 117, Loss: 0.19571, Elapsed: 10:59:11
Epoch: 118, Loss: 0.37368, Elapsed: 11:04:50
Epoch: 119, Loss: 0.18644, Elapsed: 11:10:29
Epoch: 120, Loss: 0.48589, Elapsed: 11:16:16
Epoch: 121, Loss: 0.74257, Elapsed: 11:21:57
Epoch: 122, Loss: 0.65423, Elapsed: 11:27:37
Epoch: 123, Loss: 0.35185, Elapsed: 11:33:17
Epoch: 124, Loss: 0.81636, Elapsed: 11:38:55
Epoch: 125, Loss: 0.49292, Elapsed: 11:44:34
Epoch: 126, Loss: 0.9133, Elapsed: 11:50:14
Epoch: 127, Loss: 0.80498, Elapsed: 11:55:53
Epoch: 128, Loss: 0.59473, Elapsed: 12:01:33
Epoch: 129, Loss: 0.60313, Elapsed: 12:07:12
Epoch: 130, Loss: 0.5426, Elapsed: 12:12:50
Epoch: 131, Loss: 1.3471, Elapsed: 12:18:29
Epoch: 132, Loss: 0.35591, Elapsed: 12:24:08
Epoch: 133, Loss: 0.75186, Elapsed: 12:29:49
Epoch: 134, Loss: 0.98765, Elapsed: 12:35:29
Epoch: 135, Loss: 0.65345, Elapsed: 12:41:08
Epoch: 136, Loss: 0.78963, Elapsed: 12:46:48
Epoch: 137, Loss: 0.38269, Elapsed: 12:52:27
Epoch: 138, Loss: 0.5309, Elapsed: 12:58:06
Epoch: 139, Loss: 0.4119, Elapsed: 13:03:45
Epoch: 140, Loss: 0.93898, Elapsed: 13:09:26
Epoch: 141, Loss: 0.45791, Elapsed: 13:15:04
Epoch: 142, Loss: 0.70093, Elapsed: 13:20:43
Epoch: 143, Loss: 0.84997, Elapsed: 13:26:23
Epoch: 144, Loss: 0.27732, Elapsed: 13:32:05
Epoch: 145, Loss: 0.51171, Elapsed: 13:37:44
Epoch: 146, Loss: 0.81123, Elapsed: 13:43:24
Epoch: 147, Loss: 0.5678, Elapsed: 13:49:04
Epoch: 148, Loss: 0.58568, Elapsed: 13:54:44
Epoch: 149, Loss: 0.3952, Elapsed: 14:00:23
Epoch: 150, Loss: 0.31967, Elapsed: 14:06:03
Epoch: 151, Loss: 0.44051, Elapsed: 14:11:46
Epoch: 152, Loss: 0.99278, Elapsed: 14:17:27
Epoch: 153, Loss: 0.87306, Elapsed: 14:23:07
Epoch: 154, Loss: 0.34008, Elapsed: 14:28:47
Epoch: 155, Loss: 0.4687, Elapsed: 14:34:27
Epoch: 156, Loss: 0.22836, Elapsed: 14:40:07
Epoch: 157, Loss: 0.23204, Elapsed: 14:45:48
Epoch: 158, Loss: 0.36854, Elapsed: 14:51:28
Epoch: 159, Loss: 0.35363, Elapsed: 14:57:08
Epoch: 160, Loss: 0.37937, Elapsed: 15:02:55
Epoch: 161, Loss: 0.7725, Elapsed: 15:08:36
Epoch: 162, Loss: 0.59353, Elapsed: 15:14:15
Epoch: 163, Loss: 0.57963, Elapsed: 15:19:54
Epoch: 164, Loss: 0.54625, Elapsed: 15:25:35
Epoch: 165, Loss: 0.65612, Elapsed: 15:31:15
Epoch: 166, Loss: 0.73254, Elapsed: 15:36:56
Epoch: 167, Loss: 0.4483, Elapsed: 15:42:37
Epoch: 168, Loss: 0.36817, Elapsed: 15:48:17
Epoch: 169, Loss: 0.57539, Elapsed: 15:53:57
Epoch: 170, Loss: 1.0026, Elapsed: 15:59:37
Epoch: 171, Loss: 0.95288, Elapsed: 16:05:17
Epoch: 172, Loss: 0.83053, Elapsed: 16:10:59
Epoch: 173, Loss: 0.41976, Elapsed: 16:16:39
Epoch: 174, Loss: 0.44098, Elapsed: 16:22:19
Epoch: 175, Loss: 0.58823, Elapsed: 16:28:00
Epoch: 176, Loss: 0.67325, Elapsed: 16:33:41
Epoch: 177, Loss: 0.27045, Elapsed: 16:39:21
Epoch: 178, Loss: 0.66652, Elapsed: 16:45:02
Epoch: 179, Loss: 1.0097, Elapsed: 16:50:43
Epoch: 180, Loss: 0.40372, Elapsed: 16:56:23
Epoch: 181, Loss: 0.39175, Elapsed: 17:02:04
Epoch: 182, Loss: 0.40741, Elapsed: 17:07:45
Epoch: 183, Loss: 0.35398, Elapsed: 17:13:25
Epoch: 184, Loss: 0.63228, Elapsed: 17:19:05
Epoch: 185, Loss: 0.35308, Elapsed: 17:24:45
Epoch: 186, Loss: 0.46854, Elapsed: 17:30:27
Epoch: 187, Loss: 0.51346, Elapsed: 17:36:08
Epoch: 188, Loss: 0.71886, Elapsed: 17:41:48
Epoch: 189, Loss: 0.73986, Elapsed: 17:47:29
Epoch: 190, Loss: 0.46669, Elapsed: 17:53:10
Epoch: 191, Loss: 0.40962, Elapsed: 17:58:51
Epoch: 192, Loss: 0.25007, Elapsed: 18:04:31
Epoch: 193, Loss: 0.45651, Elapsed: 18:10:12
Epoch: 194, Loss: 0.20788, Elapsed: 18:15:52
Epoch: 195, Loss: 0.32097, Elapsed: 18:21:32
Epoch: 196, Loss: 0.28159, Elapsed: 18:27:15
Epoch: 197, Loss: 0.20396, Elapsed: 18:32:56
Epoch: 198, Loss: 0.30823, Elapsed: 18:38:37
Epoch: 199, Loss: 0.28583, Elapsed: 18:44:18

Epoch: 200, Loss: 0.32877, Elapsed: 18:50:07

Create Snapshot Ensemble

Combine the M snapshots of the network taken during training to form a final ensemble. The ensemble predictions correspond to the average of the output of the fully connected layer from all M individual models.

YPredictions = zeros(numClasses,numel(YTest),numSnapshots);
modelAccuracy = zeros(numSnapshots+1,1);
modelName = cell(numSnapshots+1,1);
for m = 1:numSnapshots
    modelName{m} = modelPrefix + m*epochsPerSnapshot;
    load(modelName{m} + ".mat");
    YPredictions(:,:,m) = gather(extractdata(predict(dlnet, dlarray(single(XTest),'SSCB'))));
    modelAccuracy(m) = computeAccuracy(YPredictions(:,:,m), YTest, classes);
    disp(modelName{m} + " accuracy: " + modelAccuracy(m) + "%")
end
SnapshotEpoch40 accuracy: 88.04%
SnapshotEpoch80 accuracy: 86.78%
SnapshotEpoch120 accuracy: 87.53%
SnapshotEpoch160 accuracy: 87.07%
SnapshotEpoch200 accuracy: 88.39%
modelAccuracy(end) = computeAccuracy(mean(YPredictions,3), YTest, classes);
modelName{end} = "Ensemble model";
disp("Ensemble accuracy: " + modelAccuracy(end) + "%")
Ensemble accuracy: 91.13%

Plot Accuracy

Plot the accuracy on the test data set for all snapshot models and the ensemble model.

figure;bar(modelAccuracy);
ylabel('Accuracy (%)');
xticklabels(modelName)
xtickangle(45)
title('Model accuracy')

Helper Functions

modelGradients Function

The modelGradients function takes in a dlnetwork object dlnet, a mini-batch of input data dlX, the labels Y, and the parameter for weight decay. The function returns the gradients, the loss, and the state of the nonlearnable parameters. To compute the gradients automatically, use the dlgradient function.

function [gradients, loss, state] = modelGradients(dlnet, dlX, Y, weightDecay)

[dlYPred, state] = forward(dlnet, dlX);
dlYPred = softmax(dlYPred);

loss = crossentropy(dlYPred, Y);

% L2-regularization (weight decay)
allParams = dlnet.Learnables(dlnet.Learnables.Parameter == "Weights" | dlnet.Learnables.Parameter == "Scale",:).Value;
l2Norm = cellfun(@(x) sum(x.^2,'All'), allParams, 'UniformOutput', false);
l2Norm = sum(cat(1, l2Norm{:}));
loss = loss + weightDecay*0.5*l2Norm;

gradients = dlgradient(loss, dlnet.Learnables);
end

computeAccuracy Function

The computeAccuracy function uses the network predictions, the true labels, and the number of classes to calculate the accuracy.

function accuracy = computeAccuracy(YPredictions, YTest, classes)
[~,I] = max(YPredictions,[],1);
C = classes(I);
accuracy = 100*(sum(C==YTest)/numel(C));
end

plotLossAndLearnRate Function

The plotLossAndLearnRate function plots the loss and learn rate at each iteration during training.

function [lossLine, learnRateLine] = plotLossAndLearnRate()
figure('Name','Training Progress');
clf
subplot(2,1,1); lossLine = animatedline;
title('Loss');
xlabel('Iteration')
ylabel('Loss')
grid on
subplot(2,1,2); learnRateLine = animatedline;
title('Learning rate');
xlabel('Iteration')
ylabel('Learning rate')
grid on
end

convolutionalUnit Function

convolutionalUnit(numF,stride,tag) creates an array of layers with two convolutional layers and corresponding batch normalization and ReLU layers. numF is the number of convolutional filters, stride is the stride of the first convolutional layer, and tag is a tag that is prepended to all layer names.

function layers = convolutionalUnit(numF,stride,tag)
layers = [
    convolution2dLayer(3,numF,'Padding','same','Stride',stride,'Name',[tag,'conv1'])
    batchNormalizationLayer('Name',[tag,'BN1'])
    reluLayer('Name',[tag,'relu1'])
    convolution2dLayer(3,numF,'Padding','same','Name',[tag,'conv2'])
    batchNormalizationLayer('Name',[tag,'BN2'])];
end

References

[1] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning for image recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.

[2] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf

[3] Loshchilov, Ilya, and Frank Hutter. "Sgdr: Stochastic gradient descent with warm restarts." (2016). arXiv preprint arXiv:1608.03983.

[4] Huang, Gao, Yixuan Li, Geoff Pleiss, Zhuang Liu, John E. Hopcroft, and Kilian Q. Weinberger. "Snapshot ensembles: Train 1, get m for free." (2017). arXiv preprint arXiv:1704.00109.

See Also

| | | | | |

Related Topics