Instead of using the model function for prediction, you can assemble the network into a DAGNetwork
ready for prediction using the functionToLayerGraph
and assembleNetwork
functions. This lets you use the predict
function.
Load the model parameters from the MAT file digitsMIMO.mat
. The MAT file contains the model parameters in a struct named parameters
, the model state in a struct named state
, and the class names in classNames
.
s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;
The model function model
, listed at the end of the example, defines the model given the model parameters and state.
Define an anonymous function with a fixed set of model parameters, the model state, and set the doTraining
option to false
.
doTraining = false; fun = @(dlX) model(dlX,parameters,doTraining,state);
Convert the model function to a layer graph using the functionToLayerGraph
function. Create a variable dlX
that contains a mini-batch of data with the expected format.
X = rand(28,28,1,128,'single'); dlX = dlarray(X,'SSCB'); lgraph = functionToLayerGraph(fun,dlX);
The layer graph output by the functionToLayerGraph
function does not include input and output layers. Add an input layer, a classification layer, and a regression layer to the layer graph using the addLayers
and connectLayers
functions.
layers = imageInputLayer([28 28 1],'Name','input','Normalization','none'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'input','conv_1'); layers = classificationLayer('Classes',classNames,'Name','coutput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'sm_1','coutput'); layers = regressionLayer('Name','routput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'fc_1','routput');
View a plot of the network.
figure plot(lgraph)
Assemble the network using the assembleNetwork
function.
net = assembleNetwork(lgraph)
net = DAGNetwork with properties: Layers: [18×1 nnet.cnn.layer.Layer] Connections: [18×2 table] InputNames: {'input'} OutputNames: {'coutput' 'routput'}
Load the test data.
[XTest,Y1Test,Y2Test] = digitTest4DArrayData;
To make predictions using the assembled network, use the predict
function. To return categorical labels for the classification output, set the 'ReturnCategorical'
option to true
.
[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);
Evaluate the classification accuracy.
accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9644
Evaluate the regression accuracy.
angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
5.8081
View some of the images with their predictions. Display the predicted angles in red and the correct labels in green.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Pred(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') thetaValidation = Y2Test(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--') hold off label = string(Y1Pred(idx(i))); title("Label: " + label) end
The function model takes the input data dlX, the model parameters parameters
, the flag doTraining
which specifies whether to model should return outputs for training or prediction, and the network state state
. The network outputs the predictions for the labels, the predictions for the angles, and the updated network state.
function [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state) % Convolution W = parameters.conv1.Weights; B = parameters.conv1.Bias; dlY = dlconv(dlX,W,B,'Padding',2); % Batch normalization, ReLU Offset = parameters.batchnorm1.Offset; Scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution, batch normalization (Skip connection) W = parameters.convSkip.Weights; B = parameters.convSkip.Bias; dlYSkip = dlconv(dlY,W,B,'Stride',2); Offset = parameters.batchnormSkip.Offset; Scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else dlYSkip = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance); end % Convolution W = parameters.conv2.Weights; B = parameters.conv2.Bias; dlY = dlconv(dlY,W,B,'Padding',1,'Stride',2); % Batch normalization, ReLU Offset = parameters.batchnorm2.Offset; Scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution W = parameters.conv3.Weights; B = parameters.conv3.Bias; dlY = dlconv(dlY,W,B,'Padding',1); % Batch normalization Offset = parameters.batchnorm3.Offset; Scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end % Addition, ReLU dlY = dlYSkip + dlY; dlY = relu(dlY); % Fully connect (angles) W = parameters.fc1.Weights; B = parameters.fc1.Bias; dlY2 = fullyconnect(dlY,W,B); % Fully connect, softmax (labels) W = parameters.fc2.Weights; B = parameters.fc2.Bias; dlY1 = fullyconnect(dlY,W,B); dlY1 = softmax(dlY1); end
assembleNetwork
| batchnorm
| dlarray
| dlconv
| fullyconnect
| functionToLayerGraph
| predict
| relu
| softmax