Update parameters using custom function
dlupdate
Perform L1 regularization on a structure of parameter gradients.
Create the sample input data.
dlX = dlarray(rand(100,100,3),'SSC');
Initialize the learnable parameters for the convolution operation.
params.Weights = dlarray(rand(10,10,3,50)); params.Bias = dlarray(rand(50,1));
Calculate the gradients for the convolution operation using the helper function convGradients
, defined at the end of this example.
grads = dlfeval(@convGradients,dlX,params);
Define the regularization factor.
L1Factor = 0.001;
Create an anonymous function that regularizes the gradients. By using an anonymous function to pass a scalar constant to the function, you can avoid having to expand the constant value to the same size and structure as the parameter variable.
L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);
Use dlupdate
to apply the regularization function to each of the gradients.
grads = dlupdate(L1Regularizer,grads,params);
The gradients in grads
are now regularized according to the function L1Regularizer
.
convGradients
Function
The convGradients
helper function takes the learnable parameters of the convolution operation and a mini-batch of input data dlX
, and returns the gradients with respect to the learnable parameters.
function grads = convGradients(dlX,params) dlY = dlconv(dlX,params.Weights,params.Bias); dlY = sum(dlY,'all'); grads = dlgradient(dlY,params); end
dlupdate
to Train Network Using Custom Update FunctionUse dlupdate
to train a network using a custom update function that implements the stochastic gradient descent algorithm (without momentum).
Load Training Data
Load the digits training data.
[XTrain,YTrain] = digitTrain4DArrayData; classes = categories(YTrain); numClasses = numel(classes);
Define the Network
Define the network architecture and specify the average image value using the 'Mean'
option in the image input layer.
layers = [ imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4)) convolution2dLayer(5,20,'Name','conv1') reluLayer('Name', 'relu1') convolution2dLayer(3,20,'Padding',1,'Name','conv2') reluLayer('Name','relu2') convolution2dLayer(3,20,'Padding',1,'Name','conv3') reluLayer('Name','relu3') fullyConnectedLayer(numClasses,'Name','fc')]; lgraph = layerGraph(layers);
Create a dlnetwork
object from the layer graph.
dlnet = dlnetwork(lgraph);
Define Model Gradients Function
Create the helper function modelGradients
, listed at the end of this example. The function takes a dlnetwork
object dlnet
and a mini-batch of input data dlX
with corresponding labels Y
, and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet
.
Define Stochastic Gradient Descent Function
Create the helper function sgdFunction
, listed at the end of this example. The function takes param
and paramGradient
, a learnable parameter and the gradient of the loss with respect to that parameter, respectively, and returns the updated parameter using the stochastic gradient descent algorithm, expressed as
where is the iteration number, is the learning rate, is the parameter vector, and is the loss function.
Specify Training Options
Specify the options to use during training.
miniBatchSize = 128; numEpochs = 20; numObservations = numel(YTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.
executionEnvironment = "auto";
Initialize the velocity parameter.
learnRate = 0.001;
Initialize the training progress plot.
plots = "training-progress"; if plots == "training-progress" iteration = 1; figure lineLossTrain = animatedline; xlabel("Total Iterations") ylabel("Loss") end
Train Network
Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters by calling dlupdate
with the function sgdFunction
defined at the end of this example. At the end of each epoch, display the training progress.
for epoch = 1:numEpochs % Shuffle data. idx = randperm(numel(YTrain)); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(idx); for i = 1:numIterationsPerEpoch % Read mini-batch of data and convert the labels to dummy % variables. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = zeros(numClasses, miniBatchSize, 'single'); for c = 1:numClasses Y(c,YTrain(idx)==classes(c)) = 1; end % Convert mini-batch of data to dlarray. dlX = dlarray(single(X),'SSCB'); % If training on a GPU, then convert data to a gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients helper function. [grad,loss] = dlfeval(@modelGradients,dlnet,dlX,Y); % Update the network parameters using the SGD algorithm defined in the % sgdFunction helper function. dlnet = dlupdate(@sgdFunction,dlnet,grad); % Display the training progress. if plots == "training-progress" addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Loss During Training: Epoch - " + epoch + "; Iteration - " + i) drawnow iteration = iteration + 1; end end end
Test Network
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.
[XTest, YTest] = digitTest4DArrayData;
Convert the data to a dlarray
with the dimension format 'SSCB'
. For GPU prediction, also convert the data to a gpuArray
.
dlXTest = dlarray(XTest,'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlXTest = gpuArray(dlXTest); end
To classify images using a dlnetwork
object, use the predict
function and find the classes with the highest scores.
dlYPred = predict(dlnet,dlXTest); [~,idx] = max(extractdata(dlYPred),[],1); YPred = classes(idx);
Evaluate the classification accuracy.
accuracy = mean(YPred==YTest)
accuracy = 0.7282
Model Gradients Function
The helper function modelGradients
takes a dlnetwork
object dlnet
and a mini-batch of input data dlX
with corresponding labels Y
, and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet
. To compute the gradients automatically, use the dlgradient
function.
function [gradients,loss] = modelGradients(dlnet,dlX,Y) dlYPred = forward(dlnet,dlX); dlYPred = softmax(dlYPred); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); end
Stochastic Gradient Descent Function
The helper function sgdFunction
takes param
and paramGradient
, a learnable parameter and the gradient of the loss with respect to that parameter respectively, and returns the updated parameter using the stochastic gradient descent algorithm, expressed as
where is the iteration number, is the learning rate, is the parameter vector, and is the loss function.
function param = sgdFunction(param,paramGradient) learnRate = 0.01; param = param - learnRate.*paramGradient; end
fun
— Function to applyFunction to apply to the learnable parameters, specified as a function handle.
dlupate
evaluates fun
with each network
learnable parameter as an input. fun
is evaluated as many times as
there are arrays of learnable parameters in dlnet
or
params
.
dlnet
— Networkdlnetwork
objectNetwork, specified as a dlnetwork
object.
The function updates the dlnet.Learnables
property of the
dlnetwork
object. dlnet.Learnables
is a table
with three variables:
Layer
— Layer name, specified as a string scalar.
Parameter
— Parameter name, specified as a string
scalar.
Value
— Value of parameter, specified as a cell array
containing a dlarray
.
params
— Network learnable parametersdlarray
| numeric array | cell array | structure | tableNetwork learnable parameters, specified as a dlarray
, a numeric
array, a cell array, a structure, or a table.
If you specify params
as a table, it must contain the following
three variables.
Layer
— Layer name, specified as a string scalar.
Parameter
— Parameter name, specified as a string
scalar.
Value
— Value of parameter, specified as a cell array
containing a dlarray
.
You can specify params
as a container of learnable parameters for
your network using a cell array, structure, or table, or nested cell arrays or
structures. The learnable parameters inside the cell array, structure, or table must be
dlarray
or numeric values of data type double
or
single
.
The input argument grad
must be provided with exactly the same
data type, ordering, and fields (for structures) or variables (for tables) as
params
.
Data Types: single
| double
| struct
| table
| cell
A1,...,An
— Additional input argumentsdlarray
| numeric array | cell array | structure | tableAdditional input arguments to fun
, specified as
dlarray
objects, numeric arrays, cell arrays, structures, or
tables with a Value
variable.
The exact form of A1,...,An
depends on the input network or
learnable parameters. The following table shows the required format for
A1,...,An
for possible inputs to
dlupdate
.
Input | Learnable Parameters | A1,...,An |
---|---|---|
dlnet | Table dlnet.Learnables containing
Layer , Parameter , and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data type, variables, and ordering as
dlnet.Learnables . A1,...,An must have a
Value variable consisting of cell arrays that contain the
additional input arguments for the function fun to apply to
each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
Numeric array | Numeric array with the same data type and ordering as
params
| |
Cell array | Cell array with the same data types, structure, and ordering as
params | |
Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer , Parameter , and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data types, variables and ordering as
params . A1,...,An must have a
Value variable consisting of cell arrays that contain the
additional input argument for the function fun to apply to
each learnable parameter. |
dlnet
— Updated networkdlnetwork
objectNetwork, returned as a dlnetwork
object.
The function updates the dlnet.Learnables
property of the
dlnetwork
object.
params
— Updated network learnable parametersdlarray
| numeric array | cell array | structure | tableUpdated network learnable parameters, returned as a dlarray
, a
numeric array, a cell array, a structure, or a table with a Value
variable containing the updated learnable parameters of the network.
X1,...,Xm
— Additional output argumentsdlarray
| numeric array | cell array | structure | tableAdditional output arguments from the function fun
, where
fun
is a function handle to a function that returns multiple
outputs, returned as dlarray
objects, numeric arrays, cell arrays,
structures, or tables with a Value
variable.
The exact form of X1,...,Xm
depends on the input network or
learnable parameters. The following table shows the returned format of
X1,...,Xm
for possible inputs to
dlupdate
.
Input | Learnable parameters | X1,...,Xm |
---|---|---|
dlnet | Table dlnet.Learnables containing
Layer , Parameter , and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data type, variables, and ordering as
dlnet.Learnables . X1,...,Xm has a
Value variable consisting of cell arrays that contain the
additional output arguments of the function fun applied to
each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
Numeric array | Numeric array with the same data type and ordering as
params
| |
Cell array | Cell array with the same data types, structure, and ordering as
params | |
Structure | Structure with the same data types, fields, and ordering as
params
| |
Table with Layer , Parameter , and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data types, variables. and ordering as
params . X1,...,Xm has a
Value variable consisting of cell arrays that contain the
additional output argument of the function fun applied to
each learnable parameter. |
Usage notes and limitations:
When at least one of the following input arguments is a gpuArray
or a dlarray
with underlying data of type
gpuArray
, this function runs on the GPU.
params
A1,...,An
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| rmspropupdate
| sgdmupdate
You have a modified version of this example. Do you want to open this example with your edits?