This example shows how to assemble a multiple output network for prediction.
Instead of using the dlnetwork
object for prediction, you can assemble the network into a DAGNetwork
ready for prediction using the assembleNetwork
function. This lets you use the predict
function with other data types such as datastores.
Load the model parameters from the MAT file dlnetDigits.mat
. The MAT file contains a dlnetwork
object that predicts both the scores for categorical labels and numeric angles of rotation of images of digits, and the corresponding class names.
s = load("dlnetDigits.mat");
dlnet = s.dlnet;
classNames = s.classNames;
Extract the layer graph from the dlnetwork
object using the layerGraph
function.
lgraph = layerGraph(dlnet);
The layer graph does not include output layers. Add a classification layer and a regression layer to the layer graph using the addLayers
and connectLayers
functions.
layers = classificationLayer('Classes',classNames,'Name','coutput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'softmax','coutput'); layers = regressionLayer('Name','routput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'fc2','routput');
View a plot of the network.
figure plot(lgraph)
Assemble the network using the assembleNetwork
function.
net = assembleNetwork(lgraph)
net = DAGNetwork with properties: Layers: [19x1 nnet.cnn.layer.Layer] Connections: [19x2 table] InputNames: {'in'} OutputNames: {'coutput' 'routput'}
Load the test data.
[XTest,Y1Test,Y2Test] = digitTest4DArrayData;
To make predictions using the assembled network, use the predict
function. To return categorical labels for the classification output, set the 'ReturnCategorical'
option to true
.
[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);
Evaluate the classification accuracy.
accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9870
Evaluate the regression accuracy.
angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
6.0091
View some of the images with their predictions. Display the predicted angles in red and the correct labels in green.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Pred(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') thetaValidation = Y2Test(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--') hold off label = string(Y1Pred(idx(i))); title("Label: " + label) end
assembleNetwork
| batchNormalizationLayer
| convolution2dLayer
| fullyConnectedLayer
| predict
| reluLayer
| softmaxLayer