This example shows how to generate a synthetic image of a scene from a semantic segmentation map using a Pix2PixHD conditional generative adversarial network (CGAN).
Pix2PixHD [1] consists of two networks that are trained simultaneously to maximize the performance of both.
The generator is an encoder-decoder style neural network that generates a scene image from a semantic segmentation map. A CGAN network trains the generator to generate a scene image that the discriminator misclassifies as real.
The discriminator is a fully convolutional neural network that compares a generated scene image and the corresponding real image and attempts to classify them as fake and real, respectively. A CGAN network trains the discriminator to correctly distinguish between generated and real image.
The generator and discriminator networks compete against each other during training. The training converges when neither network can improve further.
This example uses the CamVid data set [2] from the University of Cambridge for training. This data set is a collection of 701 images containing street-level views obtained while driving. The data set provides pixel labels for 32 semantic classes including car, pedestrian, and road.
Download the CamVid data set from these URLs. The download time depends on your internet connection.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip'; dataDir = fullfile(tempdir,'CamVid'); downloadCamVidData(dataDir,imageURL,labelURL); imgDir = fullfile(dataDir,"images","701_StillsRaw_full"); labelDir = fullfile(dataDir,'labels');
Create an imageDatastore
to store the images in the CamVid data set.
imds = imageDatastore(imgDir); imageSize = [576 768];
Define the class names and pixel label IDs of the 32 classes in the CamVid data set using the helper function defineCamVid32ClassesAndPixelLabelIDs
. Get a standard color map for the CamVid data set using the helper function camvid32ColorMap
. The helper functions are attached to the example as supporting files.
numClasses = 32; [classes,labelIDs] = defineCamVid32ClassesAndPixelLabelIDs; cmap = camvid32ColorMap;
Create a pixelLabelDatastore
to store the pixel label images.
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
Preview a pixel label image and the corresponding ground truth scene image. Convert the labels from categorical labels to RGB colors by using the label2rgb
function, then display the pixel label image and ground truth image in a montage.
im = preview(imds); px = preview(pxds); px = label2rgb(px,cmap); montage({px,im})
Partition the data into training and test sets using the helper function partitionCamVidForPix2PixHD
. This function is attached to the example as a supporting file. The helper function splits the data into 648 training files and 32 test files.
[imdsTrain,imdsTest,pxdsTrain,pxdsTest] = partitionCamVidForPix2PixHD(imds,pxds,classes,labelIDs);
Use the combine
function to combine the pixel label images and ground truth scene images into a single datastore.
dsTrain = combine(pxdsTrain,imdsTrain);
Augment the training data by using the transform
function with custom preprocessing operations specified by the helper function preprocessCamVidForPix2PixHD
. This helper function is attached to the example as a supporting file.
The preprocessCamVidForPix2PixHD
function performs these operations:
Scale the ground truth data to the range [-1, 1]. This range matches the range of the final tanhLayer
(Deep Learning Toolbox) in the generator network.
Resize the image and labels to the output size of the network, 576-by-768 pixels, using bicubic and nearest neighbor downsampling, respectively.
Convert the single channel segmentation map to a 32-channel one-hot encoded segmentation map using the onehotencode
(Deep Learning Toolbox) function.
Randomly flip image and pixel label pairs in the horizontal direction.
dsTrain = transform(dsTrain,@(x) preprocessCamVidForPix2PixHD(x,imageSize));
Preview the channels of a one-hot encoded segmentation map in a montage. Each channel represents a one-hot map corresponding to pixels of a unique class.
map = preview(dsTrain); montage(map{1},'Size',[4 8],'Bordersize',5,'BackgroundColor','b')
Define a generator network that generates a scene image from a depth-wise one-hot encoded segmentation map. This input has same height and width as the original segmentation map and the same number of channels as classes.
generatorInputSize = [imageSize numClasses];
Create layers of the initial subnetwork. reflectionPad2dLayer
is a custom layer implemented specifically for this example. This layer is attached to the example as a supporting file.
numFiltersFirstConvLayerGenerator = 64; filterSize = [7 7]; initialLayers = [ ... imageInputLayer(generatorInputSize,'Normalization','none','Name','inputLayer') ... reflectionPad2dLayer(3,'iPad') ... convolution2dLayer(filterSize,numFiltersFirstConvLayerGenerator,'Name','iConv') ... groupNormalizationLayer('channel-wise','Name','iGn') ... reluLayer('Name','iRelu') ];
Add layers of the downsampling subnetwork. Use four downsampling convolutional layers. Each downsampling layer has twice the number of filters as the previous convolutional layer.
numFilters = numFiltersFirstConvLayerGenerator; numDownsamplingLayers = 4; filterSize = [3 3]; downsamplingLayers = []; for idx = 1:numDownsamplingLayers % Compute the number of filters in the next convolutional layer numFilters = numFilters*2; s = int2str(idx); downsamplingLayers = [ downsamplingLayers ... convolution2dLayer(filterSize,numFilters,"Name",strcat("dConv",s), ... "Stride",2,"Padding",1) ... groupNormalizationLayer("channel-wise",'Name',strcat("dGn",s)) ... reluLayer('Name',strcat('dRelu',s)) ]; end
Create a layer graph from the initial subnetwork and downsampling subnetwork.
generator = layerGraph([initialLayers downsamplingLayers]);
Create layers of the residual subnetwork. Specify nine residual blocks in the generator. The residual connection in the first residual block is between the additional layer of the first block and the final ReLU layer of the downsampling network. The residual connection in subsequent residual blocks are between the addition layer of the current block and the addition layer of the previous block.
The number of filters in the residual convolutional layers is equal to the number of filters in the last downsampling convolutional layer, 1024.
numResidualBlocks = 9; for idx = 1:numResidualBlocks % Get the name of the layer that acts as the source of the residual connection res = generator.Layers(end).Name; % Specify the layer names of the residual block s = int2str(idx); convLayer1Name = strcat("rConv",s,"_1"); convLayer2Name = strcat("rConv",s,"_2"); gnLayer1Name = strcat("rGn",s,"_1"); gnLayer2Name = strcat("rGn",s,"_2"); pad1Name = strcat("rPad",s,"_1"); pad2Name = strcat("rPad",s,"_2"); residualBlockLayers = [ reflectionPad2dLayer(1,pad1Name) ... convolution2dLayer(filterSize,numFilters,"Name",convLayer1Name) ... groupNormalizationLayer('channel-wise','Name',gnLayer1Name) ... reflectionPad2dLayer(1,pad2Name) ... convolution2dLayer(filterSize,numFilters,"Name",convLayer2Name) ... groupNormalizationLayer('channel-wise','Name',gnLayer2Name) ... reluLayer("Name",strcat("rRelu",s)) ... additionLayer(2,'Name',strcat("rAdd",s)) ]; % Add the layers to the layer graph lg = addLayers(generator,residualBlockLayers); generator = connectLayers(lg,generator.Layers(end).Name,residualBlockLayers(1).Name); % Link the residual connection generator = connectLayers(generator,res,strcat("rAdd",s,"/in2")); end
Create layers of the upsampling subnetwork. Use four upsampling convolutional layers, which is the same as the number of downconvolutional layers. Each upsampling convolutional layer has half the number of filters as the previous convolutional layer.
for idx=1:numDownsamplingLayers % Compute the number of filters in the next convolutional layer numFilters = numFilters/2; s = int2str(idx); upsamplingLayers = [ transposedConv2dLayer(filterSize,numFilters,"Name",strcat("uConv",s), ... "Stride",2,"Cropping","Same") ... groupNormalizationLayer('channel-wise','Name',strcat("uGn",s)) ... reluLayer('Name',strcat('uRelu',s)); ]; % Add the upsampling layers to the layer graph lg = addLayers(generator,upsamplingLayers); generator = connectLayers(lg,generator.Layers(end).Name,upsamplingLayers(1).Name); end
Create layers of the final subnetwork. Specify the filter size and number of filters of the final convolutional layer of the generator. The final layer is a hyperbolic tangent layer, which produces activations in the range [-1, 1].
filterSize = [7 7]; numFilters = 3; finalLayers = [ reflectionPad2dLayer(3,'fPad') ... convolution2dLayer(filterSize,numFilters,'Name','fConv') ... tanhLayer('Name','ftanh') ];
Add the final subnetwork layers to the layer graph.
lg = addLayers(generator,finalLayers); lgraphGenerator = connectLayers(lg,generator.Layers(end).Name,finalLayers(1).Name);
Visualize the generator network in a plot.
plot(lgraphGenerator)
title("Generator")
To train the network with a custom training loop and to enable automatic differentiation, convert the layer graph to a dlnetwork
(Deep Learning Toolbox) object.
dlnetGenerator = dlnetwork(lgraphGenerator);
Visualize the network architecture using the Deep Network Designer (Deep Learning Toolbox) app.
deepNetworkDesigner(lgraphGenerator)
Define a discriminator network that classifies an input image as either real (1) or fake (0).
The input to the discriminator is the depth-wise concatenation of the one-hot encoded segmentation maps and the scene image to be classified. Specify the number of channels input to the discriminator as the total number of labeled classes and image color channels.
numImageChannels = 3; numChannelsDiscriminator = numClasses + numImageChannels; discriminatorInputSize = [imageSize numChannelsDiscriminator];
Specify the filter size and number of filters in the first convolutional layer of the discriminator.
filterSize = [4 4]; numFilters = 64;
Define the layers of the discriminator.
discriminator = [ imageInputLayer(discriminatorInputSize,"Name","inputLayer","Normalization","none") convolution2dLayer(filterSize,numFilters,"Name","iConv", ... "Padding",2,"Stride",2) leakyReluLayer(0.2,"Name","lrelu1") convolution2dLayer(filterSize,numFilters*2,"Name","dConv1", ... "Padding",2,"Stride",2) groupNormalizationLayer('channel-wise',"Name","dGn1") leakyReluLayer(0.2,"Name","lrelu2") convolution2dLayer(filterSize,numFilters*4,"Name","dConv2", ... "Padding",2,"Stride",2); groupNormalizationLayer('channel-wise',"Name","dGn2") leakyReluLayer(0.2,"Name","lrelu3") convolution2dLayer(filterSize,numFilters*8,"Name","dConv3", ... "Padding",2) groupNormalizationLayer('channel-wise',"Name","dGn3") leakyReluLayer(0.2,"Name","lrelu4") convolution2dLayer(filterSize,1,"Name","fConv", ... "Padding",2) ];
Create the layer graph.
lgraphDiscriminator = layerGraph(discriminator);
Visualize the discriminator network in a plot.
plot(lgraphDiscriminator)
title("Discriminator")
To train the network with a custom training loop and to enable automatic differentiation, convert the layer graph to a dlnetwork
object.
dlnetDiscriminator = dlnetwork(lgraphDiscriminator);
Visualize the network architecture using the Deep Network Designer (Deep Learning Toolbox) app.
deepNetworkDesigner(lgraphDiscriminator)
The helper function modelGradients
calculates the gradients and adversarial loss for the generator and discriminator. The function also calculates the feature matching loss and VGG loss for the generator. This function is defined in Supporting Functions section of this example.
The objective of the generator is to generate images that the discriminator classifies as real (1). The generator loss consists of three losses.
The adversarial loss is computed as the squared difference between a vector of ones and the discriminator predictions on the generated image. are discriminator predictions on the image generated by the generator. This loss is implemented using part of the pix2pixhdAdversarialLoss
helper function defined in the Supporting Functions section of this example.
The feature matching loss penalises the distance between the real and generated feature maps obtained as predictions from the discriminator network. is total number of discriminator feature layers. and are the ground truth images and generated images, respectively. This loss is implemented using the pix2pixhdFeatureMatchingLoss
helper function defined in the Supporting Functions section of this example
The perceptual loss penalises the distance between real and generated feature maps obtained as predictions from a feature extraction network. is total number of feature layers. and are network predictions for ground truth images and generated images, respectively. This loss is implemented using the pix2pixhdVggLoss
helper function defined in the Supporting Functions section of this example. The feature extraction network is created in Load Feature Extraction Network.
The overall generator loss is a weighted sum of all three losses. , , and are the weight factors for adversarial loss, feature matching loss, and perceptual loss, respectively.
Note that the adversarial loss and feature matching loss for the generator are computed for two different scales.
The objective of the discriminator is to correctly distinguish between ground truth images and generated images. The discriminator loss is a sum of two components:
The squared difference between a vector of ones and the predictions of the discriminator on real images
The squared difference between a vector of zeros and the predictions of the discriminator on generated images
The discriminator loss is implemented using part of the pix2pixhdAdversarialLoss
helper function defined in the Supporting Functions section of this example. Note that adversarial loss for the discriminator is computed for two different scales.
This example modifies a pretrained VGG-19 deep neural network to extract the features of the real and generated images at various layers. These multilayer features are used to compute the perceptual loss of the generator.
To get a pretrained VGG-19 network, install vgg19
(Deep Learning Toolbox). If you do not have the required support packages installed, then the software provides a download link.
netVGG = vgg19;
Visualize the network architecture using the Deep Network Designer (Deep Learning Toolbox) app.
deepNetworkDesigner(netVGG)
To make the VGG-19 network suitable for feature extraction, keep the layers up to 'pool5' and remove all of the fully connected layers from the network. The resulting network is a fully convolutional network.
netVGG = layerGraph(netVGG.Layers(1:38));
Create a new image input layer with no normalization. Replace the original image input layer with the new layer.
inp = imageInputLayer([imageSize 3],"Normalization","None","Name","Input"); netVGG = replaceLayer(netVGG,"input",inp); netVGG = dlnetwork(netVGG);
Specify the options for Adam optimization. Train for 60 epochs. Specify identical options for the generator and discriminator networks.
Specify an equal learning rate of 0.0002.
Initialize the trailing average gradient and trailing average gradient-square decay rates with []
.
Use a gradient decay factor of 0.5 and a squared gradient decay factor of 0.999.
Use a mini-batch size of 1 for training.
numEpochs = 60; learningRate = 0.0002; trailingAvgGenerator = []; trailingAvgSqGenerator = []; trailingAvgDiscriminator = []; trailingAvgSqDiscriminator = []; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999; miniBatchSize = 1;
Create a minibatchqueue
(Deep Learning Toolbox) object that manages the mini-batching of observations in a custom training loop. The minibatchqueue
object also casts data to a dlarray
(Deep Learning Toolbox) object that enables auto differentiation in deep learning applications.
Specify the mini-batch data extraction format as SSCB
(spatial, spatial, channel, batch). Set the DispatchInBackground
name-value pair argument as the boolean returned by canUseGPU
. If a supported GPU is available for computation, then the minibatchqueue
object preprocesses mini-batches in the background in a parallel pool during training.
mbqTrain = minibatchqueue(dsTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
By default, the example downloads a pretrained version of the Pix2PixHD network for the CamVid data set by using the helper function downloadTrainedPix2PixHDNet
. The helper function is attached to the example as a supporting file. The pretrained network enables you to run the entire example without waiting for training to complete.
To train the network, set the doTraining
variable in the following code to true
. Train the model in a custom training loop. For each iteration:
Read the data for current mini-batch using the next
(Deep Learning Toolbox) function.
Evaluate the model gradients using the dlfeval
(Deep Learning Toolbox) function and the modelGradients
helper function.
Update the network parameters using the adamupdate
(Deep Learning Toolbox) function.
Update the training progress plot for every iteration and display various computed losses.
A CUDA-capable NVIDIA™ GPU with compute capability 3.0 or higher is highly recommended for training (requires Parallel Computing Toolbox™). Training takes about 22 hours on an NVIDIA™ Titan RTX and can take even longer depending on your GPU hardware. If your GPU device has less memory, try reducing the size of the input images by specifying the imageSize
variable as [480 640] in the Preprocess Training Data section of the example.
doTraining = false; if doTraining fig = figure; lossPlotter = configureTrainingProgressPlotter(fig); iteration = 0; % Loop over epochs for epoch = 1:numEpochs % Reset and shuffle the data reset(mbqTrain); shuffle(mbqTrain); % Loop over each image while hasdata(mbqTrain) iteration = iteration + 1; % Read data from current mini-batch [dlInputSegMap,dlRealImage] = next(mbqTrain); % Evaluate the model gradients and the generator state using % dlfeval and the GANLoss function listed at the end of the % example [gradParamsG,gradParamsD,lossGGAN,lossGFM,lossGVGG,lossD] = dlfeval( ... @modelGradients,dlInputSegMap,dlRealImage,dlnetGenerator,dlnetDiscriminator,netVGG); % Update the generator parameters [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = adamupdate( ... dlnetGenerator,gradParamsG, ... trailingAvgGenerator,trailingAvgSqGenerator,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Update the discriminator parameters [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = adamupdate( ... dlnetDiscriminator,gradParamsD, ... trailingAvgDiscriminator,trailingAvgSqDiscriminator,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Plot and display various losses lossPlotter = updateTrainingProgressPlotter(lossPlotter,iteration, ... epoch,numEpochs,lossGGAN,lossGFM,lossGVGG,lossD); end end save('trainedPix2PixHDNet.mat','dlnetGenerator'); else trainedPix2PixHDNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedPix2PixHDNet.zip'; netDir = fullfile(tempdir,'CamVid'); downloadTrainedPix2PixHDNet(trainedPix2PixHDNet_url,netDir); load(fullfile(netDir,'trainedPix2PixHDNet.mat')); end
The performance of this trained Pix2PixHD network is limited because the number of CamVid training images is relatively small. Additionally, some images belong to an image sequence and therefore are correlated with other images in the training set. To improve the effectiveness of the Pix2PixHD network, train the network using a different data set that has a larger number of training images without correlation.
Because of the limitations, this Pix2PixHD network generates more realistic images for some test images than for others. To demonstrate the difference in results, compare the generated images for the first and third test image. The camera angle of the first test image has an uncommon vantage point that faces more perpendicular to the road than the typical training image. In contrast, the camera angle of the third test image has a typical vantage point that faces along the road and shows two lanes with lane markers. The network has significantly better performance generating a realistic image for the third test image than for the first test image.
Get the first ground truth scene image from the test data. Resize the image using bicubic interpolation.
idxToTest = 1;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");
Get the corresponding pixel label image from the test data. Resize the pixel label image using nearest neighbor interpolation.
segMap = readimage(pxdsTest,idxToTest);
segMap = imresize(segMap,imageSize,"nearest");
Convert the pixel label image to a multichannel one-hot segmentation map by using the onehotencode
(Deep Learning Toolbox) function.
segMap1Hot = onehotencode(segMap,3,'single');
Create a dlarray
object that inputs data to the generator. If a supported GPU is available for computation, then perform inference on a GPU by converting the data to a gpuArray
object.
dlSegMap = dlarray(segMap1Hot,'SSCB'); if canUseGPU dlSegMap = gpuArray(dlSegMap); end
Generate a scene image from the generator and one-hot segmentation map using the predict
(Deep Learning Toolbox) function.
dlGeneratedImage = predict(dlnetGenerator,dlSegMap); generatedImage = extractdata(gather(dlGeneratedImage));
The final layer of the generator network produces activations in the range [-1, 1]. For display, rescale the activations to the range [0, 1].
generatedImage = rescale(generatedImage);
For display, convert the labels from categorical labels to RGB colors by using the label2rgb
function.
coloredSegMap = label2rgb(segMap,cmap);
Display the RGB pixel label image, generated scene image, and ground truth scene image in a montage.
figure montage({coloredSegMap generatedImage gtImage},'Size',[1 3]) title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])
Get the third ground truth scene image from the test data. Resize the image using bicubic interpolation.
idxToTest = 3;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");
To get the third pixel label image from the test data and to generate the corresponding scene image, you can use the helper function evaluatePix2PixHD
. This helper function is attached to the example as a supporting file.
The evaluatePix2PixHD
function performs the same operations as the evaluation of the first test image:
Get a pixel label image from the test data. Resize the pixel label image using nearest neighbor interpolation.
Convert the pixel label image to a multichannel one-hot segmentation map.
Create a dlarray
object to input data to the generator. For GPU inference, convert the data to a gpuArray
object.
Generate a scene image from the generator and one-hot segmentation map using the predict
(Deep Learning Toolbox) function.
Rescale the activations to the range [0, 1].
[generatedImage,segMap] = evaluatePix2PixHD(pxdsTest,idxToTest,imageSize,dlnetGenerator);
For display, convert the labels from categorical labels to RGB colors by using the label2rgb
function.
coloredSegMap = label2rgb(segMap,cmap);
Display the RGB pixel label image, generated scene image, and ground truth scene image in a montage.
figure montage({coloredSegMap generatedImage gtImage},'Size',[1 3]) title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])
To evaluate how well the network generalizes to pixel label images outside the CamVid data set, generate scene images from custom pixel label images. This example uses pixel label images that were created using the Image Labeler app. The pixel label images are attached to the example as supporting files. No ground truth images are available.
Create a pixel label datastore that reads and processes the pixel label images in the current example directory.
cpxds = pixelLabelDatastore(pwd,classes,labelIDs);
For each pixel label image in the datastore, generate a scene image using the helper function evaluatePix2PixHD
.
for idx = 1:length(cpxds.Files) % Get the pixel label image and generated scene image [generatedImage,segMap] = evaluatePix2PixHD(cpxds,idx,imageSize,dlnetGenerator); % For display, convert the labels from categorical labels to RGB colors coloredSegMap = label2rgb(segMap); % Display the pixel label image and generated scene image in a montage figure montage({coloredSegMap generatedImage}) title(['Custom Pixel Label Image ',num2str(idx),' and Generated Scene Image']) end
The modelGradients
helper function calculates the gradients and adversarial loss for the generator and discriminator. The function also calculates the feature matching loss and VGG loss for the generator.
function [gradParamsG,gradParamsD,lossGGAN,lossGFM,lossGVGG,lossD] = modelGradients(inputSegMap,realImage,generator,discriminator,netVGG) % Compute the image generated by the generator given the input semantic map generatedImage = forward(generator,inputSegMap); % Define the loss weights lambdaDiscriminator = 1; lambdaGenerator = 1; lambdaFeatureMatching = 5; lambdaVGG = 5; % Concatenate the image to be classified and the semantic map inpDiscriminatorReal = cat(3,inputSegMap,realImage); inpDiscriminatorGenerated = cat(3,inputSegMap,generatedImage); % Compute the adversarial loss for the discriminator and the generator. [DLossScale1,GLossScale1,realPredScale1D,fakePredScale1G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminator); % Scale the generated image, the real image, and the input semantic map to % half size resizedRealImage = dlresize(realImage, 'Scale',0.5, 'Method',"linear"); resizedGeneratedImage = dlresize(generatedImage,'Scale',0.5,'Method',"linear"); resizedinputSegMap = dlresize(inputSegMap,'Scale',0.5,'Method',"nearest"); % Concatenate the image to be classified and the semantic map inpDiscriminatorReal = cat(3,resizedinputSegMap,resizedRealImage); inpDiscriminatorGenerated = cat(3,resizedinputSegMap,resizedGeneratedImage); % Compute the adversarial loss for the discriminator and the generator [DLossScale2,GLossScale2,realPredScale2D,fakePredScale2G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminator); % Compute the feature matching loss for scale 1 FMLossScale1 = pix2pixHDFeatureMatchingLoss(realPredScale1D,fakePredScale1G); FMLossScale1 = FMLossScale1 * lambdaFeatureMatching; % Compute the feature matching loss for scale 2 FMLossScale2 = pix2pixHDFeatureMatchingLoss(realPredScale2D,fakePredScale2G); FMLossScale2 = FMLossScale2 * lambdaFeatureMatching; % Compute the VGG loss VGGLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG); VGGLoss = VGGLoss * lambdaVGG; % Compute the combined generator loss lossGCombined = GLossScale1 + GLossScale2 + FMLossScale1 + FMLossScale2 + VGGLoss; lossGCombined = lossGCombined * lambdaGenerator; % Compute gradients for the generator gradParamsG = dlgradient(lossGCombined,generator.Learnables); % Compute the combined discriminator loss lossDCombined = (DLossScale1 + DLossScale2)/2 * lambdaDiscriminator; % Compute gradients for the discriminator gradParamsD = dlgradient(lossDCombined,discriminator.Learnables); % Log the values for displaying later lossD = gather(extractdata(lossDCombined)); lossGGAN = gather(extractdata(GLossScale1 + GLossScale2)); lossGFM = gather(extractdata(FMLossScale1 + FMLossScale2)); lossGVGG = gather(extractdata(VGGLoss)); end
The helper function pix2pixHDAdverserialLoss
computes the adversarial loss gradients for the generator and the discriminator. The function also returns feature maps of the real image and synthetic images.
function [DLoss,GLoss,realPredFtrsD,genPredFtrsG] = pix2pixHDAdverserialLoss(inpReal,inpGenerated,discriminator) % Discriminator layer names containing feature maps featureNames = {'lrelu1','lrelu2','lrelu3','lrelu4','fConv'}; % Get the feature maps for the real image from the discriminator realPredFtrsD = cell(size(featureNames)); [realPredFtrsD{:}] = forward(discriminator,inpReal,"Outputs",featureNames); % Get the feature maps for the generated image from the discriminator genPredFtrsD = cell(size(featureNames)); [genPredFtrsD{:}] = forward(discriminator,inpGenerated,"Outputs",featureNames); % Get the feature map from the final layer to compute the loss realPredD = realPredFtrsD{end}; genPredD = genPredFtrsD{end}; % Compute the discriminator loss DLoss = (1 - realPredD).^2 + (genPredD).^2; DLoss = mean(DLoss,"all"); % Compute the generator loss genPredFtrsG = cell(size(featureNames)); [genPredFtrsG{:}] = forward(discriminator,inpGenerated,"Outputs",featureNames); genPredG = genPredFtrsG{end}; GLoss = (1 - genPredG).^2; GLoss = mean(GLoss,"all"); end
The helper function pix2pixHDFeatureMatchingLoss
computes the feature matching loss between a real image and a synthetic image generated by the generator.
function featureMatchingLoss = pix2pixHDFeatureMatchingLoss(realPredFtrs,genPredFtrs) % Number of features numFtrsMaps = numel(realPredFtrs); % Initialize the feature matching loss featureMatchingLoss = 0; for i = 1:numFtrsMaps % Get the feature maps of the real image a = extractdata(realPredFtrs{i}); % Get the feature maps of the synthetic image b = genPredFtrs{i}; % Compute the feature matching loss featureMatchingLoss = featureMatchingLoss + mean(abs(a - b),"all"); end end
The helper function pix2pixHDVGGLoss
computes the perceptual VGG loss between a real image and a synthetic image generated by the generator.
function vggLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG) featureWeights = [1.0/32 1.0/16 1.0/8 1.0/4 1.0]; % Initialize the VGG loss vggLoss = 0; % Specify the names of the layers with desired feature maps featureNames = ["relu1_1","relu2_1","relu3_1","relu4_1","relu5_1"]; % Extract the feature maps for the real image activReal = cell(size(featureNames)); [activReal{:}] = forward(netVGG,realImage,"Outputs",featureNames); % Extract the feature maps for the synthetic image activGenerated = cell(size(featureNames)); [activGenerated{:}] = forward(netVGG,generatedImage,"Outputs",featureNames); % Compute the VGG loss for i = 1:numel(featureNames) vggLoss = vggLoss + featureWeights(i)*mean(abs(activReal{i} - activGenerated{i}),"all"); end end
[1] Wang, Ting-Chun, Ming-Yu Liu, Jun-Yan Zhu, Andrew Tao, Jan Kautz, and Bryan Catanzaro. "High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs." In 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 8798–8807, 2018. https://doi.org/10.1109/CVPR.2018.00917.
[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. "Semantic Object Classes in Video: A High-Definition Ground Truth Database." Pattern Recognition Letters. Vol. 30, Issue 2, 2009, pp 88-97.
combine
| imageDatastore
| pixelLabelDatastore
| transform
| trainingOptions
(Deep Learning Toolbox) | trainNetwork
(Deep Learning Toolbox) | vgg19
(Deep Learning Toolbox)