Update parameters using stochastic gradient descent with momentum (SGDM)
Update the network learnable parameters in a custom training loop using the stochastic gradient descent with momentum (SGDM) algorithm.
Note
This function applies the SGDM optimization algorithm to update network parameters in
custom training loops that use networks defined as dlnetwork
objects or model functions. If you want to train a network defined as
a Layer
array or as a
LayerGraph
, use the
following functions:
Create a TrainingOptionsSGDM
object using the trainingOptions
function.
Use the TrainingOptionsSGDM
object with the trainNetwork
function.
sgdmupdate
Perform a single SGDM update step with a global learning rate of
0.05
and momentum of 0.95
.
Create the parameters and parameter gradients as numeric arrays.
params = rand(3,3,4); grad = ones(3,3,4);
Initialize the parameter velocities for the first iteration.
vel = [];
Specify custom values for the global learning rate and momentum.
learnRate = 0.05; momentum = 0.95;
Update the learnable parameters using sgdmupdate
.
[params,vel] = sgdmupdate(params,grad,vel,learnRate,momentum);
sgdmupdate
Use sgdmupdate
to train a network using the SGDM algorithm.
Load Training Data
Load the digits training data.
[XTrain,YTrain] = digitTrain4DArrayData; classes = categories(YTrain); numClasses = numel(classes);
Define Network
Define the network architecture and specify the average image value using the 'Mean'
option in the image input layer.
layers = [ imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4)) convolution2dLayer(5,20,'Name','conv1') reluLayer('Name', 'relu1') convolution2dLayer(3,20,'Padding',1,'Name','conv2') reluLayer('Name','relu2') convolution2dLayer(3,20,'Padding',1,'Name','conv3') reluLayer('Name','relu3') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','softmax')]; lgraph = layerGraph(layers);
Create a dlnetwork
object from the layer graph.
dlnet = dlnetwork(lgraph);
Define Model Gradients Function
Create the helper function modelGradients
, listed at the end of the example. The function takes a dlnetwork
object dlnet
and a mini-batch of input data dlX
with corresponding labels Y
, and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet
.
Specify Training Options
Specify the options to use during training.
miniBatchSize = 128; numEpochs = 20; numObservations = numel(YTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.
executionEnvironment = "auto";
Visualize the training progress in a plot.
plots = "training-progress";
Train Network
Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters using the sgdmupdate
function. At the end of each epoch, display the training progress.
Initialize the training progress plot.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Initialize the velocity parameter.
vel = [];
Train the network.
iteration = 0; start = tic; for epoch = 1:numEpochs % Shuffle data. idx = randperm(numel(YTrain)); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(idx); for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data and convert the labels to dummy % variables. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = zeros(numClasses, miniBatchSize, 'single'); for c = 1:numClasses Y(c,YTrain(idx)==classes(c)) = 1; end % Convert mini-batch of data to a dlarray. dlX = dlarray(single(X),'SSCB'); % If training on a GPU, then convert data to a gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients helper function. [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y); % Update the network parameters using the SGDM optimizer. [dlnet,vel] = sgdmupdate(dlnet,gradients,vel); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end
Test the Network
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.
[XTest, YTest] = digitTest4DArrayData;
Convert the data to a dlarray
with the dimension format 'SSCB'
. For GPU prediction, also convert the data to a gpuArray
.
dlXTest = dlarray(XTest,'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlXTest = gpuArray(dlXTest); end
To classify images using a dlnetwork
object, use the predict
function and find the classes with the highest scores.
dlYPred = predict(dlnet,dlXTest); [~,idx] = max(extractdata(dlYPred),[],1); YPred = classes(idx);
Evaluate the classification accuracy.
accuracy = mean(YPred==YTest)
accuracy = 0.9916
Model Gradients Function
The modelGradients
helper function takes a dlnetwork
object dlnet
and a mini-batch of input data dlX
with corresponding labels Y
, and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet
. To compute the gradients automatically, use the dlgradient
function.
function [gradients,loss] = modelGradients(dlnet,dlX,Y) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); end
dlnet
— Networkdlnetwork
objectNetwork, specified as a dlnetwork
object.
The function updates the dlnet.Learnables
property of the
dlnetwork
object. dlnet.Learnables
is a table
with three variables:
Layer
— Layer name, specified as a string scalar.
Parameter
— Parameter name, specified as a string
scalar.
Value
— Value of parameter, specified as a cell array
containing a dlarray
.
The input argument grad
must be a table of the same
form as dlnet.Learnables
.
params
— Network learnable parametersdlarray
| numeric array | cell array | structure | tableNetwork learnable parameters, specified as a dlarray
, a numeric
array, a cell array, a structure, or a table.
If you specify params
as a table, it must contain the following
three variables.
Layer
— Layer name, specified as a string scalar.
Parameter
— Parameter name, specified as a string
scalar.
Value
— Value of parameter, specified as a cell array
containing a dlarray
.
You can specify params
as a container of learnable parameters for
your network using a cell array, structure, or table, or nested cell arrays or
structures. The learnable parameters inside the cell array, structure, or table must be
dlarray
or numeric values of data type double
or
single
.
The input argument grad
must be provided with exactly the same
data type, ordering, and fields (for structures) or variables (for tables) as
params
.
Data Types: single
| double
| struct
| table
| cell
grad
— Gradients of the lossdlarray
| numeric array | cell array | structure | tableGradients of the loss, specified as a dlarray
, a numeric array, a
cell array, a structure, or a table.
The exact form of grad
depends on the input network or learnable
parameters. The following table shows the required format for grad
for possible inputs to sgdmupdate
.
Input | Learnable Parameters | Gradients |
---|---|---|
dlnet | Table dlnet.Learnables containing
Layer , Parameter , and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data type, variables, and ordering as
dlnet.Learnables . grad must have a
Value variable consisting of cell arrays that contain the
gradient of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
Numeric array | Numeric array with the same data type and ordering as
params
| |
Cell array | Cell array with the same data types, structure, and ordering as
params | |
Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer , Parameter , and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data types, variables, and ordering as
params . grad must have a
Value variable consisting of cell arrays that contain the
gradient of each learnable parameter. |
You can obtain grad
from a call to dlfeval
that
evaluates a function that contains a call to dlgradient
.
For more information, see Use Automatic Differentiation In Deep Learning Toolbox.
vel
— Parameter velocities[]
| dlarray
| numeric array | cell array | structure | tableParameter velocities, specified as an empty array, a dlarray
, a
numeric array, a cell array, a structure, or a table.
The exact form of vel
depends on the input network or learnable
parameters. The following table shows the required format for vel
for
possible inputs to sgdmpdate
.
Input | Learnable Parameters | Velocities |
---|---|---|
dlnet | Table dlnet.Learnables containing
Layer , Parameter , and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data type, variables, and ordering as
dlnet.Learnables . vel must have a
Value variable consisting of cell arrays that contain the
velocity of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
Numeric array | Numeric array with the same data type and ordering as
params
| |
Cell array | Cell array with the same data types, structure, and ordering as
params | |
Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer , Parameter , and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data types, variables, and ordering as
params . vel must have a
Value variable consisting of cell arrays that contain the
velocity of each learnable parameter. |
If you specify vel
as an empty array, the function assumes no
previous velocities and runs in the same way as for the first update in a series of
iterations. To update the learnable parameters iteratively, use the
vel
output of a previous call to sgdmupdate
as
the vel
input.
learnRate
— Global learning rate0.01
(default) | positive scalarLearning rate, specified as a positive scalar. The default value of
learnRate
is 0.01
.
If you specify the network parameters as a dlnetwork
object, the
learning rate for each parameter is the global learning rate multiplied by the
corresponding learning rate factor property defined in the network layers.
momentum
— Momentum0.9
(default) | positive scalar between 0
and 1
Momentum, specified as a positive scalar between 0
and
1
. The default value of momentum
is
0.9
.
dlnet
— Updated networkdlnetwork
objectNetwork, returned as a dlnetwork
object.
The function updates the dlnet.Learnables
property of the
dlnetwork
object.
params
— Updated network learnable parametersdlarray
| numeric array | cell array | structure | tableUpdated network learnable parameters, returned as a dlarray
, a
numeric array, a cell array, a structure, or a table with a Value
variable containing the updated learnable parameters of the network.
vel
— Updated parameter velocitiesdlarray
| numeric array | cell array | structure | tableUpdated parameter velocities, returned as a dlarray
, a numeric
array, a cell array, a structure, or a table.
The function uses the stochastic gradient descent with momentum
algorithm to update the learnable parameters. For more information, see the definition of
the stochastic gradient descent with momentum algorithm under Stochastic Gradient Descent on the
trainingOptions
reference page.
Usage notes and limitations:
When at least one of the following input arguments is a gpuArray
or a dlarray
with underlying data of type
gpuArray
, this function runs on the GPU.
grad
params
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| dlupdate
| forward
| rmspropupdate
You have a modified version of this example. Do you want to open this example with your edits?