{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "from collections import Counter\n",
    "from torchtext.vocab import build_vocab_from_iterator\n",
    "from torchtext.data.utils import get_tokenizer\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_punctuation_regex(text):\n",
    "    # 使用正则表达式去除标点符号\n",
    "    cleaned_text = re.sub(r'[^\\w\\s]', '', text)\n",
    "    return cleaned_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # get_tokenizer\n",
    "\n",
    "# en = open('./datasets/EN-DE/train.en', encoding='utf-8').readlines()\n",
    "# de = open('./datasets/EN-DE/train.de', encoding='utf-8').readlines()\n",
    "\n",
    "# def custom_tokenizer(text):\n",
    "#     # 在这里实现你的分词逻辑\n",
    "#     # 返回标记列表\n",
    "#     tokens = text.split()  # 这只是一个示例，将文本按空格分割\n",
    "#     return tokens\n",
    "# tokenizers = { 'en': get_tokenizer(custom_tokenizer), \n",
    "#                     'de': get_tokenizer(custom_tokenizer) }\n",
    "# def _get_tokens(corpus, lang):\n",
    "#     for line in corpus:\n",
    "#         yield tokenizers[lang](remove_punctuation_regex(line.lower().strip()))\n",
    "        \n",
    "# def _build_vocab():\n",
    "#     src_vocab = build_vocab_from_iterator(_get_tokens(en, 'en'), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)  # discarding words occurs < 2 times\n",
    "#     trg_vocab = build_vocab_from_iterator(_get_tokens(de, 'de'), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)\n",
    "  \n",
    "#     trg_vocab.set_default_index(trg_vocab['<unk>'])\n",
    "#     src_vocab.set_default_index(src_vocab['<unk>'])\n",
    "\n",
    "#     return src_vocab, trg_vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_iterator = [\"This is a sample sentence.\", \"Another example sentence.\"]\n",
    "# def text_iterator():\n",
    "#     for text in data_iterator:\n",
    "#         yield text.split()  # 这里简单地将文本按空格分割为单词\n",
    "        \n",
    "# vocab = build_vocab_from_iterator(text_iterator(), min_freq=1, specials=[\"<unk>\", \"<pad>\"])\n",
    "# # indexed_data = [vocab[token] for token in text_iterator()]\n",
    "# for tokens in text_iterator():\n",
    "#     for token in tokens:\n",
    "#         print(vocab[token])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vocab = build_vocab_from_iterator(_get_tokens(en, 'en'), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # [vocab[token] for token in _get_tokens(en, 'en')]\n",
    "# for tokens in _get_tokens(en, 'en'):\n",
    "#     for token in tokens:\n",
    "#         print(vocab[token])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_tokenizer(text):\n",
    "    # 在这里实现你的分词逻辑\n",
    "    # 返回标记列表\n",
    "    tokens = text.split()  # 这只是一个示例，将文本按空格分割\n",
    "    return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class nmtDataset(Dataset):\n",
    "    def __init__(self, src_path, tgt_path,src_name,tgt_name,split, train_ds=None):\n",
    "        self.en = open(src_path, encoding='utf-8').readlines()\n",
    "        self.de = open(tgt_path, encoding='utf-8').readlines()\n",
    "        \n",
    "        self.tokenizers = { src_name: get_tokenizer(custom_tokenizer), \n",
    "                            tgt_name: get_tokenizer(custom_tokenizer) }\n",
    "        self.src_name = src_name\n",
    "        self.tgt_name = tgt_name\n",
    "        if split == 'train':\n",
    "            self.src_vocab, self.trg_vocab = self._build_vocab()\n",
    "        else:\n",
    "            self.src_vocab, self.trg_vocab = train_ds.src_vocab, train_ds.trg_vocab\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.en)\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        src_tokens = self.tokenizers[self.src_name](self.en[item].lower().strip())\n",
    "        trg_tokens = self.tokenizers[self.tgt_name](self.de[item].lower().strip())\n",
    "  \n",
    "        return {\n",
    "            \"src\": [self.src_vocab['<sos>']] + self.src_vocab.lookup_indices(src_tokens) + [self.src_vocab['<eos>']],\n",
    "            \"trg\": [self.trg_vocab['<sos>']] + self.trg_vocab.lookup_indices(trg_tokens) + [self.trg_vocab['<eos>']]\n",
    "        }\n",
    "        \n",
    "    def _build_vocab(self):\n",
    "        src_vocab = build_vocab_from_iterator(self._get_tokens(self.en, self.src_name), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)  # discarding words occurs < 2 times\n",
    "        trg_vocab = build_vocab_from_iterator(self._get_tokens(self.de, self.tgt_name), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)\n",
    "  \n",
    "        trg_vocab.set_default_index(trg_vocab['<unk>'])\n",
    "        src_vocab.set_default_index(src_vocab['<unk>'])\n",
    "\n",
    "        return src_vocab, trg_vocab\n",
    "    \n",
    "    def _get_tokens(self, corpus, lang):\n",
    "        for line in corpus:\n",
    "            yield self.tokenizers[lang](line.lower().strip())\n",
    "            \n",
    "# data = nmtDataset(src_path='./datasets/EN-DE/train.en',tgt_path='./datasets/EN-DE/train.de',src_name='en',tgt_name='de',split='train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import math\n",
    "import time\n",
    "import gc\n",
    "import copy\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.nn import Transformer\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import torchtext\n",
    "from torchtext.data.metrics import bleu_score\n",
    "\n",
    "# import spacy\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "# from dataset import nmtDataset\n",
    "import helpers as utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Seeding for consistency in reproducibility\n",
    "SEED = 1234\n",
    "\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "log = utils.Logger('logs/jp-ch-transformers.out')  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyp_params = {\n",
    "    \"batch_size\": 128,\n",
    "    \"lr\": 0.0005,\n",
    "    \"num_epochs\": 10,\n",
    "    \n",
    "    # Same as presented in paper\n",
    "    \"d_model\": 512,\n",
    "    \n",
    "    # No. of multi-head attention block (aka paralle self-attention layers)\n",
    "    # Same as presented in paper\n",
    "    \"n_head\": 8,\n",
    "    \n",
    "    # N in the paper and they used 6 of each\n",
    "    \"num_encoder_layers\": 3,\n",
    "    \"num_decoder_layers\": 3,\n",
    "    \n",
    "    \"feedforward_dim\": 128,\n",
    "    \n",
    "    # Following paper\n",
    "    \"dropout\": 0.1\n",
    "}\n",
    "# data = nmtDataset(src_path='./datasets/EN-DE/train.en',tgt_path='./datasets/EN-DE/train.de',src_name='en',tgt_name='de',split='train')\n",
    "# nmtds_train = nmtDataset(src_path='./datasets/jp-ch/mintrain.jp',tgt_path='./datasets/jp-ch/mintrain.ch',src_name='jp',tgt_name='ch',split='train')\n",
    "# nmtds_valid = nmtDataset(src_path='./datasets/jp-ch/minvalid.jp',tgt_path='./datasets/jp-ch/minvalid.ch',src_name='jp',tgt_name='ch',split='val',train_ds=nmtds_train)\n",
    "# nmtds_test = nmtDataset(src_path='./datasets/jp-ch/minvalid.jp',tgt_path='./datasets/jp-ch/minvalid.ch',src_name='jp',tgt_name='ch',split='val',train_ds=nmtds_train)\n",
    "\n",
    "nmtds_train = nmtDataset(src_path='./datasets/jp-ch/train.jp',tgt_path='./datasets/jp-ch/train.ch',src_name='jp',tgt_name='ch',split='train')\n",
    "nmtds_valid = nmtDataset(src_path='./datasets/jp-ch/valid.jp',tgt_path='./datasets/jp-ch/valid.ch',src_name='jp',tgt_name='ch',split='val',train_ds=nmtds_train)\n",
    "nmtds_test = nmtDataset(src_path='./datasets/jp-ch/valid.jp',tgt_path='./datasets/jp-ch/valid.ch',src_name='jp',tgt_name='ch',split='val',train_ds=nmtds_train)\n",
    "# nmtds_valid = nmtDataset('datasets/Multi30k/', 'val', nmtds_train)\n",
    "# nmtds_test = nmtDataset('datasets/Multi30k/', 'test', nmtds_train)\n",
    "\n",
    "SRC_PAD_IDX = nmtds_train.src_vocab[\"<pad>\"]\n",
    "TRG_PAD_IDX = nmtds_train.trg_vocab[\"<pad>\"]\n",
    "\n",
    "train_dataloader = DataLoader(nmtds_train, batch_size=hyp_params[\"batch_size\"], shuffle=True,\n",
    "                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, SRC_PAD_IDX, device))\n",
    "\n",
    "valid_dataloader = DataLoader(nmtds_valid, batch_size=hyp_params[\"batch_size\"], shuffle=True,\n",
    "                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, SRC_PAD_IDX, device))\n",
    "\n",
    "hyp_params[\"src_vocab_size\"] = len(nmtds_train.src_vocab)\n",
    "hyp_params[\"trg_vocab_size\"] = len(nmtds_train.trg_vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionalEncoding(nn.Module):\n",
    "    def __init__(self, d_model, dropout, maxlen = 5000):\n",
    "        super(PositionalEncoding, self).__init__()\n",
    "        \n",
    "        # A tensor consists of all the possible positions (index) e.g 0, 1, 2, ... max length of input\n",
    "        # Shape (pos) --> [max len, 1]\n",
    "        pos = torch.arange(0, maxlen).unsqueeze(1)\n",
    "        \n",
    "        pos_encoding = torch.zeros((maxlen, d_model))\n",
    "        \n",
    "        # In the paper, they had 2i in the positional encoding formula\n",
    "        # where i is the dimension \n",
    "        sin_den = 10000 ** (torch.arange(0, d_model, 2)/d_model) # sin for even item of position's dimension\n",
    "        cos_den = 10000 ** (torch.arange(1, d_model, 2)/d_model) # cos for odd \n",
    "        \n",
    "        pos_encoding[:, 0::2] = torch.sin(pos / sin_den) \n",
    "        pos_encoding[:, 1::2] = torch.cos(pos / cos_den)\n",
    "        \n",
    "        # Shape (pos_embedding) --> [max len, d_model]\n",
    "        pos_encoding = pos_encoding.unsqueeze(-2)\n",
    "        # Shape (pos_embedding) --> [max len, 1, d_model]\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "        # We want pos_encoding be saved and restored in the `state_dict`, but not trained by the optimizer\n",
    "        # hence registering it!\n",
    "        # Source & credits: https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723/2\n",
    "        self.register_buffer('pos_encoding', pos_encoding)\n",
    "\n",
    "    def forward(self, token_embedding):\n",
    "        \n",
    "        # shape (token_embedding) --> [sentence len, batch size, d_model]\n",
    "        \n",
    "        # Concatenating embeddings with positional encodings\n",
    "        # Note: As we made positional encoding with the size max length of sentence in our dataset \n",
    "        #       hence here we are picking till the sentence length in a batch\n",
    "        #       Another thing to notice is in the paper they used FIXED positional encoding, there are\n",
    "        #       methods where we can also learn them but we are doing as presented in the paper\n",
    "        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])\n",
    "    \n",
    "class InputEmbedding(nn.Module):\n",
    "    def __init__(self, vocab_size, d_model):\n",
    "        super(InputEmbedding, self).__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(vocab_size, d_model)\n",
    "        self.d_model = d_model\n",
    "\n",
    "    def forward(self, tokens):\n",
    "        # shape (tokens) --> [sentence len, batch size]\n",
    "        # shape (inp_emb) --> [sentence len, batch size, d_model]\n",
    "        # Multiplying with square root of d_model as they mentioned in the paper\n",
    "        inp_emb = self.embedding(tokens.long()) * math.sqrt(self.d_model)\n",
    "        return inp_emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqTransformer(nn.Module):\n",
    "    def __init__(self, \n",
    "                 src_vocab_size, \n",
    "                 trg_vocab_size, \n",
    "                 d_model, \n",
    "                 dropout,\n",
    "                 nhead,\n",
    "                 num_encoder_layers,\n",
    "                 num_decoder_layers,\n",
    "                 dim_feedforward,\n",
    "                 src_pad_idx,\n",
    "                 trg_pad_idx\n",
    "                ):\n",
    "        super(Seq2SeqTransformer, self).__init__()\n",
    "        \n",
    "        self.src_pad_idx = src_pad_idx\n",
    "        self.trg_pad_idx = trg_pad_idx\n",
    "        \n",
    "        self.src_inp_emb = InputEmbedding(src_vocab_size, d_model)\n",
    "        self.trg_inp_emb = InputEmbedding(trg_vocab_size, d_model)\n",
    "        \n",
    "        self.positional_encoding = PositionalEncoding(d_model, dropout=dropout)\n",
    "        \n",
    "        self.transformer = Transformer(d_model=d_model,\n",
    "                                       nhead=nhead,\n",
    "                                       num_encoder_layers=num_encoder_layers,\n",
    "                                       num_decoder_layers=num_decoder_layers,\n",
    "                                       dim_feedforward=dim_feedforward,\n",
    "                                       dropout=dropout)\n",
    "        \n",
    "        self.linear = nn.Linear(d_model, trg_vocab_size)\n",
    "    \n",
    "    def forward(self, src, trg):\n",
    "        src_emb = self.positional_encoding(self.src_inp_emb(src))\n",
    "        trg_emb = self.positional_encoding(self.trg_inp_emb(trg))\n",
    "        \n",
    "        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.create_mask(src, trg)\n",
    "        \n",
    "        outs = self.transformer(src = src_emb, \n",
    "                                tgt = trg_emb, \n",
    "                                src_mask = src_mask,\n",
    "                                tgt_mask = tgt_mask, \n",
    "                                src_key_padding_mask = src_padding_mask, \n",
    "                                tgt_key_padding_mask = tgt_padding_mask,\n",
    "                                memory_key_padding_mask = src_padding_mask\n",
    "                               )\n",
    "        return self.linear(outs)\n",
    "        \n",
    "\n",
    "    def create_mask(self, src, trg):\n",
    "        src_seq_len = src.shape[0]\n",
    "        trg_seq_len = trg.shape[0]\n",
    "\n",
    "        # Subsequent mask aka \"look ahead mask\" is important as it wont let Decoder\n",
    "        # to peek into future tokens.\n",
    "        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).to(device)\n",
    "        \n",
    "        src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool) # All False hence unchanged\n",
    "\n",
    "        # Padding masking will allow attention to ignore padding <pad> tokens \n",
    "        src_padding_mask = (src == self.src_pad_idx).transpose(0, 1)\n",
    "        trg_padding_mask = (trg == self.trg_pad_idx).transpose(0, 1)\n",
    "        \n",
    "        return src_mask, trg_mask, src_padding_mask, trg_padding_mask\n",
    "    \n",
    "    # These two functions will only used while inferring\n",
    "    \n",
    "    def encode(self, src):\n",
    "        src_mask = torch.zeros((src.shape[0], src.shape[0]),device=device).type(torch.bool)\n",
    "        src_padding_mask = (src == self.src_pad_idx).transpose(0, 1)\n",
    "        \n",
    "        return self.transformer.encoder(self.positional_encoding(self.src_inp_emb(src)), src_mask, src_padding_mask)\n",
    "\n",
    "    def decode(self, trg, memory):\n",
    "        # memory is the output from the encoder \n",
    "        trg_seq_len = trg.shape[0]\n",
    "        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).type(torch.bool).to(device)\n",
    "        trg_padding_mask = (trg == self.trg_pad_idx).transpose(0, 1)\n",
    "        \n",
    "        return self.transformer.decoder(tgt = self.positional_encoding(self.trg_inp_emb(trg)), \n",
    "                                        memory = memory,\n",
    "                                        tgt_mask = trg_mask,\n",
    "                                        tgt_key_padding_mask = trg_padding_mask\n",
    "                                       )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, train_dataloader, criterion, optimizer):\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    for batch_idx, batch in enumerate(tqdm(train_dataloader)):\n",
    "        # shape (src, trg) --> [seq len, batch size]\n",
    "        src = batch[\"src\"]\n",
    "        trg = batch[\"trg\"]\n",
    "\n",
    "        # Clear the accumulating gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # shape (trg_inp, trg_out) --> [seq len - 1, batch size]\n",
    "        trg_inp = trg[:-1, :]\n",
    "        trg_out = trg[1:, :]\n",
    "\n",
    "        # shape --> (seq len - 1) * batch size \n",
    "        # Making all target seqeunces in 1d tensor\n",
    "        trg_out = trg_out.reshape(-1)\n",
    "\n",
    "        # shape (logits) --> [seq len - 1, batch size, trg vocab size]\n",
    "        logits = model(src, trg_inp)\n",
    "\n",
    "        # shape (logits) --> [(seq len - 1) * batch size, trg vocab size]\n",
    "        logits = logits.reshape(-1, logits.shape[-1])\n",
    "\n",
    "        loss = criterion(logits, trg_out)\n",
    "\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.detach().cpu()\n",
    "\n",
    "    return epoch_loss/len(train_dataloader)\n",
    "\n",
    "def evaluate_model(model, valid_dataloader, criterion):\n",
    "    model.eval()\n",
    "    epoch_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, batch in enumerate(valid_dataloader):\n",
    "            # shape (src, trg) --> [seq len, batch size]\n",
    "            src = batch[\"src\"]\n",
    "            trg = batch[\"trg\"]\n",
    "\n",
    "            # shape (trg_inp, trg_out) --> [seq len - 1, batch size]\n",
    "            trg_inp = trg[:-1, :]\n",
    "            trg_out = trg[1:, :]\n",
    "\n",
    "            # shape --> (seq len - 1) * batch size \n",
    "            # Making all target seqeunces in 1d tensor\n",
    "            trg_out = trg_out.reshape(-1)\n",
    "\n",
    "            # shape (logits) --> [seq len - 1, batch size, trg vocab size]\n",
    "            logits = model(src, trg_inp)\n",
    "\n",
    "            # shape (logits) --> [(seq len - 1) * batch size, trg vocab size]\n",
    "            logits = logits.reshape(-1, logits.shape[-1])\n",
    "\n",
    "            loss = criterion(logits, trg_out)\n",
    "\n",
    "            epoch_loss += loss.detach().cpu()\n",
    "    \n",
    "    return epoch_loss/len(valid_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Seq2SeqTransformer(hyp_params[\"src_vocab_size\"],\n",
    "                                hyp_params[\"trg_vocab_size\"],\n",
    "                                hyp_params[\"d_model\"],\n",
    "                                hyp_params[\"dropout\"],\n",
    "                                hyp_params[\"n_head\"],\n",
    "                                hyp_params[\"num_encoder_layers\"],\n",
    "                                hyp_params[\"num_decoder_layers\"],\n",
    "                                hyp_params[\"feedforward_dim\"],\n",
    "                                SRC_PAD_IDX,\n",
    "                                TRG_PAD_IDX\n",
    "                                ).to(device)\n",
    "\n",
    "# They did not mention it in paper however tutorials from \n",
    "# bentrevett and others uses it so -_(-.-)_-\n",
    "for p in model.parameters():\n",
    "    if p.dim() > 1:\n",
    "        nn.init.xavier_uniform_(p)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX).to(device)\n",
    "\n",
    "# They did not used fixed learning rate in the paper infact \n",
    "# their optimizer would look like \n",
    "#   optimizer = torch.optim.Adam(transformer.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)\n",
    "# with variable learning rate as they presented in the paper\n",
    "# but for the sake of simplicity and also fixed lr works fine\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyp_params[\"lr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(snt, dataset, model, device):\n",
    "    snt = torch.tensor(snt).view(-1,1).to(device)\n",
    "    \n",
    "    num_tokens = snt.shape[0]\n",
    "    max_len = 50\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        memory = model.encode(snt).to(device)\n",
    "    \n",
    "    ys = torch.LongTensor([dataset.trg_vocab['<sos>']]).unsqueeze(0).to(device)\n",
    "\n",
    "    for i in range(max_len):\n",
    "        with torch.no_grad():\n",
    "            out = model.decode(ys, memory)\n",
    "            \n",
    "        out = out.transpose(0, 1)\n",
    "        prob = model.linear(out[:, -1])\n",
    "        next_word = prob.argmax().detach().item()\n",
    "\n",
    "        ys = torch.cat([ys, torch.tensor([next_word]).unsqueeze(1).to(device)])\n",
    "        \n",
    "        \n",
    "        if next_word == dataset.trg_vocab['<eos>']:\n",
    "            break\n",
    "    \n",
    "    return dataset.trg_vocab.lookup_tokens(ys.squeeze().cpu().numpy())\n",
    "\n",
    "\n",
    "def bleu(model, dataset, device):\n",
    "    targets = []\n",
    "    outputs = []\n",
    "\n",
    "    for example in tqdm(dataset):\n",
    "        src = example[\"src\"]\n",
    "        trg = example[\"trg\"]\n",
    "        \n",
    "        trg = dataset.trg_vocab.lookup_tokens(trg)    \n",
    "        prediction = translate(src, dataset, model, device)\n",
    "        \n",
    "        prediction = prediction[1:-1]  # removing <sos> and <eos> tokens\n",
    "        trg = trg[1:-1]\n",
    "        \n",
    "        targets.append([trg])\n",
    "        outputs.append(prediction)\n",
    "\n",
    "    return bleu_score(outputs, targets)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1329 [00:00<?, ?it/s]d:\\软件\\python\\lib\\site-packages\\torch\\nn\\functional.py:4999: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
      "  warnings.warn(\n",
      "100%|██████████| 1329/1329 [01:30<00:00, 14.69it/s]\n",
      "100%|██████████| 30000/30000 [15:42<00:00, 31.82it/s]\n",
      " 87%|████████▋ | 147297/170000 [12:20:42<12:37, 29.97it/s]        "
     ]
    }
   ],
   "source": [
    "min_el = math.inf\n",
    "patience = 1\n",
    "best_model = {}\n",
    "best_epoch = 0\n",
    "\n",
    "epoch_loss = 0\n",
    "for epoch in range(hyp_params[\"num_epochs\"]):\n",
    "    start = time.time()\n",
    "    \n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    \n",
    "    epoch_loss = train_model(model, train_dataloader, criterion, optimizer)\n",
    "    eval_loss = evaluate_model(model, valid_dataloader, criterion)\n",
    "    \n",
    "    # model_1=model.to('cpu')\n",
    "    \n",
    "    vaild_bleu = bleu(model, nmtds_valid, device)\n",
    "    train_bleu = bleu(model, nmtds_train, device)\n",
    "    \n",
    "    log.log(f\"Epoch: {epoch+1}, Train loss: {epoch_loss}, Eval loss: {eval_loss}, patience: {patience}, Train Bleu: {train_bleu},Vaild Bleu: {vaild_bleu}. Time {time.time() - start}\")\n",
    "\n",
    "    \n",
    "    if eval_loss < min_el:\n",
    "        best_epoch = epoch+1\n",
    "        min_el = eval_loss\n",
    "        best_model = copy.deepcopy(model)\n",
    "        torch.save({\n",
    "            'model_state_dict': model.state_dict(),\n",
    "            'optimizer_state_dict': optimizer.state_dict(),\n",
    "            'eval_loss': min_el\n",
    "        }, 'model-transformer.pt')\n",
    "        patience = 1\n",
    "    else:\n",
    "        patience += 1\n",
    "    \n",
    "    if patience == 90:\n",
    "        log.log(\"[STOPPING] Early stopping in action..\")\n",
    "        log.log(f\"Best epoch was {best_epoch} with {min_el} eval loss\")\n",
    "        break\n",
    "        \n",
    "log.log(f\"Best epoch was {best_epoch} with {min_el} eval loss\")\n",
    "log.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
