Seq2Seq Part 2: Transformer

Recall the encoder-decoder architecture we used in the BiLSTM-LSTM notbook:

And compare this with the top-level Transformer as written below:

You notice that the high-level encoder-decoder architecture is almost the same. The crucial difference of Transformer, given that a sequence model is nothing but a conditional probability computes $p(y_t|y_{<t}, x)$, is the way it processes all tokens in $y$ in parallel from input (remember we are doing teacher forcing training). In contrast, our earlier model predicts $y_{[t]}$ given $x$ and $y_{[0:t-1]}$ sequentially. Taking all input tokens at once loses their positional information. As a remedy, Transformer applies a positional encoding $Enc_{pos}$ on the input sequences (both source and target).

One important advantage of Transformer is to achieve parallelism while preserving the sequence prediction nature, and this is reflected in a couple of places, as we will show shortly.

Let's first look at the next level of details, as depicted in the diagram below: image

Transformer Encoder

Let's look at the encoder. Recall we did Bi-LSTM where the representation of each position fuses from its neighbors of both sides. To push it to the extreme, Transformer pulls from all positions, i.e. for position $w$ and let $S$ as the collection of all positions in the sentence, we compute $h^l[w] \gets f(h^{l-1}[w], g(h^{l-1}[v]|v \in S))$ with some function $f$ and $g$ at level $l$.

This is done with the $\mathrm{MHAttn}$ block below, which reuses the attention module $\mathrm{Attn}$ we had developed earlier. $\mathrm{MHAttn}$ adds a few more things:

Now we are ready to build the encoder module. Note that:

Transformer decoder

The decoder is very similar to the encoder except that it needs to attend to the encoder states -- this part is no different than BiLSTM-LSTM translator we see earlier. As is the case of encoder, decoder also performs attention internally, with one crucial difference: to reflect the sequence prediction nature, a position never attends to future positions (i.e. $y_{[t]}$ only attends to $y_{[0:t-1]}$. To do so, we need to define a different attentin module:

Now we are ready to write out the Transformer decoder. Note that we do a partial attention with the decoder states first, and followed with full attention to the encoder state so that the precition will be conditioned on $x$.

Completing NN modules

We now complete the rest part of NN module definitions in Python using the templates generated by Kokoyi.

Training

We use the same setup from the Seq2Seq_LSTM tutorial to train the Transformer model for machine translation task.

We will use IWSLT2016 dataset from torchtext. We train our model on the German-English subset that consists of bilingual sentence pairs. Each text sequence is tokenized into a sequence of integers and padded into the same length.