This notebook is an example of specifying a model that maps a sequence $\{\mathbb{R}^{d_1\times d_2...}\}^N$ of tensors to a discrete value. Specifically, we are going to use LSTM to perform topic classification. In this notebook you are going to see an example of implementing a more complex module from scratch.
Long-short term memory (LSTM) is one of the most popular architectures in the class of recurrent neural networks (RNNs). A typical RNN holds a hidden state $h$. For a binary classifier, at time step $t$, the state is updated as a linear combination of the previous state $h_{t-1}$ and the current input $x_t$ through $\mathop{sigmoid}$ activation. The activation at the final step can be used to compute loss and train the classifier.
Here's an (incomplete) implementaiton of a two-class classifier in Kokoyi, with learnable parameter $W_x$, $W_h$ and $b$ in a vanilla RNN, whose final representation summarizes the sentence. We will use a the recursive pattern should be familiar to you by now:
where $t$ starts from 1 and the representation arrays $h$ is initialized by $h_0$. Let $S$ be the input sentence, each is an index to a word dictionary $D$. Note that the indexing of input sequence usually starts from 0 (unless you pre-processed it), so we right-shift it by padding one <pad>
at left.
%kokoyi
\Module {RNN} {S ; W_h, W_x, D, b, h_0}
L \gets |S| \Comment{Get length of the input sequence}\\
\bar{S} \gets \{ 0 \}^{1} || S \Comment{Right shift and add a 0}\\
x \gets \{D(s); s \in \bar{S} \} \Comment{Map the integer IDs to dense representations.} \\
h[0 \leq t \leq L] \gets \begin{cases}
h_0 & t = 0 \\
\tanh(W_h @ h[t-1] + W_x @ x[t] + b) & t \leq L
\end{cases} \\
\Return h[L] \\
\EndModule
Vanilla RNN has difficulty to preserve long-term information, LSTM mitigated the problem by adding some extra units (gates) to retain memory. At time step $t$, a forget gate $f_t$, an input gate $i_t$ and an output gate $o_t$ are applied to selectively drop some old memory and collect useful new state:
A candidate memory cell $\tilde{c}_t$ is maintained similarly except using $\mathop{tanh}$ as the activation function:
Then, the memory cell is update with the forget gate, the input gate, the candidate memory cell and the previous one with hadamard product:
Finally, the output of each step would be the hidden state $h_t$, calculated from the output gate and the memory cell:
We will now write the model in Kokoyi straight from the definition. Let's first define some help function and module. Since the calculation of $f_t$, $i_t$, $o_t$ have identical form, we can use a module $T$ to update these gates (cells), we will reuse it for $\tilde{c}_t$ as well. Note the use of ";" to sepearate inputs from parameters. We also define an inline function $\sigma$:
%kokoyi
\Module {T} {x, h ; W_x, W_h, b}
\Return W_h @ h + W_x @ x + b \\
\EndModule
Then, we can write the main model from the definitions, use the standard cross entropy as the loss function. Note that $\{W\}^L$ is a sentence of length $L$, each token in the sentence is an index to the embedding table $D$. By taking $\{W\}^L$ as an input, $L$ is available. The first statement unpacks the four $T$ modules we mentioned earlier; the second statement $\{D(w); w \in W\}$ maps each token id into its dense representation.
Note also:
@
for matrix multiplication("$\cdot$" in rendering), and *
for dot-product("$\circ$" in rending). fix: Aston complained about the use of cdot
and I agree, the very name of cdot
repurposed for matrix multiplication is problematic and not intuitive. Perhaps we should use *
but render as "$\cdot$". Aston also suggests that \odot
is more common for Hadamard products. Should we just use it for both Hadmard and dot-product? We should perhaps just normalize everything and don't give users a lot of choice, i.e. *
for multipication, and \odot
for dot/Haramard products, will we miss anything?
%kokoyi
\sigma \gets \Sigmoid \\
\Module {LSTM} {\{S\}^L ; T_s, Linear, D, c_0, h_0}
\bar{S} \gets \{ 0 \}^{1} || S \Comment{Pad the input at index 0}\\
(T_f, T_i, T_o, T_c) \gets T_s \Comment{unpack the transformation modules } \\
x \gets \{D(s); s \in \bar{S} \} \Comment{Map the integer IDs to dense representations.} \\
\begin{group}
f[1 \leq t \leq L] \gets \sigma(T_f(x[t], h[t - 1])) \\
i[1 \leq t \leq L] \gets \sigma(T_i(x[t], h[t - 1])) \\
o[1 \leq t \leq L] \gets \sigma(T_o(x[t], h[t - 1])) \\
\tilde{c}[1 \leq t \leq L] \gets \tanh(T_c(x[t], h[t - 1])) \\
c[0 \leq t \leq L] \gets \begin{cases}
c_0 & t = 0 \\
f[t] * c[t - 1] + i[t] * \tilde{c}[t] & otherwise \\
\end{cases} \\
h[0 \leq t \leq L] \gets \begin{cases}
h_0 & t = 0 \\
o[t] * \tanh (c[t]) & otherwise
\end{cases} \\
\end{group} \\
\hat{y} \gets Linear(h[L]) \\
\Return \hat{y} \\
\EndModule
loss(\hat{y}, y) \gets \CrossEntropy(\hat{y}, y) \\
You can let Kokoyi set up the initialization for the LSTM (just copy and paste and then fill up what's needed):
class T(torch.nn.Module): def __init__(self): """ Add your code for parameter initialization here (not necessarily the same names).""" super().__init__() self.W_x = None self.W_h = None self.b = None def get_parameters(self): """ Change the following code to return the parameters as a tuple in the order of (W_x, W_h, b).""" return None forward = kokoyi.symbol["T"] class LSTM(torch.nn.Module): def __init__(self): """ Add your code for parameter initialization here (not necessarily the same names).""" super().__init__() self.T_s = None self.Linear = None self.D = None self.c_0 = None self.h_0 = None def get_parameters(self): """ Change the following code to return the parameters as a tuple in the order of (T_s, Linear, D, c_0, h_0).""" return None forward = kokoyi.symbol["LSTM"]
Here's the completed module definition. We import the Linear module from kokoyi.nn. NN modules in Kokoyi are basically the same as NN modules in torch. You can set up a kokoyi module with the same parameters used in torch. The forward function performs almost the same, except for some changes for auto-batching.
import torch
from kokoyi.nn import Linear, Embedding
class T(torch.nn.Module):
def __init__(self, hidden_size, embed_size):
super().__init__()
self.W_x = torch.nn.Parameter(torch.empty(hidden_size, embed_size))
self.W_h = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))
self.b = torch.nn.Parameter(torch.empty(hidden_size,))
torch.nn.init.xavier_uniform_(self.W_x)
torch.nn.init.xavier_uniform_(self.W_h)
torch.nn.init.uniform_(self.b, -0.5, 0.5)
def get_parameters(self):
return self.W_x, self.W_h, self.b
forward = kokoyi.symbol['T']
class LSTM(torch.nn.Module):
def __init__(self, hidden_size, embed_size, vocab_size, label_size):
super().__init__()
self.T_s = torch.nn.ModuleList(
[T(hidden_size, embed_size),
T(hidden_size, embed_size),
T(hidden_size, embed_size),
T(hidden_size, embed_size)])
self.Linear = Linear(hidden_size, label_size)
self.D = Embedding(vocab_size, embed_size)
self.c_0 = torch.zeros(hidden_size,)
self.h_0 = torch.zeros(hidden_size,)
def get_parameters(self):
return self.T_s, self.Linear, self.D, self.c_0, self.h_0
forward = kokoyi.symbol['LSTM']
Let's first do some setup:
import os
import torch
from torch.optim import Adam
from torchtext.datasets import AG_NEWS
from torch.utils.data import DataLoader
from torch.nn.functional import pad
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
import kokoyi
We will use the news article in the AG_NEWS dataset from torchtext. The dataset consists of several label-text pairs, each text sequence is already tokenized into a sequence of integers.
if not os.path.exists('data'):
os.mkdir('data')
# os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(root='data/ag_news_csv', split='train')
counter = Counter()
for (label, line) in train_iter:
counter.update(tokenizer(line))
Vocab = vocab(counter, min_freq=1)
if '<unk>' not in Vocab: Vocab.insert_token('<unk>', 0)
if '<pad>' not in Vocab: Vocab.insert_token('<pad>', 1)
Vocab.set_default_index(0)
text_pipeline = lambda x: [Vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: int(x) - 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kokoyi.set_rt_device(device)
def collate_batch(batch):
label_list, text_list = [], []
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
if processed_text.size(0) < MAX_LEN:
processed_text = pad(processed_text,
(1, MAX_LEN - processed_text.size(0) - 1),
'constant', 1)
else:
processed_text = processed_text[:MAX_LEN]
text_list.append(processed_text)
label_list = torch.tensor(label_list, dtype=torch.int64)
text_list = torch.stack(text_list)
return label_list.to(device), text_list.to(device)
train_iter = AG_NEWS(root='data/ag_news_csv', split='train')
test_iter = AG_NEWS(root='data/ag_news_csv', split='test')
train_dataset = list(train_iter)
test_dataset = list(test_iter)
BATCH_SIZE = 16
MAX_LEN = 24
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
Finally, we can set the hyper-parameters and start training!
num_epochs = 3
hidden_size = 128
embed_size = 64
vocab_size = len(Vocab)
label_size = 4
lstm = LSTM(hidden_size, embed_size, vocab_size, label_size).to(device)
optimizer = Adam(lstm.parameters(), lr=1e-3)
lstm.train()
for epoch in range(num_epochs):
total_loss = 0
for i, (label, text) in enumerate(train_dataloader):
optimizer.zero_grad()
pred = lstm(text, batch_level=[1])
loss = kokoyi.symbol['loss'](pred, label, batch_level=[1,1])
loss = torch.mean(loss)
loss.backward()
optimizer.step()
total_loss += loss.item()
if i % 10 == 0:
print('Epoch %d | Iter %d | Loss=%.4f' % (epoch, i, loss.item()))
print('Epoch %d: total loss=%.6f' % (epoch, total_loss))
We can validate the accuracy after training with similar code. We applied a straight-forward pre-processing to simplify the tutorial, feel free to change it for a better performance!
total_acc = 0
lstm.eval()
for label, text in test_dataloader:
pred = lstm(text, batch_level=[1])
pred_y = torch.argmax(pred, dim=1)
total_acc += sum(torch.where(torch.eq(pred_y, label), 1, 0)).item()
print('Test accuracy: %.6f.' % (total_acc / len(test_dataloader) / BATCH_SIZE))