This notebook is an example of specifying a sequence transdution model that transforms an input sequence $x\in\mathbb{R}^{N}$ to an output sequence $y\in\mathbb{R}^M$.
This architecture can be described in Kokoyi as follows:
%kokoyi
\Module {Seq2Seq} {x; Enc, Dec}
h_x \gets Enc(x) \\
\hat{y} \gets Dec(h_x) \\
\Return \hat{y} \\
\EndModule
This notebook will take advantage of the teacher forcing mode, where the true labels are fed into the decoder.
%kokoyi
\Module {Seq2Seq} {x, y; Enc, Dec}
h_x \gets Enc(x) \\
\hat{y} \gets Dec(y, h_x) \Comment{Teacher forcing} \\
\Return \hat{y} \\
\EndModule
The model can be trivially extended to go deep, such that both encoder and decoder have multiple layers:
%kokoyi
\Module {MultiLayerSeq2Seq} {x, y; EncLayers, DecLayers}
(L_{enc}, L_{dec}) \gets (|EncLayers|, |DecLayers|) \\
h[0 \leq l \leq L_{enc} - 1] \gets \begin{cases}
EncLayers[0](x) & l = 0 \\
EncLayers[l](h[l - 1]) & otherwise \\
\end{cases} \\
h_x \gets h[L_{enc} - 1] \\
k[0 \leq l \leq L_{dec} - 1] \gets \begin{cases}
DecLayers[0](y, h_x) & l = 0 \\
DecLayers[l](k[l - 1], h_x) & otherwise \\
\end{cases} \\
\Return k[L_{dec} - 1]\\
\EndModule
The above are warm-ups, though we can adopt the templates to generate real codes. For the time being we will developing a real LSTM-based translation system with an important mechanism, namely the attention module. In the example below, $x$ is "They", "are", "watching", ".", "<eos>"
and the $y$ is "IIs", "regardent", ".", "<eos>"
.
We will reuse the LSTM we developed for doc classification. The same helper function setup:
%kokoyi
\Module {T} {x, h ; W_x, W_h, b}
\Return W_h @ h + W_x @ x + b \\
\EndModule
\sigma \gets \Sigmoid \\
We assume the input comprises $s$, a list of embeddings, one for each word. One modification we add here is $h_0, c_0$ which can provide initial condition to LSTM states. This is used, for instance, to condition the decoder LSTM.
%kokoyi
\Module {LSTM} {s, h_0, c_0; T_s}
(L, d) \gets \GetShape(s) \\
x \gets \{ 0 \}^{1 \times d} || s \Comment{Pad the input at index 0}\\
(T_f, T_i, T_o, T_c) \gets T_s \Comment{unpack the transformation modules } \\
\begin{group}
f[1 \leq t \leq L] \gets \sigma(T_f(x[t], h[t-1])) \\
i[1 \leq t \leq L] \gets \sigma(T_i(x[t], h[t-1])) \\
o[1 \leq t \leq L] \gets \sigma(T_o(x[t], h[t-1])) \\
\tilde{c}[1 \leq t \leq L] \gets \tanh(T_c(x[t], h[t-1])) \\
c[0 \leq t \leq L] \gets \begin{cases}
c_0 & t = 0 \\
f[t] * c[t - 1] + i[t] * \tilde{c}[t] & t \leq L \\
\end{cases} \\
h[0 \leq t \leq L] \gets \begin{cases}
h_0 & t = 0 \\
o[t] * \tanh (c[t]) & t \leq L
\end{cases} \\
\end{group} \\
\Return (h[1:], c[1:]) \\
\EndModule
BiLSTM runs a second LSTM in the reverse direction, and concatenate the states. This is a crude way of saying we don't trust the lef-to-right order encodes every information; as we shall see in Part 2, Transformer relaxes this even further.
%kokoyi
\Module {BiLSTM} {x, h_0, c_0, \overleftarrow{h_0}, \overleftarrow{c_0}; LSTM, \overleftarrow{LSTM}}
N \gets |x| \\
\overleftarrow{x} \gets \{ x[N-i] \}_{i=1}^{N} \Comment{get a reverse list}\\
(h, c) \gets LSTM(x, h_0, c_0) \\
(\overleftarrow{h}, \overleftarrow{c}) \gets \overleftarrow{LSTM}(\overleftarrow{x}, \overleftarrow{h_0}, \overleftarrow{c_0}) \\
\hat{h} \gets \{ h[i] || \overleftarrow{h}[N-1-i] \}_{i=0}^{N-1} \\
\hat{c} \gets \{ c[i] || \overleftarrow{c}[N-1-i] \}_{i=0}^{N-1} \\
\Return (\hat{h}, \hat{c})\\
\EndModule
Now we are ready to define the a translator the maps one sentence in one domain to the other:
%kokoyi
\Module {Seq2Seq_{LSTM}} {x, y, h_0, c_0, \overleftarrow{h_0}, \overleftarrow{c_0};BiLSTM, LSTM, W}
(h_x, c_x) \gets BiLSTM(x, h_0, c_0, \overleftarrow{h_0}, \overleftarrow{c_0}) \Comment{Encoding} \\
\bar{h_x} \gets \Mean(h_x) \\
\bar{c_x} \gets \Mean(c_x) \\
(\hat{y}, c_y) \gets LSTM(y, \bar{h_x}, \bar{c_x}) \Comment{Decoding} \\
\Return \hat{y} @ W\\
\EndModule
One of the most effictive way in machine translation is to align output to some words in the input. Since we cannot know a priori, we use attention to learn it. In order to do that we will define a new LSTM decoder that can attend to the output of the encoder:
%kokoyi
\Module {Seq2Seq_{Attn}} {x, y, h_0, c_0, \overleftarrow{h_0}, \overleftarrow{c_0};BiLSTM, LSTM_{Attn}, W}
(h_x, c_x) \gets BiLSTM(x, h_0, c_0, \overleftarrow{h_0}, \overleftarrow{c_0}) \Comment{Encoding} \\
\bar{h_x} \gets \Mean(h_x) \\
\bar{c_x} \gets \Mean(c_x) \\
(\hat{y}, c_y) \gets LSTM_{Attn}(y, \bar{h_x}, \bar{c_x}, h_x) \Comment{Decoding} \\
\Return \hat{y} @ W\\
\EndModule
The new LSTM decoder has an attention module $Attn$ that injects additional contextual information from the input $h_x$. Here is the definition of $Attn$ and a visualization of how it works:
%kokoyi
\Function{Attn}{q, \{k\}^M, \{v\}^M}
a \gets \Softmax(\{\frac{\trans{q} @ k[j]}{\sqrt{d}} \}_{j = 0}^{M - 1}) \where d \gets |q| \\
\Return a @ v \\
\EndFunction
What it does is to compute a similarity between query $q$ and key $k$, and use that to compute a distribution (using softmax), and then compute a weighted sum over the values $v$. In our case, $k$ and $v$ are output of the encoder $h_x$, and $q$ is the decoder output $h$, and that's all the modification needed:
%kokoyi
\Module {LSTM_{Attn}} {s, h_0, c_0, h_x; T_s}
(L, d) \gets \GetShape(s) \\
\bar{s} \gets \{ 0 \}^{1 \times d} || s \\
(T_f, T_i, T_o, T_c) \gets T_s \Comment{Unpack the transformation modules } \\
\begin{group}
a[1 \leq t \leq L] \gets Attn(h[t-1], h_x, h_x) \Comment{Compute attention weight per decoder position to input} \\
x[1 \leq t \leq L] \gets \bar{s}[t] || a[t] \\
f[1 \leq t \leq L] \gets \sigma(T_f(x[t], h[t-1])) \\
i[1 \leq t \leq L] \gets \sigma(T_i(x[t], h[t-1])) \\
o[1 \leq t \leq L] \gets \sigma(T_o(x[t], h[t-1])) \\
\tilde{c}[1 \leq t \leq L] \gets \tanh(T_c(x[t], h[t-1])) \\
c[0 \leq t \leq L] \gets \begin{cases}
c_0 & t = 0 \\
f[t] * c[t - 1] + i[t] * \tilde{c}[t] & otherwise \\
\end{cases} \\
h[0 \leq t \leq L] \gets \begin{cases}
h_0 & t = 0 \\
o[t] * \tanh (c[t]) & otherwise \\
\end{cases} \\
\end{group} \\
\Return (h[1:], c[1:]) \\
\EndModule
Let's first do some setup:
import os
import torch
import torchtext
from collections import Counter
from torchtext.datasets import IWSLT2016
from torch.utils.data import DataLoader
from torchtext.vocab import vocab, build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import kokoyi
We will use IWSLT2016 dataset from torchtext. We train our model on the German-English subset that consists of bilingual sentence pairs. Each text sequence is tokenized into a sequence of integers and padded into the same length.
if not os.path.exists('data'):
os.mkdir('data')
print("Creating dataset ...")
train_iter = IWSLT2016(root='data', split='train', language_pair=('de', 'en'))
test_iter = IWSLT2016(root='data', split='test', language_pair=('de', 'en'))
train_dataset = list(train_iter)
test_dataset = list(test_iter)
print("Train set: %d" % len(train_dataset))
print("Test set: %d" % len(test_dataset))
de_tokenizer = get_tokenizer('spacy', language='de')
en_tokenizer = get_tokenizer('spacy', language='en')
def yield_tokens(dataset, idx, tokenizer):
for sentence in dataset:
yield tokenizer(sentence[idx])
de_vocab = build_vocab_from_iterator(yield_tokens(train_dataset, 0, de_tokenizer), specials=['<pad>', '<bos>', '<eos>'])
en_vocab = build_vocab_from_iterator(yield_tokens(train_dataset, 1, en_tokenizer), specials=['<pad>', '<bos>', '<eos>'])
text_pipeline = lambda x,vocab, tokenizer: [vocab['<bos>']] + [vocab[token] for token in tokenizer(x)] + [vocab['<eos>']]
def collate_batch(batch):
de_batch, en_batch = [], []
for (_de, _en) in batch:
de_batch.append(torch.tensor(text_pipeline(_de, de_vocab, de_tokenizer)))
en_batch.append(torch.tensor(text_pipeline(_en, en_vocab, en_tokenizer)))
de_batch = pad_sequence(de_batch, padding_value=de_vocab['<pad>'], batch_first=True)
en_batch = pad_sequence(en_batch, padding_value=en_vocab['<pad>'], batch_first=True)
return de_batch, en_batch
BATCH_SIZE = 32
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kokoyi.set_rt_device(device)
print('Device', device)
You can let Kokoyi set up the initialization for the Seq2Seq modules defined above (click the button and then fill up what's needed).
from kokoyi.nn import Linear
class T(torch.nn.Module):
def __init__(self, hidden_size, embed_size):
super().__init__()
self.W_x = torch.nn.Parameter(torch.empty(hidden_size, embed_size))
self.W_h = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))
self.b = torch.nn.Parameter(torch.empty(hidden_size, requires_grad=True,))
torch.nn.init.xavier_uniform_(self.W_x)
torch.nn.init.xavier_uniform_(self.W_h)
torch.nn.init.uniform_(self.b, -0.5, 0.5)
def get_parameters(self):
return self.W_x, self.W_h, self.b
forward = kokoyi.symbol['T']
class LSTM(torch.nn.Module):
def __init__(self, hidden_size, embed_size):
super().__init__()
self.T_s = torch.nn.ModuleList(
[T(hidden_size, embed_size),
T(hidden_size, embed_size),
T(hidden_size, embed_size),
T(hidden_size, embed_size)])
def get_parameters(self):
return self.T_s
forward = kokoyi.symbol['LSTM']
class BiLSTM(torch.nn.Module):
def __init__(self, hidden_size, embed_size):
super().__init__()
self.LSTM = LSTM(hidden_size, embed_size)
self.overleftarrowLSTM = LSTM(hidden_size, embed_size)
def get_parameters(self):
return self.LSTM, self.overleftarrowLSTM
forward = kokoyi.symbol['BiLSTM']
class AttnLSTM(torch.nn.Module):
def __init__(self, dec_hidden, dec_embed, enc_hidden):
super().__init__()
self.T_s = torch.nn.ModuleList(
[T(dec_hidden, dec_embed + enc_hidden),
T(dec_hidden, dec_embed + enc_hidden),
T(dec_hidden, dec_embed + enc_hidden),
T(dec_hidden, dec_embed + enc_hidden)])
def get_parameters(self):
return self.T_s
forward = kokoyi.symbol['LSTM_{Attn}']
class AttnSeq2Seq(torch.nn.Module):
def __init__(self, enc_hidden, enc_embed, dec_embed, tgt_vocab_size):
super().__init__()
dec_hidden = 2 * enc_hidden
self.BiLSTM = BiLSTM(enc_hidden, enc_embed)
self.AttnLSTM = AttnLSTM(dec_hidden, dec_embed, 2 * enc_hidden)
self.W = torch.nn.Parameter(torch.Tensor(2 * enc_hidden, tgt_vocab_size))
def get_parameters(self):
return self.BiLSTM, self.AttnLSTM, self.W
forward = kokoyi.symbol['Seq2Seq_{Attn}']
Finally, we can set the hyper-parameters and start training! Note that we use teacher forcing method where the original output sequence is fed into the decoder.
num_epochs = 3
embed_size = 64
hidden_size = 128
src_vocab_size = len(de_vocab)
tgt_vocab_size = len(en_vocab)
src_embedding = torch.nn.Parameter(torch.empty((src_vocab_size, embed_size), device=device))
tgt_embedding = torch.nn.Parameter(torch.empty((tgt_vocab_size, embed_size), device=device))
gain = torch.nn.init.calculate_gain('relu')
torch.nn.init.xavier_uniform_(src_embedding, gain=gain)
torch.nn.init.xavier_uniform_(tgt_embedding, gain=gain)
model = AttnSeq2Seq(hidden_size, embed_size, embed_size, tgt_vocab_size).to(device)
print(model)
h_0 = torch.zeros(hidden_size, device=device)
c_0 = torch.zeros(hidden_size, device=device)
overleftarrowh_0 = torch.zeros(hidden_size, device=device)
overleftarrowc_0 = torch.zeros(hidden_size, device=device)
parameters = list(model.parameters()) + [src_embedding, tgt_embedding]
optimizer = torch.optim.Adam(parameters)
for epoch in range(num_epochs):
total_loss, n_word_total, n_word_correct = 0, 0, 0
for i, (de, en) in enumerate(train_dataloader):
# prepare data
de, en = de.to(device), en.to(device)
src_seq = de
tgt_seq, gold = en[:, :-1], en[:, 1:]
# Look up the embedding table
src_emb = F.embedding(src_seq, src_embedding, padding_idx=de_vocab['<pad>'])
tgt_emb = F.embedding(tgt_seq, tgt_embedding, padding_idx=en_vocab['<pad>'])
# forward
optimizer.zero_grad()
pred = model(src_emb, tgt_emb, h_0, c_0, overleftarrowh_0, overleftarrowc_0, batch_level=[1,1,0,0,0,0])
# backward and update parameters
pred = pred.reshape(-1, pred.size(2))
gold = gold.contiguous().view(-1)
loss = F.cross_entropy(pred, gold, ignore_index=en_vocab['<pad>'], reduction='mean')
pred = pred.max(1)[1]
non_pad_mask = gold.ne(en_vocab['<pad>'])
n_correct = pred.eq(gold).masked_select(non_pad_mask).sum().item()
n_word = non_pad_mask.sum().item()
loss.backward()
optimizer.step()
# Note keeping
n_word_total += n_word
n_word_correct += n_correct
total_loss += loss.item()
if i % 5 == 0:
print(f'Epoch {epoch:04d} | Iter {i:04d} | Loss {loss.item():.4f} | Acc {(n_correct / n_word):.4f}')
print(f'Epoch {epoch:04d} | Avg Loss {(total_loss / i):.4f} | Avg Acc {(n_word_correct / n_word_total):.4f}')