This example shows how to train a network that classifies handwritten digits with a custom learning rate schedule.
If trainingOptions
does not provide the options you need (for example, a custom learning rate schedule), then you can define your own custom training loop using automatic differentiation.
This example trains a network to classify handwritten digits with the time-based decay learning rate schedule: for each iteration, the solver uses the learning rate given by , where t is the iteration number, is the initial learning rate, and k is the decay.
Load Training Data
Load the digits data.
Define Network
Define the network and specify the average image using the 'Mean'
option in the image input layer.
Create a dlnetwork
object from the layer graph.
dlnet =
dlnetwork with properties:
Layers: [12×1 nnet.cnn.layer.Layer]
Connections: [11×2 table]
Learnables: [14×3 table]
State: [6×3 table]
InputNames: {'input'}
OutputNames: {'softmax'}
Define Model Gradients Function
Create the function modelGradients
, listed at the end of the example, that takes a dlnetwork
object dlnet
, a mini-batch of input data dlX
with corresponding labels Y
and returns the gradients of the loss with respect to the learnable parameters in dlnet
and the corresponding loss.
Specify Training Options
Train with a minibatch size of 128 for 5 epochs.
Specify the options for SGDM optimization. Specify an initial learn rate of 0.01 with a decay of 0.01, and momentum 0.9.
Visualize the training progress in a plot.
Train on a GPU if one is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.
Train Model
Train the model using a custom training loop.
For each epoch, shuffle the data and loop over mini-batches of data. At the end of each epoch, display the training progress.
For each mini-batch:
Convert the labels to dummy variables.
Convert the data to dlarray
objects with underlying type single and specify the dimension labels 'SSCB'
(spatial, spatial, channel, batch).
For GPU training, convert to gpuArray
objects.
Evaluate the model gradients, state, and loss using dlfeval
and the modelGradients
function and update the network state.
Determine the learning rate for the time-based decay learning rate schedule.
Update the network parameters using the sgdmupdate
function.
Initialize the training progress plot.
Initialize the velocity parameter for the SGDM solver.
Train the network.
Test Model
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.
Convert the data to a dlarray
object with dimension format 'SSCB'
. For GPU prediction, also convert the data to gpuArray
.
Classify the images using modelPredictions
function, listed at the end of the example and find the classes with the highest scores.
Evaluate the classification accuracy.
Model Gradients Function
The modelGradients
function takes a dlnetwork
object dlnet
, a mini-batch of input data dlX
with corresponding labels Y
and returns the gradients of the loss with respect to the learnable parameters in dlnet
, the network state, and the loss. To compute the gradients automatically, use the dlgradient
function.
Model Predictions Function
The modelPredictions
function takes a dlnetwork
object dlnet
, an array of input data dlX
, and a mini-batch size, and outputs the model predictions by iterating over mini-batches of the specified size.