This example shows how to convert decimal strings to Roman numerals using a recurrent sequence-to-sequence encoder-decoder model with attention.
Recurrent encoder-decoder models have proven successful at tasks like abstractive text summarization and neural machine translation. The models consistent of an encoder which typically processes input data with a recurrent layer such as LSTM, and a decoder which maps the encoded input into the desired output, typically with a second recurrent layer. Models that incorporate attention mechanisms into the models allows the decoder to focus on parts of the encoded input while generating the translation.
For the encoder model, this example uses a simple network consisting of an embedding followed by two LSTM operations. Embedding is a method of converting categorical tokens into numeric vectors.
For the decoder model, this example uses a network very similar to the encoder that contains two LSTMs. However, an important difference is that the decoder contains an attention mechanism. The attention mechanism allows the decoder to attend to specific parts of the encoder output.
Download the decimal-Roman numeral pairs from "romanNumerals.csv"
filename = fullfile("romanNumerals.csv"); options = detectImportOptions(filename, ... 'TextType','string', ... 'ReadVariableNames',false); options.VariableNames = ["Source" "Target"]; options.VariableTypes = ["string" "string"]; data = readtable(filename,options);
Split the data into training and test partitions containing 50% of the data each.
idx = randperm(size(data,1),500); dataTrain = data(idx,:); dataTest = data; dataTest(idx,:) = [];
View some of the decimal-roman numeral pairs.
head(dataTrain)
ans=8×2 table
Source Target
______ _________
"168" "CLXVIII"
"154" "CLIV"
"765" "DCCLXV"
"714" "DCCXIV"
"649" "DCXLIX"
"346" "CCCXLVI"
"77" "LXXVII"
"83" "LXXXIII"
Preprocess the training data using the preprocessSourceTargetPairs
function, listed at the end of the example. The preprocessSourceTargetPairs
function converts the input text data to numeric sequences. The elements of the sequences are positive integers that index into a corresponding wordEncoding
object. The wordEncoding
maps tokens to a numeric index and vice-versa using a vocabulary. To highlight the beginning and the ends of sequences, the encoding also encapsulates the special tokens "<start>"
and "<stop>"
.
startToken = "<start>"; stopToken = "<stop>"; [sequencesSource, sequencesTarget, encSource, encTarget] = preprocessSourceTargetPairs(dataTrain,startToken,stopToken);
For example, the decimal string "441"
is encoded as follows:
strSource = "441";
Insert spaces between the characters.
strSource = strip(replace(strSource,""," "));
Add the special start and stop tokens.
strSource = startToken + strSource + stopToken
strSource = "<start>4 4 1<stop>"
Tokenize the text using the tokenizedDocument
function and set the 'CustomTokens'
option to the special tokens.
documentSource = tokenizedDocument(strSource,'CustomTokens',[startToken stopToken])
documentSource = tokenizedDocument: 5 tokens: <start> 4 4 1 <stop>
Convert the document to a sequence of token indices using the word2ind
function with the corresponding wordEncoding
object.
tokens = string(documentSource); sequenceSource = word2ind(encSource,tokens)
sequenceSource = 1×5
1 7 7 2 5
Sequence data such as text naturally have different sequence lengths. To train a model using variable length sequences, pad the mini-batches of input data to have the same length. To ensure that the padding values do not impact the loss calculations, create a mask which records which sequence elements are real, and which are just padding.
For example, consider a mini-batch containing the decimal strings "437", "431", and "102" with the corresponding Roman numeral strings "CDXXXVII", "CDXXXI", and "CII". For character-by-character sequences, the input sequences have the same length and do not need to be padded. The corresponding mask is an array of ones.
The output sequences have different lengths, so they require padding. The corresponding padding mask contains zeros where the corresponding time steps are padding values.
Initialize the model parameters. for both the encoder and decoder, specify an embedding dimension of 256, two LSTM layers with 200 hidden units, and dropout layers with random dropout with probability 0.05.
embeddingDimension = 256; numHiddenUnits = 200; dropout = 0.05;
Initialize the encoder model parameters:
Specify an embedding dimension of 256 and the vocabulary size of the source vocabulary plus 1, where the extra value corresponds to the padding token.
Specify two LSTM operations with 200 hidden units.
Initialize the embedding weights by sampling from a random normal distribution.
Initialize the LSTM weights and biases by sampling from a uniform distribution using the uniformNoise
function, listed at the end of the example.
inputSize = encSource.NumWords + 1; parametersEncoder.emb.Weights = dlarray(randn([embeddingDimension inputSize])); parametersEncoder.lstm1.InputWeights = dlarray(uniformNoise([4*numHiddenUnits embeddingDimension],1/numHiddenUnits)); parametersEncoder.lstm1.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersEncoder.lstm1.Bias = dlarray(uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits)); parametersEncoder.lstm2.InputWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersEncoder.lstm2.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersEncoder.lstm2.Bias = dlarray(uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits));
Initialize the decoder model parameters.
Specify an embedding dimension of 256 and the vocabulary size of the target vocabulary plus 1, where the extra value corresponds to the padding token.
Initialize the attention mechanism weights using the uniformNoise
function.
Initialize the embedding weights by sampling from a random normal distribution.
Initialize the LSTM weights and biases by sampling from a uniform distribution using the uniformNoise
function.
outputSize = encTarget.NumWords + 1; parametersDecoder.emb.Weights = dlarray(randn([embeddingDimension outputSize])); parametersDecoder.attn.Weights = dlarray(uniformNoise([numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm1.InputWeights = dlarray(uniformNoise([4*numHiddenUnits embeddingDimension+numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm1.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm1.Bias = dlarray( uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits)); parametersDecoder.lstm2.InputWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm2.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm2.Bias = dlarray(uniformNoise([4*numHiddenUnits 1], 1/numHiddenUnits)); parametersDecoder.fc.Weights = dlarray(uniformNoise([outputSize 2*numHiddenUnits],1/(2*numHiddenUnits))); parametersDecoder.fc.Bias = dlarray(uniformNoise([outputSize 1], 1/(2*numHiddenUnits)));
Create the functions modelEncoder
and modelDecoder
, listed at the end of the example, that compute the outputs of the encoder and decoder models, respectively.
The modelEncoder
function, listed in the Encoder Model Function section of the example, takes the input data, the model parameters, the optional mask that is used to determine the correct outputs for training and returns the model outputs and the LSTM hidden state.
The modelDecoder
function, listed in the Decoder Model Function section of the example, takes the input data, the model parameters, the context vector, the LSTM initial hidden state, the outputs of the encoder, and the dropout probability and outputs the decoder output, the updated context vector, the updated LSTM state, and the attention scores.
Create the function modelGradients
, listed in the Model Gradients Function section of the example, that takes the encoder and decoder model parameters, a mini-batch of input data and the padding masks corresponding to the input data, and the dropout probability and returns the gradients of the loss with respect to the learnable parameters in the models and the corresponding loss.
Train with a mini-batch size of 32 for 40 epochs. Specify a learning rate of 0.002 and clip the gradients with a threshold of 5.
miniBatchSize = 32; numEpochs = 40; learnRate = 0.002; gradientThreshold = 5;
Initialize the options from Adam.
gradientDecayFactor = 0.9; squaredGradientDecayFactor = 0.999;
Specify to plot the training progress. To disable the training progress plot, set the plots
value to "none"
.
plots = "training-progress";
Train the model using a custom training loop.
For the first epoch, train with the sequences sorted by increasing sequence length. This results in batches with sequences of approximately the same sequence length, and ensures smaller sequence batches are used to update the model before longer sequence batches. For subsequent epochs, shuffle the data.
For each mini-batch:
Convert the data to dlarray
.
Compute loss and gradients.
Clip the gradients.
Update the encoder and decoder model parameters using the adamupdate
function.
Update the training progress plot.
Sort the sequences for the first epoch.
sequenceLengthsEncoder = cellfun(@(sequence) size(sequence,2), sequencesSource); [~,idx] = sort(sequenceLengthsEncoder); sequencesSource = sequencesSource(idx); sequencesTarget = sequencesTarget(idx);
Initialize the training progress plot.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Initialize the values for the adamupdate
function.
trailingAvgEncoder = []; trailingAvgSqEncoder = []; trailingAvgDecoder = []; trailingAvgSqDecoder = [];
Train the model.
numObservations = numel(sequencesSource); numIterationsPerEpoch = floor(numObservations/miniBatchSize); iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Loop over mini-batches. for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data idx = (i-1)*miniBatchSize+1:i*miniBatchSize; [XSource, XTarget, maskSource, maskTarget] = createBatch(sequencesSource(idx), ... sequencesTarget(idx), inputSize, outputSize); % Convert mini-batch of data to dlarray. dlXSource = dlarray(XSource); dlXTarget = dlarray(XTarget); % Compute loss and gradients. [gradientsEncoder, gradientsDecoder, loss] = dlfeval(@modelGradients, parametersEncoder, ... parametersDecoder, dlXSource, dlXTarget, maskSource, maskTarget, dropout); % Gradient clipping. gradientsEncoder = dlupdate(@(w) clipGradient(w,gradientThreshold), gradientsEncoder); gradientsDecoder = dlupdate(@(w) clipGradient(w,gradientThreshold), gradientsDecoder); % Update encoder using adamupdate. [parametersEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(parametersEncoder, ... gradientsEncoder, trailingAvgEncoder, trailingAvgSqEncoder, iteration, learnRate, ... gradientDecayFactor, squaredGradientDecayFactor); % Update decoder using adamupdate. [parametersDecoder, trailingAvgDecoder, trailingAvgSqDecoder] = adamupdate(parametersDecoder, ... gradientsDecoder, trailingAvgDecoder, trailingAvgSqDecoder, iteration, learnRate, ... gradientDecayFactor, squaredGradientDecayFactor); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(loss))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end % Shuffle data. idx = randperm(numObservations); sequencesSource = sequencesSource(idx); sequencesTarget = sequencesTarget(idx); end
To generate translations for new data using the trained model, convert the text data to numeric sequences using the same steps as when training and input the sequences into the encoder-decoder model and convert the resulting sequences back into text using the token indices.
Select a mini-batch of test observations.
numObservationsTest = 16; idx = randperm(size(dataTest,1),numObservationsTest); dataTest(idx,:)
ans=16×2 table
Source Target
______ ___________
"857" "DCCCLVII"
"991" "CMXCI"
"143" "CXLIII"
"924" "CMXXIV"
"752" "DCCLII"
"85" "LXXXV"
"131" "CXXXI"
"124" "CXXIV"
"858" "DCCCLVIII"
"103" "CIII"
"497" "CDXCVII"
"76" "LXXVI"
"815" "DCCCXV"
"829" "DCCCXXIX"
"940" "CMXL"
"94" "XCIV"
Preprocess the text data using the same steps as when training. Use the transformText
function, listed at the end of the example, to split the text into characters and add the start and stop tokens.
strSource = dataTest{idx,1}; strTarget = dataTest{idx,2}; documentsSource = transformText(strSource,startToken,stopToken);
Convert the tokenized text into a batch of padded sequences by using the doc2sequence
function. To automatically pad the sequences, set the 'PaddingDirection'
option to 'right'
and set the padding value to the input size (the token index of the padding token).
sequencesSource = doc2sequence(encSource,documentsSource, ... 'PaddingDirection','right', ... 'PaddingValue',inputSize);
Concatenate and permute the sequence data into the required shape for the encoder model function (1-by-N-by-S, where N is the number of observations and S is the sequence length).
XSource = cat(3,sequencesSource{:}); XSource = permute(XSource,[1 3 2]);
Convert input data to dlarray
and calculate the encoder model outputs.
dlXSource = dlarray(XSource); [dlZ, hiddenState] = modelEncoder(dlXSource, parametersEncoder);
To generate translations for new data input the sequences into the encoder-decoder model and convert the resulting sequences back into text using the token indices.
To initialize the translations, create a vector containing only the indices corresponding to the start token.
decoderInput = repmat(word2ind(encTarget,startToken),[1 numObservationsTest]); decoderInput = dlarray(decoderInput);
Initialize the context vector and the cell arrays containing the translated sequences and the attention scores for each observation.
context = dlarray(zeros([size(dlZ, 1) numObservationsTest])); sequencesTranslated = cell(1,numObservationsTest); attentionScores = cell(1,numObservationsTest);
Loop over time steps and translate the sequences. Keep looping over the time steps until all sequences translated. For each observation, when the translation is finished (when the decoder predicts the stop token), set a flag to stop translating that sequence.
stopIdx = word2ind(encTarget,stopToken); stopTranslating = false(1, numObservationsTest); maxSequenceLength = 10; while ~all(stopTranslating) % Forward through decoder. [dlY, context, hiddenState, attn] = modelDecoder(decoderInput, parametersDecoder, context, ... hiddenState, dlZ); % Loop over observations. for i = 1:numObservationsTest % Skip already-translated sequences. if stopTranslating(i) continue end % Update attention scores. attentionScores{i} = [attentionScores{i} extractdata(attn(:,i))]; % Predict next time step. prob = softmax(dlY(:,i), 'DataFormat', 'CB'); [~, idx] = max(prob(1:end-1,:), [], 1); % Set stopTranslating flag when translation done. if idx == stopIdx || numel(sequencesTranslated{i}) == maxSequenceLength stopTranslating(i) = true; else sequencesTranslated{i} = [sequencesTranslated{i} extractdata(idx)]; decoderInput(i) = idx; end end end
View the source text, target text, and translations in a table.
tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = cellfun(@(sequence) join(ind2word(encTarget,sequence),""),sequencesTranslated)';
tbl
tbl=16×3 table
Source Target Translated
______ ___________ ___________
"857" "DCCCLVII" "DCCCLVII"
"991" "CMXCI" "CMXCI"
"143" "CXLIII" "CXLIII"
"924" "CMXXIV" "CMXXIV"
"752" "DCCLII" "DCCLII"
"85" "LXXXV" "DCCCLVI"
"131" "CXXXI" "CXXXI"
"124" "CXXIV" "CXXIV"
"858" "DCCCLVIII" "DCCCLVIII"
"103" "CIII" "CIII"
"497" "CDXCVII" "CDXCVII"
"76" "LXXVI" "DCCLVII"
"815" "DCCCXV" "DCCCXV"
"829" "DCCCXXIX" "DCCCXXIX"
"940" "CMXL" "CMXL"
"94" "XCIV" "CMXLVI"
Plot the attention scores of the first sequence in a heat map. The attention scores highlight which areas of the source and translated sequences the model attends to when processing the translation.
idx = 1; figure xlabs = [ind2word(encTarget,sequencesTranslated{idx}) stopToken]; ylabs = string(documentsSource(idx)); heatmap(attentionScores{idx}, ... 'CellLabelColor','none', ... 'XDisplayLabels',xlabs, ... 'YDisplayLabels',ylabs); xlabel("Translation") ylabel("Source") title("Attention Scores")
The preprocessSourceTargetPairs
takes a table data
containing the source-target pairs in two columns and for each column returns sequences of token indices and a corresponding wordEncoding
object that maps the indices to words and vice versa.
function [sequencesSource, sequencesTarget, encSource, encTarget] = preprocessSourceTargetPairs(data,startToken,stopToken) % Extract text data. strSource = data{:,1}; strTarget = data{:,2}; % Create tokenized document arrays. documentsSource = transformText(strSource,startToken,stopToken); documentsTarget = transformText(strTarget,startToken,stopToken); % Create word encodings. encSource = wordEncoding(documentsSource); encTarget = wordEncoding(documentsTarget); % Convert documents to numeric sequences. sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none'); sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none'); end
The transformText
function preprocesses and tokenizes the input text for translation by splitting the text into characters and adding start and stop tokens. To translate text by splitting the text into words instead of characters, skip the first step.
function documents = transformText(str,startToken,stopToken) % Split text into characters. str = strip(replace(str,""," ")); % Add start and stop tokens. str = startToken + str + stopToken; % Create tokenized document array. documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]); end
The createBatch
function takes a mini-batch of source and target sequences and returns padded sequences with the corresponding padding masks.
function [XSource, XTarget, maskSource, maskTarget] = createBatch(sequencesSource, sequencesTarget, ... paddingValueSource, paddingValueTarget) numObservations = size(sequencesSource,1); sequenceLengthSource = max(cellfun(@(x) size(x,2), sequencesSource)); sequenceLengthTarget = max(cellfun(@(x) size(x,2), sequencesTarget)); % Initialize masks. maskSource = false(numObservations, sequenceLengthSource); maskTarget = false(numObservations, sequenceLengthTarget); % Initialize mini-batch. XSource = zeros(1,numObservations,sequenceLengthSource); XTarget = zeros(1,numObservations,sequenceLengthTarget); % Pad sequences and create masks. for i = 1:numObservations % Source L = size(sequencesSource{i},2); paddingSize = sequenceLengthSource - L; padding = repmat(paddingValueSource, [1 paddingSize]); XSource(1,i,:) = [sequencesSource{i} padding]; maskSource(i,1:L) = true; % Target L = size(sequencesTarget{i},2); paddingSize = sequenceLengthTarget - L; padding = repmat(paddingValueTarget, [1 paddingSize]); XTarget(1,i,:) = [sequencesTarget{i} padding]; maskTarget(i,1:L) = true; end end
The function modelEncoder
takes the input data, the model parameters, the optional mask that is used to determine the correct outputs for training and returns the model output and the LSTM hidden state.
function [dlZ, hiddenState] = modelEncoder(dlX, parametersEncoder, maskSource) % Embedding weights = parametersEncoder.emb.Weights; dlZ = embedding(dlX,weights); % LSTM inputWeights = parametersEncoder.lstm1.InputWeights; recurrentWeights = parametersEncoder.lstm1.RecurrentWeights; bias = parametersEncoder.lstm1.Bias; numHiddenUnits = size(recurrentWeights, 2); initialHiddenState = dlarray(zeros([numHiddenUnits 1])); initialCellState = dlarray(zeros([numHiddenUnits 1])); dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ... recurrentWeights, bias, 'DataFormat', 'CBT'); % LSTM inputWeights = parametersEncoder.lstm2.InputWeights; recurrentWeights = parametersEncoder.lstm2.RecurrentWeights; bias = parametersEncoder.lstm2.Bias; [dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ... inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Mask output for training if nargin > 2 dlZ = dlZ.*permute(maskSource, [3 1 2]); sequenceLengths = sum(maskSource, 2); % Mask final hidden state for ii = 1:size(dlZ, 2) hiddenState(:, ii) = dlZ(:, ii, sequenceLengths(ii)); end end end
The function modelDecoder
takes the input data, the model parameters, the context vector, the LSTM initial hidden state, the outputs of the encoder, and the dropout probability and outputs the decoder output, the updated context vector, the updated LSTM state, and the attention scores.
function [dlY, context, hiddenState, attentionScores] = modelDecoder(dlX, parameters, context, ... hiddenState, encoderOutputs, dropout) % Embedding weights = parameters.emb.Weights; dlX = embedding(dlX, weights); % RNN input dlY = cat(1, dlX, context); % LSTM 1 initialCellState = dlarray(zeros(size(hiddenState))); inputWeights = parameters.lstm1.InputWeights; recurrentWeights = parameters.lstm1.RecurrentWeights; bias = parameters.lstm1.Bias; dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, ... recurrentWeights, bias, 'DataFormat', 'CBT'); if nargin > 5 % Dropout mask = ( rand(size(dlY), 'like', dlY) > dropout ); dlY = dlY.*mask; end % LSTM 2 inputWeights = parameters.lstm2.InputWeights; recurrentWeights = parameters.lstm2.RecurrentWeights; bias = parameters.lstm2.Bias; [~, hiddenState] = lstm(dlY, hiddenState, initialCellState, ... inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Attention weights = parameters.attn.Weights; attentionScores = attention(hiddenState, encoderOutputs, weights); % Context encoderOutputs = permute(encoderOutputs, [1 3 2]); attentionScores = permute(attentionScores,[1 3 2]); context = dlmtimes(encoderOutputs,attentionScores); context = squeeze(context); % Fully connect weights = parameters.fc.Weights; bias = parameters.fc.Bias; dlY = weights*cat(1, hiddenState, context) + bias; end
The embedding
function maps numeric indices to the corresponding vector given by the input weights.
function Z = embedding(X, weights) % Reshape inputs into a vector [N, T] = size(X, 2:3); X = reshape(X, N*T, 1); % Index into embedding matrix Z = weights(:, X); % Reshape outputs by separating out batch and sequence dimensions Z = reshape(Z, [], N, T); end
The attention
function computes the attention scores according to Luong "general" scoring.
function attentionScores = attention(hiddenState, encoderOutputs, weights) [N, S] = size(encoderOutputs, 2:3); attentionEnergies = dlarray(zeros( [S N] )); % The energy at each time step is the dot product of the hidden state % and the learnable attention weights times the encoder output hWX = hiddenState .* dlmtimes(weights,encoderOutputs); for tt = 1:S attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1); end % Compute softmax scores attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB'); end
The modelGradients
function takes the encoder and decoder model parameters, a mini-batch of input data and the padding masks corresponding to the input data, and the dropout probability and returns the gradients of the loss with respect to the learnable parameters in the models and the corresponding loss.
function [gradientsEncoder, gradientsDecoder, maskedLoss] = modelGradients(parametersEncoder, ... parametersDecoder, dlXSource, dlXTarget, maskSource, maskTarget, dropout) % Forward through encoder. [dlZ, hiddenState] = modelEncoder(dlXSource, parametersEncoder, maskSource); % Get parameter sizes. [miniBatchSize, sequenceLength] = size(dlXTarget,2:3); sequenceLength = sequenceLength - 1; numHiddenUnits = size(dlZ,1); % Initialize context vector. context = dlarray(zeros([numHiddenUnits miniBatchSize])); % Initialize loss. loss = dlarray(zeros([miniBatchSize sequenceLength])); % Get first time step for decoder. decoderInput = dlXTarget(:,:,1); % Choose whether to use teacher forcing. doTeacherForcing = rand < 0.5; if doTeacherForcing for t = 1:sequenceLength % Forward through decoder. [dlY, context, hiddenState] = modelDecoder(decoderInput, parametersDecoder, context, ... hiddenState, dlZ, dropout); % Update loss. dlT = dlarray(oneHot(dlXTarget(:,:,t+1), size(dlY,1))); loss(:,t) = crossEntropyAndSoftmax(dlY, dlT); % Get next time step. decoderInput = dlXTarget(:,:,t+1); end else for t = 1:sequenceLength % Forward through decoder. [dlY, context, hiddenState] = modelDecoder(decoderInput, parametersDecoder, context, ... hiddenState, dlZ, dropout); % Update loss. dlT = dlarray(oneHot(dlXTarget(:,:,t+1), size(dlY,1))); loss(:,t) = crossEntropyAndSoftmax(dlY, dlT); % Greedily update next input time step. prob = softmax(dlY,'DataFormat','CB'); [~, decoderInput] = max(prob,[],1); end end % Determine masked loss. maskedLoss = sum(sum(loss.*maskTarget(:,2:end))) / miniBatchSize; % Update gradients. [gradientsEncoder, gradientsDecoder] = dlgradient(maskedLoss, parametersEncoder, parametersDecoder); % For plotting, return loss normalized by sequence length. maskedLoss = extractdata(maskedLoss) ./ sequenceLength; end
The crossEntropyAndSoftmax
loss computes the cross-entropy and softmax loss.
function loss = crossEntropyAndSoftmax(dlY, dlT) offset = max(dlY); logSoftmax = dlY - offset - log(sum(exp(dlY-offset))); loss = -sum(dlT.*logSoftmax); end
The uniformNoise
function samples weights from a uniform distribution.
function weights = uniformNoise(sz, k) weights = -sqrt(k) + 2*sqrt(k).*rand(sz); end
The clipGradient
function clips the model gradients.
function g = clipGradient(g, gradientThreshold) wnorm = norm(extractdata(g)); if wnorm > gradientThreshold g = (gradientThreshold/wnorm).*g; end end
The oneHot
function encodes word indices as one-hot vectors.
function oh = oneHot(idx, numTokens) tokens = (1:numTokens)'; oh = (tokens == idx); end
adamupdate
| crossentropy
| dlarray
| dlfeval
| dlgradient
| dlupdate
| doc2sequence
| lstm
| softmax
| tokenizedDocument
| word2ind
| wordEncoding