When training a deep learning model with a custom training loop, the software minimizes the loss with respect to the learnable parameters. To minimize the loss, the software uses the gradients of the loss with respect to the learnable parameters. To calculate these gradients using automatic differentiation, you must define a model gradients function.
For an example showing how to train deep learning model with a dlnetwork
object, see Train Network Using Custom Training Loop. For an example showing
how to training a deep learning model defined as a function, see Train Network Using Model Function.
dlnetwork
ObjectIf you have a deep learning model defined as a dlnetwork
object, then
create a model gradients function that takes the dlnetwork
object as
input.
For models specified as a dlnetwork
object, create a function of the form
gradients = modelGradients(dlnet,dlX,T)
, where
dlnet
is the network, dlX
contains the input
predictors, T
contains the targets, and gradients
contains the returned gradients. Optionally, you can pass extra arguments to the gradients
function (for example, if the loss function requires extra information), or return extra
arguments (for example, metrics for plotting the training progress).
For example, this function returns the gradients and the cross entropy loss for the
specified dlnetwork
object dlnet
, input data
dlX
, and targets T
.
function [gradients, loss] = modelGradients(dlnet, dlX, T) % Forward data through the dlnetwork object. dlY = forward(dlnet,dlX); % Compute loss. loss = crossentropy(dlX,T); % Compute gradients. gradients = dlgradient(loss,dlnet); end
If you have a deep learning model defined as a function of the form dlY =
model(parameters,dlX)
, then create a function of the form
gradients = modelGradients(parameters,dlX,T)
, where
parameters
is a struct containing the learnable parameters,
dlX
are the input predictors, T
are the
targets, and gradients
are the returned gradients. Optionally, you
can pass extra arguments to the gradients function (for example, if the loss function
requires extra information), or return extra arguments (for example, metrics for
plotting the training progress). For models defined as a function, you do not need to
pass a network as an input argument.
For example, this function returns the gradients and the cross entropy loss for the
deep learning model function model
with the specified learnable
parameters parameters
, input data dlX
, and targets
T
.
function [gradients, loss] = modelGradients(parameters, dlX, T) % Forward data through the model function. dlY = model(parameters,dlX); % Compute loss. loss = crossentropy(dlX,T); % Compute gradients. gradients = dlgradient(loss,parameters); end
To evaluate the model gradients using automatic differentiation, use the
dlfeval
function which evaluates a function with automatic
differentiation enabled. For the first input of dlfeval
, pass the model
gradients function specified as a function handle and for the following inputs, pass the
required variables for the model gradients function. For the outputs of the
dlfeval
function, specify the same outputs as the model gradients
function.
For example, to evaluate the model gradients function
modelGradients
with a dlnetwork
object
dlnet
, input data dlX
and
T
, and return the model gradients and loss, use the
command:
[gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);
Similarly, to evaluate the model gradients function modelGradients
using a model function with learnable parameters specified by the struct
parameters
, input data dlX
and
T
, and return the model gradients and loss, use the
command:
[gradients, loss] = dlfeval(@modelGradients,parameters,dlX,T);
To update the learnable parameters using the gradients, you can use the following functions:
Function | Description |
---|---|
adamupdate | Update parameters using adaptive moment estimation (Adam) |
rmspropupdate | Update parameters using root mean squared propagation (RMSProp) |
sgdmupdate | Update parameters using stochastic gradient descent with momentum (SGDM) |
dlupdate | Update parameters using custom function |
For example, to update the learnable parameters of a dlnetwork
object
dlnet
using the adamupdate
function, use the
command:
[dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet,gradients, ...
trailingAvg,trailingAverageSq,iteration);
gradients
is the output of the model gradients function, and
trailingAvg
, trailingAvgSq
, and
iteration
are the hyperparameters required by the
adamupdate
function.Similarly, to update the learnable parameters for a model function
parameters
using the adamupdate
function,
use the
command:
[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
trailingAvg,trailingAverageSq,iteration);
gradients
is the output of the model gradients function, and
trailingAvg
, trailingAvgSq
, and
iteration
are the hyperparameters required by the
adamupdate
function.When training a deep learning model using a custom training loop, evaluate the model gradients and update the learnable parameters for each mini-batch.
This code snippet shows an example of using the dlfeval
and
adamupdate
functions in a custom training loop.
iteration = 0; % Loop over epochs. for epoch = 1:numEpochs % Loop over mini-batches. for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Prepare mini-batch. % ... % Evaluate model gradients. [gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T); % Update learnable parameters. [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ... trailingAvg,trailingAverageSq,iteration); end end
For an example showing how to train deep learning model with a
dlnetwork
object, see Train Network Using Custom Training Loop. For an example
showing how to training a deep learning model defined as a function, see Train Network Using Model Function.
If there is an issue in the implementation of the model gradients function, the call
to dlfeval
may throw an error. Sometimes, when using the
dlfeval
function, it is not clear which line of code is
throwing the error. To help locate the error, you can try the following:
Try calling the model gradients function directly (that is, without using the
dlfeval
function) with generated inputs of the expected
sizes. If any of the lines of code throw an error, then it should be clear which one
did. Note that when not using the dlfeval
function, any calls
to the dlgradient
function are expected to error.
% Generate image input data. X = rand([28 28 1 100],'single'); dlX = dlarray(dlX); % Generate one-hot encoded target data. T = repmat(eye(10,'single'),[1 10]); [gradients, loss] = modelGradients(dlnet,dlX,T);
Run the code inside the model gradients function manually with generated inputs of the expected sizes and inspect the output and any thrown error messages.
For example, to check the model gradients function defined by:
function [gradients, loss] = modelGradients(dlnet, dlX, T) % Forward data through the dlnetwork object. dlY = forward(dlnet,dlX); % Compute loss. loss = crossentropy(dlX,T); % Compute gradients. gradients = dlgradient(loss,dlnet); end
run the code:
% Generate image input data. X = rand([28 28 1 100],'single'); dlX = dlarray(dlX); % Generate one-hot encoded target data. T = repmat(eye(10,'single'),[1 10]); % Check forward pass. dlY = forward(dlnet,dlX); % Check loss calculation. loss = crossentropy(dlX,T)