Sequence-to-Sequence Translation Using Attention

This example shows how to convert decimal strings to Roman numerals using a recurrent sequence-to-sequence encoder-decoder model with attention.

Recurrent encoder-decoder models have proven successful at tasks like abstractive text summarization and neural machine translation. The models consistent of an encoder which typically processes input data with a recurrent layer such as LSTM, and a decoder which maps the encoded input into the desired output, typically with a second recurrent layer. Models that incorporate attention mechanisms into the models allows the decoder to focus on parts of the encoded input while generating the translation.

For the encoder model, this example uses a simple network consisting of an embedding followed by two LSTM operations. Embedding is a method of converting categorical tokens into numeric vectors.

For the decoder model, this example uses a network very similar to the encoder that contains two LSTMs. However, an important difference is that the decoder contains an attention mechanism. The attention mechanism allows the decoder to attend to specific parts of the encoder output.

Load Training Data

Download the decimal-Roman numeral pairs from "romanNumerals.csv"

filename = fullfile("romanNumerals.csv");

options = detectImportOptions(filename, ...
    'TextType','string', ...
    'ReadVariableNames',false);
options.VariableNames = ["Source" "Target"];
options.VariableTypes = ["string" "string"];

data = readtable(filename,options);

Split the data into training and test partitions containing 50% of the data each.

idx = randperm(size(data,1),500);
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

View some of the decimal-roman numeral pairs.

head(dataTrain)
ans=8×2 table
    Source     Target  
    ______    _________

    "168"     "CLXVIII"
    "154"     "CLIV"   
    "765"     "DCCLXV" 
    "714"     "DCCXIV" 
    "649"     "DCXLIX" 
    "346"     "CCCXLVI"
    "77"      "LXXVII" 
    "83"      "LXXXIII"

Preprocess Data

Preprocess the training data using the preprocessSourceTargetPairs function, listed at the end of the example. The preprocessSourceTargetPairs function converts the input text data to numeric sequences. The elements of the sequences are positive integers that index into a corresponding wordEncoding object. The wordEncoding maps tokens to a numeric index and vice-versa using a vocabulary. To highlight the beginning and the ends of sequences, the encoding also encapsulates the special tokens "<start>" and "<stop>".

startToken = "<start>";
stopToken = "<stop>";
[sequencesSource, sequencesTarget, encSource, encTarget] = preprocessSourceTargetPairs(dataTrain,startToken,stopToken);

Representing Text as Numeric Sequences

For example, the decimal string "441" is encoded as follows:

strSource = "441";

Insert spaces between the characters.

strSource = strip(replace(strSource,""," "));

Add the special start and stop tokens.

strSource = startToken + strSource + stopToken
strSource = 
"<start>4 4 1<stop>"

Tokenize the text using the tokenizedDocument function and set the 'CustomTokens' option to the special tokens.

documentSource = tokenizedDocument(strSource,'CustomTokens',[startToken stopToken])
documentSource = 
  tokenizedDocument:

   5 tokens: <start> 4 4 1 <stop>

Convert the document to a sequence of token indices using the word2ind function with the corresponding wordEncoding object.

tokens = string(documentSource);
sequenceSource = word2ind(encSource,tokens)
sequenceSource = 1×5

     1     7     7     2     5

Padding and Masking

Sequence data such as text naturally have different sequence lengths. To train a model using variable length sequences, pad the mini-batches of input data to have the same length. To ensure that the padding values do not impact the loss calculations, create a mask which records which sequence elements are real, and which are just padding.

For example, consider a mini-batch containing the decimal strings "437", "431", and "102" with the corresponding Roman numeral strings "CDXXXVII", "CDXXXI", and "CII". For character-by-character sequences, the input sequences have the same length and do not need to be padded. The corresponding mask is an array of ones.

The output sequences have different lengths, so they require padding. The corresponding padding mask contains zeros where the corresponding time steps are padding values.

Initialize Model Parameters

Initialize the model parameters. for both the encoder and decoder, specify an embedding dimension of 256, two LSTM layers with 200 hidden units, and dropout layers with random dropout with probability 0.05.

embeddingDimension = 256;
numHiddenUnits = 200;
dropout = 0.05;

Initialize the encoder model parameters:

  • Specify an embedding dimension of 256 and the vocabulary size of the source vocabulary plus 1, where the extra value corresponds to the padding token.

  • Specify two LSTM operations with 200 hidden units.

  • Initialize the embedding weights by sampling from a random normal distribution.

  • Initialize the LSTM weights and biases by sampling from a uniform distribution using the uniformNoise function, listed at the end of the example.

inputSize = encSource.NumWords + 1;
parametersEncoder.emb.Weights = dlarray(randn([embeddingDimension inputSize]));

parametersEncoder.lstm1.InputWeights = dlarray(uniformNoise([4*numHiddenUnits embeddingDimension],1/numHiddenUnits));
parametersEncoder.lstm1.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersEncoder.lstm1.Bias = dlarray(uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits));

parametersEncoder.lstm2.InputWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersEncoder.lstm2.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersEncoder.lstm2.Bias = dlarray(uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits));

Initialize the decoder model parameters.

  • Specify an embedding dimension of 256 and the vocabulary size of the target vocabulary plus 1, where the extra value corresponds to the padding token.

  • Initialize the attention mechanism weights using the uniformNoise function.

  • Initialize the embedding weights by sampling from a random normal distribution.

  • Initialize the LSTM weights and biases by sampling from a uniform distribution using the uniformNoise function.

outputSize = encTarget.NumWords + 1;
parametersDecoder.emb.Weights = dlarray(randn([embeddingDimension outputSize]));

parametersDecoder.attn.Weights = dlarray(uniformNoise([numHiddenUnits numHiddenUnits],1/numHiddenUnits));

parametersDecoder.lstm1.InputWeights = dlarray(uniformNoise([4*numHiddenUnits embeddingDimension+numHiddenUnits],1/numHiddenUnits));
parametersDecoder.lstm1.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersDecoder.lstm1.Bias = dlarray( uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits));

parametersDecoder.lstm2.InputWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersDecoder.lstm2.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersDecoder.lstm2.Bias = dlarray(uniformNoise([4*numHiddenUnits 1], 1/numHiddenUnits));

parametersDecoder.fc.Weights = dlarray(uniformNoise([outputSize 2*numHiddenUnits],1/(2*numHiddenUnits)));
parametersDecoder.fc.Bias = dlarray(uniformNoise([outputSize 1], 1/(2*numHiddenUnits)));

Define Model Functions

Create the functions modelEncoder and modelDecoder, listed at the end of the example, that compute the outputs of the encoder and decoder models, respectively.

The modelEncoder function, listed in the Encoder Model Function section of the example, takes the input data, the model parameters, the optional mask that is used to determine the correct outputs for training and returns the model outputs and the LSTM hidden state.

The modelDecoder function, listed in the Decoder Model Function section of the example, takes the input data, the model parameters, the context vector, the LSTM initial hidden state, the outputs of the encoder, and the dropout probability and outputs the decoder output, the updated context vector, the updated LSTM state, and the attention scores.

Define Model Gradients Function

Create the function modelGradients, listed in the Model Gradients Function section of the example, that takes the encoder and decoder model parameters, a mini-batch of input data and the padding masks corresponding to the input data, and the dropout probability and returns the gradients of the loss with respect to the learnable parameters in the models and the corresponding loss.

Specify Training Options

Train with a mini-batch size of 32 for 40 epochs. Specify a learning rate of 0.002 and clip the gradients with a threshold of 5.

miniBatchSize = 32;
numEpochs = 40;
learnRate = 0.002;
gradientThreshold = 5;

Initialize the options from Adam.

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Specify to plot the training progress. To disable the training progress plot, set the plots value to "none".

plots = "training-progress";

Train Model

Train the model using a custom training loop.

For the first epoch, train with the sequences sorted by increasing sequence length. This results in batches with sequences of approximately the same sequence length, and ensures smaller sequence batches are used to update the model before longer sequence batches. For subsequent epochs, shuffle the data.

For each mini-batch:

  • Convert the data to dlarray.

  • Compute loss and gradients.

  • Clip the gradients.

  • Update the encoder and decoder model parameters using the adamupdate function.

  • Update the training progress plot.

Sort the sequences for the first epoch.

sequenceLengthsEncoder = cellfun(@(sequence) size(sequence,2), sequencesSource);
[~,idx] = sort(sequenceLengthsEncoder);
sequencesSource = sequencesSource(idx);
sequencesTarget = sequencesTarget(idx);

Initialize the training progress plot.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Initialize the values for the adamupdate function.

trailingAvgEncoder = [];
trailingAvgSqEncoder = [];

trailingAvgDecoder = [];
trailingAvgSqDecoder = [];

Train the model.

numObservations = numel(sequencesSource);
numIterationsPerEpoch = floor(numObservations/miniBatchSize);

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        [XSource, XTarget, maskSource, maskTarget] = createBatch(sequencesSource(idx), ...
            sequencesTarget(idx), inputSize, outputSize);
        
        % Convert mini-batch of data to dlarray.
        dlXSource = dlarray(XSource);
        dlXTarget = dlarray(XTarget);
        
        % Compute loss and gradients.
        [gradientsEncoder, gradientsDecoder, loss] = dlfeval(@modelGradients, parametersEncoder, ...
            parametersDecoder, dlXSource, dlXTarget, maskSource, maskTarget, dropout);
        
        % Gradient clipping.
        gradientsEncoder = dlupdate(@(w) clipGradient(w,gradientThreshold), gradientsEncoder);
        gradientsDecoder = dlupdate(@(w) clipGradient(w,gradientThreshold), gradientsDecoder);
        
        % Update encoder using adamupdate.
        [parametersEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(parametersEncoder, ...
            gradientsEncoder, trailingAvgEncoder, trailingAvgSqEncoder, iteration, learnRate, ...
            gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update decoder using adamupdate.
        [parametersDecoder, trailingAvgDecoder, trailingAvgSqDecoder] = adamupdate(parametersDecoder, ...
            gradientsDecoder, trailingAvgDecoder, trailingAvgSqDecoder, iteration, learnRate, ...
            gradientDecayFactor, squaredGradientDecayFactor);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(loss)))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
    
    % Shuffle data.
    idx = randperm(numObservations);
    sequencesSource = sequencesSource(idx);
    sequencesTarget = sequencesTarget(idx);
end

Generate Translations

To generate translations for new data using the trained model, convert the text data to numeric sequences using the same steps as when training and input the sequences into the encoder-decoder model and convert the resulting sequences back into text using the token indices.

Prepare Data for Translation

Select a mini-batch of test observations.

numObservationsTest = 16;
idx = randperm(size(dataTest,1),numObservationsTest);
dataTest(idx,:)
ans=16×2 table
    Source      Target   
    ______    ___________

    "857"     "DCCCLVII" 
    "991"     "CMXCI"    
    "143"     "CXLIII"   
    "924"     "CMXXIV"   
    "752"     "DCCLII"   
    "85"      "LXXXV"    
    "131"     "CXXXI"    
    "124"     "CXXIV"    
    "858"     "DCCCLVIII"
    "103"     "CIII"     
    "497"     "CDXCVII"  
    "76"      "LXXVI"    
    "815"     "DCCCXV"   
    "829"     "DCCCXXIX" 
    "940"     "CMXL"     
    "94"      "XCIV"     

Preprocess the text data using the same steps as when training. Use the transformText function, listed at the end of the example, to split the text into characters and add the start and stop tokens.

strSource = dataTest{idx,1};
strTarget = dataTest{idx,2};

documentsSource = transformText(strSource,startToken,stopToken);

Convert the tokenized text into a batch of padded sequences by using the doc2sequence function. To automatically pad the sequences, set the 'PaddingDirection' option to 'right' and set the padding value to the input size (the token index of the padding token).

sequencesSource = doc2sequence(encSource,documentsSource, ...
    'PaddingDirection','right', ...
    'PaddingValue',inputSize);

Concatenate and permute the sequence data into the required shape for the encoder model function (1-by-N-by-S, where N is the number of observations and S is the sequence length).

XSource = cat(3,sequencesSource{:});
XSource = permute(XSource,[1 3 2]);

Convert input data to dlarray and calculate the encoder model outputs.

dlXSource = dlarray(XSource);
[dlZ, hiddenState] = modelEncoder(dlXSource, parametersEncoder);

Translate Source Text

To generate translations for new data input the sequences into the encoder-decoder model and convert the resulting sequences back into text using the token indices.

To initialize the translations, create a vector containing only the indices corresponding to the start token.

decoderInput = repmat(word2ind(encTarget,startToken),[1 numObservationsTest]);
decoderInput = dlarray(decoderInput);

Initialize the context vector and the cell arrays containing the translated sequences and the attention scores for each observation.

context = dlarray(zeros([size(dlZ, 1) numObservationsTest]));
sequencesTranslated = cell(1,numObservationsTest);
attentionScores = cell(1,numObservationsTest);

Loop over time steps and translate the sequences. Keep looping over the time steps until all sequences translated. For each observation, when the translation is finished (when the decoder predicts the stop token), set a flag to stop translating that sequence.

stopIdx = word2ind(encTarget,stopToken);

stopTranslating = false(1, numObservationsTest);
maxSequenceLength = 10;

while ~all(stopTranslating)
    
    % Forward through decoder.
    [dlY, context, hiddenState, attn] = modelDecoder(decoderInput, parametersDecoder, context, ...
        hiddenState, dlZ);
    
    % Loop over observations.
    for i = 1:numObservationsTest
        % Skip already-translated sequences.
        if stopTranslating(i)
            continue
        end

        
        % Update attention scores.
        attentionScores{i} = [attentionScores{i} extractdata(attn(:,i))];
        
        % Predict next time step.
        prob = softmax(dlY(:,i), 'DataFormat', 'CB');
        [~, idx] = max(prob(1:end-1,:), [], 1);
        
        % Set stopTranslating flag when translation done.
        if idx == stopIdx || numel(sequencesTranslated{i}) == maxSequenceLength
            stopTranslating(i) = true;
        else
            sequencesTranslated{i} = [sequencesTranslated{i} extractdata(idx)];
            decoderInput(i) = idx;
        end
    end
end

View the source text, target text, and translations in a table.

tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = cellfun(@(sequence) join(ind2word(encTarget,sequence),""),sequencesTranslated)';
tbl
tbl=16×3 table
    Source      Target       Translated 
    ______    ___________    ___________

    "857"     "DCCCLVII"     "DCCCLVII" 
    "991"     "CMXCI"        "CMXCI"    
    "143"     "CXLIII"       "CXLIII"   
    "924"     "CMXXIV"       "CMXXIV"   
    "752"     "DCCLII"       "DCCLII"   
    "85"      "LXXXV"        "DCCCLVI"  
    "131"     "CXXXI"        "CXXXI"    
    "124"     "CXXIV"        "CXXIV"    
    "858"     "DCCCLVIII"    "DCCCLVIII"
    "103"     "CIII"         "CIII"     
    "497"     "CDXCVII"      "CDXCVII"  
    "76"      "LXXVI"        "DCCLVII"  
    "815"     "DCCCXV"       "DCCCXV"   
    "829"     "DCCCXXIX"     "DCCCXXIX" 
    "940"     "CMXL"         "CMXL"     
    "94"      "XCIV"         "CMXLVI"   

Plot Attention Scores

Plot the attention scores of the first sequence in a heat map. The attention scores highlight which areas of the source and translated sequences the model attends to when processing the translation.

idx = 1;
figure
xlabs = [ind2word(encTarget,sequencesTranslated{idx}) stopToken];
ylabs = string(documentsSource(idx));

heatmap(attentionScores{idx}, ...
    'CellLabelColor','none', ...
    'XDisplayLabels',xlabs, ...
    'YDisplayLabels',ylabs);

xlabel("Translation")
ylabel("Source")
title("Attention Scores")

Preprocessing Function

The preprocessSourceTargetPairs takes a table data containing the source-target pairs in two columns and for each column returns sequences of token indices and a corresponding wordEncoding object that maps the indices to words and vice versa.

function [sequencesSource, sequencesTarget, encSource, encTarget] = preprocessSourceTargetPairs(data,startToken,stopToken)

% Extract text data.
strSource = data{:,1};
strTarget = data{:,2};

% Create tokenized document arrays.
documentsSource = transformText(strSource,startToken,stopToken);
documentsTarget = transformText(strTarget,startToken,stopToken);

% Create word encodings.
encSource = wordEncoding(documentsSource);
encTarget = wordEncoding(documentsTarget);

% Convert documents to numeric sequences.
sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none');
sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none');

end

Text Transformation Function

The transformText function preprocesses and tokenizes the input text for translation by splitting the text into characters and adding start and stop tokens. To translate text by splitting the text into words instead of characters, skip the first step.

function documents = transformText(str,startToken,stopToken)

% Split text into characters.
str = strip(replace(str,""," "));

% Add start and stop tokens.
str = startToken + str + stopToken;

% Create tokenized document array.
documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]);

end

Batch Creation Function

The createBatch function takes a mini-batch of source and target sequences and returns padded sequences with the corresponding padding masks.

function [XSource, XTarget, maskSource, maskTarget] = createBatch(sequencesSource, sequencesTarget, ...
    paddingValueSource, paddingValueTarget)

numObservations = size(sequencesSource,1);
sequenceLengthSource = max(cellfun(@(x) size(x,2), sequencesSource));
sequenceLengthTarget = max(cellfun(@(x) size(x,2), sequencesTarget));

% Initialize masks.
maskSource = false(numObservations, sequenceLengthSource);
maskTarget = false(numObservations, sequenceLengthTarget);

% Initialize mini-batch.
XSource = zeros(1,numObservations,sequenceLengthSource);
XTarget = zeros(1,numObservations,sequenceLengthTarget);

% Pad sequences and create masks.
for i = 1:numObservations
    
    % Source
    L = size(sequencesSource{i},2);
    paddingSize = sequenceLengthSource - L;
    padding = repmat(paddingValueSource, [1 paddingSize]);
    
    XSource(1,i,:) = [sequencesSource{i} padding];
    maskSource(i,1:L) = true;
    
    % Target
    L = size(sequencesTarget{i},2);
    paddingSize = sequenceLengthTarget - L;
    padding = repmat(paddingValueTarget, [1 paddingSize]);
    
    XTarget(1,i,:) = [sequencesTarget{i} padding];
    maskTarget(i,1:L) = true;
end

end

Encoder Model Function

The function modelEncoder takes the input data, the model parameters, the optional mask that is used to determine the correct outputs for training and returns the model output and the LSTM hidden state.

function [dlZ, hiddenState] = modelEncoder(dlX, parametersEncoder, maskSource)

% Embedding
weights = parametersEncoder.emb.Weights;
dlZ = embedding(dlX,weights);

% LSTM
inputWeights = parametersEncoder.lstm1.InputWeights;
recurrentWeights = parametersEncoder.lstm1.RecurrentWeights;
bias = parametersEncoder.lstm1.Bias;
numHiddenUnits = size(recurrentWeights, 2);
initialHiddenState = dlarray(zeros([numHiddenUnits 1]));
initialCellState = dlarray(zeros([numHiddenUnits 1]));

dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ...
    recurrentWeights, bias, 'DataFormat', 'CBT');

% LSTM
inputWeights = parametersEncoder.lstm2.InputWeights;
recurrentWeights = parametersEncoder.lstm2.RecurrentWeights;
bias = parametersEncoder.lstm2.Bias;

[dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ...
    inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Mask output for training
if nargin > 2
    dlZ = dlZ.*permute(maskSource, [3 1 2]);
    sequenceLengths = sum(maskSource, 2);
    
    % Mask final hidden state
    for ii = 1:size(dlZ, 2)
        hiddenState(:, ii) = dlZ(:, ii, sequenceLengths(ii));
    end
end

end

Decoder Model Function

The function modelDecoder takes the input data, the model parameters, the context vector, the LSTM initial hidden state, the outputs of the encoder, and the dropout probability and outputs the decoder output, the updated context vector, the updated LSTM state, and the attention scores.

function [dlY, context, hiddenState, attentionScores] = modelDecoder(dlX, parameters, context, ...
    hiddenState, encoderOutputs, dropout)

% Embedding
weights = parameters.emb.Weights;
dlX = embedding(dlX, weights);

% RNN input
dlY = cat(1, dlX, context);

% LSTM 1
initialCellState = dlarray(zeros(size(hiddenState)));

inputWeights = parameters.lstm1.InputWeights;
recurrentWeights = parameters.lstm1.RecurrentWeights;
bias = parameters.lstm1.Bias;

dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, ...
    recurrentWeights, bias, 'DataFormat', 'CBT');

if nargin > 5
    % Dropout
    mask = ( rand(size(dlY), 'like', dlY) > dropout );
    dlY = dlY.*mask;
end

% LSTM 2
inputWeights = parameters.lstm2.InputWeights;
recurrentWeights = parameters.lstm2.RecurrentWeights;
bias = parameters.lstm2.Bias;
[~, hiddenState] = lstm(dlY, hiddenState, initialCellState, ...
    inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Attention
weights = parameters.attn.Weights;
attentionScores = attention(hiddenState, encoderOutputs, weights);

% Context
encoderOutputs = permute(encoderOutputs, [1 3 2]);
attentionScores = permute(attentionScores,[1 3 2]);
context = dlmtimes(encoderOutputs,attentionScores);
context = squeeze(context);

% Fully connect
weights = parameters.fc.Weights;
bias = parameters.fc.Bias;
dlY = weights*cat(1, hiddenState, context) + bias;

end

Embedding Function

The embedding function maps numeric indices to the corresponding vector given by the input weights.

function Z = embedding(X, weights)
% Reshape inputs into a vector
[N, T] = size(X, 2:3);
X = reshape(X, N*T, 1);

% Index into embedding matrix
Z = weights(:, X);

% Reshape outputs by separating out batch and sequence dimensions
Z = reshape(Z, [], N, T);
end

Attention Function

The attention function computes the attention scores according to Luong "general" scoring.

function attentionScores = attention(hiddenState, encoderOutputs, weights)

[N, S] = size(encoderOutputs, 2:3);
attentionEnergies = dlarray(zeros( [S N] ));

% The energy at each time step is the dot product of the hidden state
% and the learnable attention weights times the encoder output
hWX = hiddenState .* dlmtimes(weights,encoderOutputs);
for tt = 1:S
    attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1);
end

% Compute softmax scores
attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB');

end

Model Gradients Function

The modelGradients function takes the encoder and decoder model parameters, a mini-batch of input data and the padding masks corresponding to the input data, and the dropout probability and returns the gradients of the loss with respect to the learnable parameters in the models and the corresponding loss.

function [gradientsEncoder, gradientsDecoder, maskedLoss] = modelGradients(parametersEncoder, ...
    parametersDecoder, dlXSource, dlXTarget, maskSource, maskTarget, dropout)

% Forward through encoder.
[dlZ, hiddenState] = modelEncoder(dlXSource, parametersEncoder, maskSource);

% Get parameter sizes.
[miniBatchSize, sequenceLength] = size(dlXTarget,2:3);
sequenceLength = sequenceLength - 1;
numHiddenUnits = size(dlZ,1);

% Initialize context vector.
context = dlarray(zeros([numHiddenUnits miniBatchSize]));

% Initialize loss.
loss = dlarray(zeros([miniBatchSize sequenceLength]));

% Get first time step for decoder.
decoderInput = dlXTarget(:,:,1);

% Choose whether to use teacher forcing.
doTeacherForcing = rand < 0.5;

if doTeacherForcing
    for t = 1:sequenceLength
        % Forward through decoder.
        [dlY, context, hiddenState] = modelDecoder(decoderInput, parametersDecoder, context, ...
            hiddenState, dlZ, dropout);
        
        % Update loss.
        dlT = dlarray(oneHot(dlXTarget(:,:,t+1), size(dlY,1)));
        loss(:,t) = crossEntropyAndSoftmax(dlY, dlT);
        
        % Get next time step.
        decoderInput = dlXTarget(:,:,t+1);
    end
else
    for t = 1:sequenceLength
        % Forward through decoder.
        [dlY, context, hiddenState] = modelDecoder(decoderInput, parametersDecoder, context, ...
            hiddenState, dlZ, dropout);
        
        % Update loss.
        dlT = dlarray(oneHot(dlXTarget(:,:,t+1), size(dlY,1)));
        loss(:,t) = crossEntropyAndSoftmax(dlY, dlT);
        
        % Greedily update next input time step.
        prob = softmax(dlY,'DataFormat','CB');
        [~, decoderInput] = max(prob,[],1);
    end
end

% Determine masked loss.
maskedLoss = sum(sum(loss.*maskTarget(:,2:end))) / miniBatchSize;

% Update gradients.
[gradientsEncoder, gradientsDecoder] = dlgradient(maskedLoss, parametersEncoder, parametersDecoder);

% For plotting, return loss normalized by sequence length.
maskedLoss = extractdata(maskedLoss) ./ sequenceLength;

end

Cross-Entropy and Softmax Loss Function

The crossEntropyAndSoftmax loss computes the cross-entropy and softmax loss.

function loss = crossEntropyAndSoftmax(dlY, dlT)

offset = max(dlY);
logSoftmax = dlY - offset - log(sum(exp(dlY-offset)));
loss = -sum(dlT.*logSoftmax);

end

Uniform Noise Function

The uniformNoise function samples weights from a uniform distribution.

function weights = uniformNoise(sz, k)

weights = -sqrt(k) + 2*sqrt(k).*rand(sz);

end

Gradient Clipping Function

The clipGradient function clips the model gradients.

function g = clipGradient(g, gradientThreshold)

wnorm = norm(extractdata(g));
if wnorm > gradientThreshold
    g = (gradientThreshold/wnorm).*g;
end

end

One-Hot Encoding Function

The oneHot function encodes word indices as one-hot vectors.

function oh = oneHot(idx, numTokens)
tokens = (1:numTokens)';
oh = (tokens == idx);
end

See Also

| | | | | | | | | | |

Related Topics