{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "from collections import Counter\n",
    "from torchtext.vocab import build_vocab_from_iterator\n",
    "from torchtext.data.utils import get_tokenizer\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_punctuation_regex(text):\n",
    "    # 使用正则表达式去除标点符号\n",
    "    cleaned_text = re.sub(r'[^\\w\\s]', '', text)\n",
    "    return cleaned_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # get_tokenizer\n",
    "\n",
    "# en = open('./datasets/EN-DE/train.en', encoding='utf-8').readlines()\n",
    "# de = open('./datasets/EN-DE/train.de', encoding='utf-8').readlines()\n",
    "\n",
    "# def custom_tokenizer(text):\n",
    "#     # 在这里实现你的分词逻辑\n",
    "#     # 返回标记列表\n",
    "#     tokens = text.split()  # 这只是一个示例，将文本按空格分割\n",
    "#     return tokens\n",
    "# tokenizers = { 'en': get_tokenizer(custom_tokenizer), \n",
    "#                     'de': get_tokenizer(custom_tokenizer) }\n",
    "# def _get_tokens(corpus, lang):\n",
    "#     for line in corpus:\n",
    "#         yield tokenizers[lang](remove_punctuation_regex(line.lower().strip()))\n",
    "        \n",
    "# def _build_vocab():\n",
    "#     src_vocab = build_vocab_from_iterator(_get_tokens(en, 'en'), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)  # discarding words occurs < 2 times\n",
    "#     trg_vocab = build_vocab_from_iterator(_get_tokens(de, 'de'), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)\n",
    "  \n",
    "#     trg_vocab.set_default_index(trg_vocab['<unk>'])\n",
    "#     src_vocab.set_default_index(src_vocab['<unk>'])\n",
    "\n",
    "#     return src_vocab, trg_vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_iterator = [\"This is a sample sentence.\", \"Another example sentence.\"]\n",
    "# def text_iterator():\n",
    "#     for text in data_iterator:\n",
    "#         yield text.split()  # 这里简单地将文本按空格分割为单词\n",
    "        \n",
    "# vocab = build_vocab_from_iterator(text_iterator(), min_freq=1, specials=[\"<unk>\", \"<pad>\"])\n",
    "# # indexed_data = [vocab[token] for token in text_iterator()]\n",
    "# for tokens in text_iterator():\n",
    "#     for token in tokens:\n",
    "#         print(vocab[token])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vocab = build_vocab_from_iterator(_get_tokens(en, 'en'), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # [vocab[token] for token in _get_tokens(en, 'en')]\n",
    "# for tokens in _get_tokens(en, 'en'):\n",
    "#     for token in tokens:\n",
    "#         print(vocab[token])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_tokenizer(text):\n",
    "    # 在这里实现你的分词逻辑\n",
    "    # 返回标记列表\n",
    "    tokens = text.split()  # 这只是一个示例，将文本按空格分割\n",
    "    return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class nmtDataset(Dataset):\n",
    "    def __init__(self, src_path, tgt_path,src_name,tgt_name,split, train_ds=None):\n",
    "        self.en = open(src_path, encoding='utf-8').readlines()\n",
    "        self.de = open(tgt_path, encoding='utf-8').readlines()\n",
    "        \n",
    "        self.tokenizers = { src_name: get_tokenizer(custom_tokenizer), \n",
    "                            tgt_name: get_tokenizer(custom_tokenizer) }\n",
    "        self.src_name = src_name\n",
    "        self.tgt_name = tgt_name\n",
    "        if split == 'train':\n",
    "            self.src_vocab, self.trg_vocab = self._build_vocab()\n",
    "        else:\n",
    "            self.src_vocab, self.trg_vocab = train_ds.src_vocab, train_ds.trg_vocab\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.en)\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        src_tokens = self.tokenizers[self.src_name](self.en[item].lower().strip())\n",
    "        trg_tokens = self.tokenizers[self.tgt_name](self.de[item].lower().strip())\n",
    "  \n",
    "        return {\n",
    "            \"src\": [self.src_vocab['<sos>']] + self.src_vocab.lookup_indices(src_tokens) + [self.src_vocab['<eos>']],\n",
    "            \"trg\": [self.trg_vocab['<sos>']] + self.trg_vocab.lookup_indices(trg_tokens) + [self.trg_vocab['<eos>']]\n",
    "        }\n",
    "        \n",
    "    def _build_vocab(self):\n",
    "        src_vocab = build_vocab_from_iterator(self._get_tokens(self.en, self.src_name), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)  # discarding words occurs < 2 times\n",
    "        trg_vocab = build_vocab_from_iterator(self._get_tokens(self.de, self.tgt_name), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)\n",
    "  \n",
    "        trg_vocab.set_default_index(trg_vocab['<unk>'])\n",
    "        src_vocab.set_default_index(src_vocab['<unk>'])\n",
    "\n",
    "        return src_vocab, trg_vocab\n",
    "    \n",
    "    def _get_tokens(self, corpus, lang):\n",
    "        for line in corpus:\n",
    "            yield self.tokenizers[lang](line.lower().strip())\n",
    "            \n",
    "# data = nmtDataset(src_path='./datasets/EN-DE/train.en',tgt_path='./datasets/EN-DE/train.de',src_name='en',tgt_name='de',split='train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import math\n",
    "import time\n",
    "import gc\n",
    "import copy\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.nn import Transformer\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import torchtext\n",
    "from torchtext.data.metrics import bleu_score\n",
    "\n",
    "# import spacy\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "# from dataset import nmtDataset\n",
    "import helpers as utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Seeding for consistency in reproducibility\n",
    "SEED = 1234\n",
    "\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "log = utils.Logger('logs/1-transformers.out')  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "hyp_params = {\n",
    "    \"batch_size\": 128,\n",
    "    \"lr\": 0.0005,\n",
    "    \"num_epochs\": 120,\n",
    "    \n",
    "    # Same as presented in paper\n",
    "    \"d_model\": 512,\n",
    "    \n",
    "    # No. of multi-head attention block (aka paralle self-attention layers)\n",
    "    # Same as presented in paper\n",
    "    \"n_head\": 8,\n",
    "    \n",
    "    # N in the paper and they used 6 of each\n",
    "    \"num_encoder_layers\": 3,\n",
    "    \"num_decoder_layers\": 3,\n",
    "    \n",
    "    \"feedforward_dim\": 128,\n",
    "    \n",
    "    # Following paper\n",
    "    \"dropout\": 0.1\n",
    "}\n",
    "# data = nmtDataset(src_path='./datasets/EN-DE/train.en',tgt_path='./datasets/EN-DE/train.de',src_name='en',tgt_name='de',split='train')\n",
    "nmtds_train = nmtDataset(src_path='./datasets/en-es/train_sub.es',tgt_path='./datasets/en-es/train_sub.en',src_name='es',tgt_name='en',split='train')\n",
    "nmtds_valid = nmtDataset(src_path='./datasets/en-es/valid_sub.es',tgt_path='./datasets/en-es/valid_sub.en',src_name='es',tgt_name='en',split='val',train_ds=nmtds_train)\n",
    "nmtds_test = nmtDataset(src_path='./datasets/en-es/valid_sub.es',tgt_path='./datasets/en-es/valid_sub.en',src_name='es',tgt_name='en',split='val',train_ds=nmtds_train)\n",
    "# nmtds_valid = nmtDataset('datasets/Multi30k/', 'val', nmtds_train)\n",
    "# nmtds_test = nmtDataset('datasets/Multi30k/', 'test', nmtds_train)\n",
    "\n",
    "SRC_PAD_IDX = nmtds_train.src_vocab[\"<pad>\"]\n",
    "TRG_PAD_IDX = nmtds_train.trg_vocab[\"<pad>\"]\n",
    "\n",
    "train_dataloader = DataLoader(nmtds_train, batch_size=hyp_params[\"batch_size\"], shuffle=True,\n",
    "                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, SRC_PAD_IDX, device))\n",
    "\n",
    "valid_dataloader = DataLoader(nmtds_valid, batch_size=hyp_params[\"batch_size\"], shuffle=True,\n",
    "                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, SRC_PAD_IDX, device))\n",
    "\n",
    "hyp_params[\"src_vocab_size\"] = len(nmtds_train.src_vocab)\n",
    "hyp_params[\"trg_vocab_size\"] = len(nmtds_train.trg_vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionalEncoding(nn.Module):\n",
    "    def __init__(self, d_model, dropout, maxlen = 5000):\n",
    "        super(PositionalEncoding, self).__init__()\n",
    "        \n",
    "        # A tensor consists of all the possible positions (index) e.g 0, 1, 2, ... max length of input\n",
    "        # Shape (pos) --> [max len, 1]\n",
    "        pos = torch.arange(0, maxlen).unsqueeze(1)\n",
    "        \n",
    "        pos_encoding = torch.zeros((maxlen, d_model))\n",
    "        \n",
    "        # In the paper, they had 2i in the positional encoding formula\n",
    "        # where i is the dimension \n",
    "        sin_den = 10000 ** (torch.arange(0, d_model, 2)/d_model) # sin for even item of position's dimension\n",
    "        cos_den = 10000 ** (torch.arange(1, d_model, 2)/d_model) # cos for odd \n",
    "        \n",
    "        pos_encoding[:, 0::2] = torch.sin(pos / sin_den) \n",
    "        pos_encoding[:, 1::2] = torch.cos(pos / cos_den)\n",
    "        \n",
    "        # Shape (pos_embedding) --> [max len, d_model]\n",
    "        pos_encoding = pos_encoding.unsqueeze(-2)\n",
    "        # Shape (pos_embedding) --> [max len, 1, d_model]\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "        # We want pos_encoding be saved and restored in the `state_dict`, but not trained by the optimizer\n",
    "        # hence registering it!\n",
    "        # Source & credits: https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723/2\n",
    "        self.register_buffer('pos_encoding', pos_encoding)\n",
    "\n",
    "    def forward(self, token_embedding):\n",
    "        \n",
    "        # shape (token_embedding) --> [sentence len, batch size, d_model]\n",
    "        \n",
    "        # Concatenating embeddings with positional encodings\n",
    "        # Note: As we made positional encoding with the size max length of sentence in our dataset \n",
    "        #       hence here we are picking till the sentence length in a batch\n",
    "        #       Another thing to notice is in the paper they used FIXED positional encoding, there are\n",
    "        #       methods where we can also learn them but we are doing as presented in the paper\n",
    "        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])\n",
    "    \n",
    "class InputEmbedding(nn.Module):\n",
    "    def __init__(self, vocab_size, d_model):\n",
    "        super(InputEmbedding, self).__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(vocab_size, d_model)\n",
    "        self.d_model = d_model\n",
    "\n",
    "    def forward(self, tokens):\n",
    "        # shape (tokens) --> [sentence len, batch size]\n",
    "        # shape (inp_emb) --> [sentence len, batch size, d_model]\n",
    "        # Multiplying with square root of d_model as they mentioned in the paper\n",
    "        inp_emb = self.embedding(tokens.long()) * math.sqrt(self.d_model)\n",
    "        return inp_emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqTransformer(nn.Module):\n",
    "    def __init__(self, \n",
    "                 src_vocab_size, \n",
    "                 trg_vocab_size, \n",
    "                 d_model, \n",
    "                 dropout,\n",
    "                 nhead,\n",
    "                 num_encoder_layers,\n",
    "                 num_decoder_layers,\n",
    "                 dim_feedforward,\n",
    "                 src_pad_idx,\n",
    "                 trg_pad_idx\n",
    "                ):\n",
    "        super(Seq2SeqTransformer, self).__init__()\n",
    "        \n",
    "        self.src_pad_idx = src_pad_idx\n",
    "        self.trg_pad_idx = trg_pad_idx\n",
    "        \n",
    "        self.src_inp_emb = InputEmbedding(src_vocab_size, d_model)\n",
    "        self.trg_inp_emb = InputEmbedding(trg_vocab_size, d_model)\n",
    "        \n",
    "        self.positional_encoding = PositionalEncoding(d_model, dropout=dropout)\n",
    "        \n",
    "        self.transformer = Transformer(d_model=d_model,\n",
    "                                       nhead=nhead,\n",
    "                                       num_encoder_layers=num_encoder_layers,\n",
    "                                       num_decoder_layers=num_decoder_layers,\n",
    "                                       dim_feedforward=dim_feedforward,\n",
    "                                       dropout=dropout)\n",
    "        \n",
    "        self.linear = nn.Linear(d_model, trg_vocab_size)\n",
    "    \n",
    "    def forward(self, src, trg):\n",
    "        src_emb = self.positional_encoding(self.src_inp_emb(src))\n",
    "        trg_emb = self.positional_encoding(self.trg_inp_emb(trg))\n",
    "        \n",
    "        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.create_mask(src, trg)\n",
    "        \n",
    "        outs = self.transformer(src = src_emb, \n",
    "                                tgt = trg_emb, \n",
    "                                src_mask = src_mask,\n",
    "                                tgt_mask = tgt_mask, \n",
    "                                src_key_padding_mask = src_padding_mask, \n",
    "                                tgt_key_padding_mask = tgt_padding_mask,\n",
    "                                memory_key_padding_mask = src_padding_mask\n",
    "                               )\n",
    "        return self.linear(outs)\n",
    "        \n",
    "\n",
    "    def create_mask(self, src, trg):\n",
    "        src_seq_len = src.shape[0]\n",
    "        trg_seq_len = trg.shape[0]\n",
    "\n",
    "        # Subsequent mask aka \"look ahead mask\" is important as it wont let Decoder\n",
    "        # to peek into future tokens.\n",
    "        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).to(device)\n",
    "        \n",
    "        src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool) # All False hence unchanged\n",
    "\n",
    "        # Padding masking will allow attention to ignore padding <pad> tokens \n",
    "        src_padding_mask = (src == self.src_pad_idx).transpose(0, 1)\n",
    "        trg_padding_mask = (trg == self.trg_pad_idx).transpose(0, 1)\n",
    "        \n",
    "        return src_mask, trg_mask, src_padding_mask, trg_padding_mask\n",
    "    \n",
    "    # These two functions will only used while inferring\n",
    "    \n",
    "    def encode(self, src):\n",
    "        src_mask = torch.zeros((src.shape[0], src.shape[0]),device=device).type(torch.bool)\n",
    "        src_padding_mask = (src == self.src_pad_idx).transpose(0, 1)\n",
    "        \n",
    "        return self.transformer.encoder(self.positional_encoding(self.src_inp_emb(src)), src_mask, src_padding_mask)\n",
    "\n",
    "    def decode(self, trg, memory):\n",
    "        # memory is the output from the encoder \n",
    "        trg_seq_len = trg.shape[0]\n",
    "        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).type(torch.bool).to(device)\n",
    "        trg_padding_mask = (trg == self.trg_pad_idx).transpose(0, 1)\n",
    "        \n",
    "        return self.transformer.decoder(tgt = self.positional_encoding(self.trg_inp_emb(trg)), \n",
    "                                        memory = memory,\n",
    "                                        tgt_mask = trg_mask,\n",
    "                                        tgt_key_padding_mask = trg_padding_mask\n",
    "                                       )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, train_dataloader, criterion, optimizer):\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    for batch_idx, batch in enumerate(tqdm(train_dataloader)):\n",
    "        # shape (src, trg) --> [seq len, batch size]\n",
    "        src = batch[\"src\"]\n",
    "        trg = batch[\"trg\"]\n",
    "\n",
    "        # Clear the accumulating gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # shape (trg_inp, trg_out) --> [seq len - 1, batch size]\n",
    "        trg_inp = trg[:-1, :]\n",
    "        trg_out = trg[1:, :]\n",
    "\n",
    "        # shape --> (seq len - 1) * batch size \n",
    "        # Making all target seqeunces in 1d tensor\n",
    "        trg_out = trg_out.reshape(-1)\n",
    "\n",
    "        # shape (logits) --> [seq len - 1, batch size, trg vocab size]\n",
    "        logits = model(src, trg_inp)\n",
    "\n",
    "        # shape (logits) --> [(seq len - 1) * batch size, trg vocab size]\n",
    "        logits = logits.reshape(-1, logits.shape[-1])\n",
    "\n",
    "        loss = criterion(logits, trg_out)\n",
    "\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.detach().cpu()\n",
    "\n",
    "    return epoch_loss/len(train_dataloader)\n",
    "\n",
    "def evaluate_model(model, valid_dataloader, criterion):\n",
    "    model.eval()\n",
    "    epoch_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, batch in enumerate(valid_dataloader):\n",
    "            # shape (src, trg) --> [seq len, batch size]\n",
    "            src = batch[\"src\"]\n",
    "            trg = batch[\"trg\"]\n",
    "\n",
    "            # shape (trg_inp, trg_out) --> [seq len - 1, batch size]\n",
    "            trg_inp = trg[:-1, :]\n",
    "            trg_out = trg[1:, :]\n",
    "\n",
    "            # shape --> (seq len - 1) * batch size \n",
    "            # Making all target seqeunces in 1d tensor\n",
    "            trg_out = trg_out.reshape(-1)\n",
    "\n",
    "            # shape (logits) --> [seq len - 1, batch size, trg vocab size]\n",
    "            logits = model(src, trg_inp)\n",
    "\n",
    "            # shape (logits) --> [(seq len - 1) * batch size, trg vocab size]\n",
    "            logits = logits.reshape(-1, logits.shape[-1])\n",
    "\n",
    "            loss = criterion(logits, trg_out)\n",
    "\n",
    "            epoch_loss += loss.detach().cpu()\n",
    "    \n",
    "    return epoch_loss/len(valid_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[17], line 17\u001b[0m\n\u001b[0;32m     15\u001b[0m \u001b[39mfor\u001b[39;00m p \u001b[39min\u001b[39;00m model\u001b[39m.\u001b[39mparameters():\n\u001b[0;32m     16\u001b[0m     \u001b[39mif\u001b[39;00m p\u001b[39m.\u001b[39mdim() \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m---> 17\u001b[0m         nn\u001b[39m.\u001b[39;49minit\u001b[39m.\u001b[39;49mxavier_uniform_(p)\n\u001b[0;32m     18\u001b[0m model \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39mto(device)\n\u001b[0;32m     19\u001b[0m criterion \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mCrossEntropyLoss(ignore_index\u001b[39m=\u001b[39mTRG_PAD_IDX)\u001b[39m.\u001b[39mto(device)\n",
      "File \u001b[1;32md:\\软件\\python\\lib\\site-packages\\torch\\nn\\init.py:327\u001b[0m, in \u001b[0;36mxavier_uniform_\u001b[1;34m(tensor, gain)\u001b[0m\n\u001b[0;32m    324\u001b[0m std \u001b[39m=\u001b[39m gain \u001b[39m*\u001b[39m math\u001b[39m.\u001b[39msqrt(\u001b[39m2.0\u001b[39m \u001b[39m/\u001b[39m \u001b[39mfloat\u001b[39m(fan_in \u001b[39m+\u001b[39m fan_out))\n\u001b[0;32m    325\u001b[0m a \u001b[39m=\u001b[39m math\u001b[39m.\u001b[39msqrt(\u001b[39m3.0\u001b[39m) \u001b[39m*\u001b[39m std  \u001b[39m# Calculate uniform bounds from standard deviation\u001b[39;00m\n\u001b[1;32m--> 327\u001b[0m \u001b[39mreturn\u001b[39;00m _no_grad_uniform_(tensor, \u001b[39m-\u001b[39;49ma, a)\n",
      "File \u001b[1;32md:\\软件\\python\\lib\\site-packages\\torch\\nn\\init.py:14\u001b[0m, in \u001b[0;36m_no_grad_uniform_\u001b[1;34m(tensor, a, b)\u001b[0m\n\u001b[0;32m     12\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_no_grad_uniform_\u001b[39m(tensor, a, b):\n\u001b[0;32m     13\u001b[0m     \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[1;32m---> 14\u001b[0m         \u001b[39mreturn\u001b[39;00m tensor\u001b[39m.\u001b[39;49muniform_(a, b)\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "model = Seq2SeqTransformer(hyp_params[\"src_vocab_size\"],\n",
    "                                hyp_params[\"trg_vocab_size\"],\n",
    "                                hyp_params[\"d_model\"],\n",
    "                                hyp_params[\"dropout\"],\n",
    "                                hyp_params[\"n_head\"],\n",
    "                                hyp_params[\"num_encoder_layers\"],\n",
    "                                hyp_params[\"num_decoder_layers\"],\n",
    "                                hyp_params[\"feedforward_dim\"],\n",
    "                                SRC_PAD_IDX,\n",
    "                                TRG_PAD_IDX\n",
    "                                ).to(device)\n",
    "\n",
    "# They did not mention it in paper however tutorials from \n",
    "# bentrevett and others uses it so -_(-.-)_-\n",
    "for p in model.parameters():\n",
    "    if p.dim() > 1:\n",
    "        nn.init.xavier_uniform_(p)\n",
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX).to(device)\n",
    "\n",
    "# They did not used fixed learning rate in the paper infact \n",
    "# their optimizer would look like \n",
    "#   optimizer = torch.optim.Adam(transformer.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)\n",
    "# with variable learning rate as they presented in the paper\n",
    "# but for the sake of simplicity and also fixed lr works fine\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyp_params[\"lr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Seq2SeqTransformer(\n",
       "  (src_inp_emb): InputEmbedding(\n",
       "    (embedding): Embedding(50526, 512)\n",
       "  )\n",
       "  (trg_inp_emb): InputEmbedding(\n",
       "    (embedding): Embedding(36418, 512)\n",
       "  )\n",
       "  (positional_encoding): PositionalEncoding(\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (transformer): Transformer(\n",
       "    (encoder): TransformerEncoder(\n",
       "      (layers): ModuleList(\n",
       "        (0-2): 3 x TransformerEncoderLayer(\n",
       "          (self_attn): MultiheadAttention(\n",
       "            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n",
       "          )\n",
       "          (linear1): Linear(in_features=512, out_features=128, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "          (linear2): Linear(in_features=128, out_features=512, bias=True)\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (dropout1): Dropout(p=0.1, inplace=False)\n",
       "          (dropout2): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (decoder): TransformerDecoder(\n",
       "      (layers): ModuleList(\n",
       "        (0-2): 3 x TransformerDecoderLayer(\n",
       "          (self_attn): MultiheadAttention(\n",
       "            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n",
       "          )\n",
       "          (multihead_attn): MultiheadAttention(\n",
       "            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n",
       "          )\n",
       "          (linear1): Linear(in_features=512, out_features=128, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "          (linear2): Linear(in_features=128, out_features=512, bias=True)\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (dropout1): Dropout(p=0.1, inplace=False)\n",
       "          (dropout2): Dropout(p=0.1, inplace=False)\n",
       "          (dropout3): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "  )\n",
       "  (linear): Linear(in_features=512, out_features=36418, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(snt, dataset, model, device):\n",
    "    snt = torch.tensor(snt).view(-1,1).to(device)\n",
    "    \n",
    "    num_tokens = snt.shape[0]\n",
    "    max_len = 50\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        memory = model.encode(snt).to(device)\n",
    "    \n",
    "    ys = torch.LongTensor([dataset.trg_vocab['<sos>']]).unsqueeze(0).to(device)\n",
    "\n",
    "    for i in range(max_len):\n",
    "        with torch.no_grad():\n",
    "            out = model.decode(ys, memory)\n",
    "            \n",
    "        out = out.transpose(0, 1)\n",
    "        prob = model.linear(out[:, -1])\n",
    "        next_word = prob.argmax().detach().item()\n",
    "\n",
    "        ys = torch.cat([ys, torch.tensor([next_word]).unsqueeze(1).to(device)])\n",
    "        \n",
    "        \n",
    "        if next_word == dataset.trg_vocab['<eos>']:\n",
    "            break\n",
    "    \n",
    "    return dataset.trg_vocab.lookup_tokens(ys.squeeze().cpu().numpy())\n",
    "\n",
    "\n",
    "def bleu(model, dataset, device):\n",
    "    targets = []\n",
    "    outputs = []\n",
    "\n",
    "    for example in tqdm(dataset):\n",
    "        src = example[\"src\"]\n",
    "        trg = example[\"trg\"]\n",
    "        \n",
    "        trg = dataset.trg_vocab.lookup_tokens(trg)    \n",
    "        prediction = translate(src, dataset, model, device)\n",
    "        \n",
    "        prediction = prediction[1:-1]  # removing <sos> and <eos> tokens\n",
    "        trg = trg[1:-1]\n",
    "        \n",
    "        targets.append([trg])\n",
    "        outputs.append(prediction)\n",
    "\n",
    "    return bleu_score(outputs, targets)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/968 [00:00<?, ?it/s]d:\\软件\\python\\lib\\site-packages\\torch\\nn\\functional.py:4999: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
      "  warnings.warn(\n",
      " 10%|▉         | 92/968 [05:08<48:54,  3.35s/it]  \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[20], line 13\u001b[0m\n\u001b[0;32m     10\u001b[0m gc\u001b[39m.\u001b[39mcollect()\n\u001b[0;32m     11\u001b[0m torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mempty_cache()\n\u001b[1;32m---> 13\u001b[0m epoch_loss \u001b[39m=\u001b[39m train_model(model, train_dataloader, criterion, optimizer)\n\u001b[0;32m     14\u001b[0m eval_loss \u001b[39m=\u001b[39m evaluate_model(model, valid_dataloader, criterion)\n\u001b[0;32m     15\u001b[0m vaild_bleu \u001b[39m=\u001b[39m bleu(model, nmtds_valid, device)\n",
      "Cell \u001b[1;32mIn[16], line 28\u001b[0m, in \u001b[0;36mtrain_model\u001b[1;34m(model, train_dataloader, criterion, optimizer)\u001b[0m\n\u001b[0;32m     24\u001b[0m logits \u001b[39m=\u001b[39m logits\u001b[39m.\u001b[39mreshape(\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, logits\u001b[39m.\u001b[39mshape[\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m])\n\u001b[0;32m     26\u001b[0m loss \u001b[39m=\u001b[39m criterion(logits, trg_out)\n\u001b[1;32m---> 28\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[0;32m     30\u001b[0m optimizer\u001b[39m.\u001b[39mstep()\n\u001b[0;32m     31\u001b[0m epoch_loss \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m loss\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\n",
      "File \u001b[1;32md:\\软件\\python\\lib\\site-packages\\torch\\_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m    477\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[0;32m    478\u001b[0m     \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m    479\u001b[0m         Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[0;32m    480\u001b[0m         (\u001b[39mself\u001b[39m,),\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    485\u001b[0m         inputs\u001b[39m=\u001b[39minputs,\n\u001b[0;32m    486\u001b[0m     )\n\u001b[1;32m--> 487\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[0;32m    488\u001b[0m     \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[0;32m    489\u001b[0m )\n",
      "File \u001b[1;32md:\\软件\\python\\lib\\site-packages\\torch\\autograd\\__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m    195\u001b[0m     retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[0;32m    197\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[0;32m    198\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m    199\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 200\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward(  \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[0;32m    201\u001b[0m     tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[0;32m    202\u001b[0m     allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "min_el = math.inf\n",
    "patience = 1\n",
    "best_model = {}\n",
    "best_epoch = 0\n",
    "\n",
    "epoch_loss = 0\n",
    "for epoch in range(hyp_params[\"num_epochs\"]):\n",
    "    start = time.time()\n",
    "    \n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    \n",
    "    epoch_loss = train_model(model, train_dataloader, criterion, optimizer)\n",
    "    eval_loss = evaluate_model(model, valid_dataloader, criterion)\n",
    "    vaild_bleu = bleu(model, nmtds_valid, device)\n",
    "    train_bleu = bleu(model, nmtds_train, device)\n",
    "    log.log(f\"Epoch: {epoch+1}, Train loss: {epoch_loss}, Eval loss: {eval_loss}, patience: {patience}, Train Bleu: {train_bleu},Vaild Bleu: {vaild_bleu}. Time {time.time() - start}\")\n",
    "\n",
    "    \n",
    "    # if eval_loss < min_el:\n",
    "    #     best_epoch = epoch+1\n",
    "    #     min_el = eval_loss\n",
    "    #     best_model = copy.deepcopy(model)\n",
    "    #     torch.save({\n",
    "    #         'model_state_dict': model.state_dict(),\n",
    "    #         'optimizer_state_dict': optimizer.state_dict(),\n",
    "    #         'eval_loss': min_el\n",
    "    #     }, 'model-transformer.pt')\n",
    "    #     patience = 1\n",
    "    # else:\n",
    "    #     patience += 1\n",
    "    \n",
    "    # if patience == 10:\n",
    "    #     log.log(\"[STOPPING] Early stopping in action..\")\n",
    "    #     log.log(f\"Best epoch was {best_epoch} with {min_el} eval loss\")\n",
    "    #     break\n",
    "        \n",
    "log.log(f\"Best epoch was {best_epoch} with {min_el} eval loss\")\n",
    "log.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
