{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.6"},"colab":{"name":"LiSHT_Seq2Seq-MachineTranslation.ipynb","provenance":[],"collapsed_sections":[]},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"7F-sOZwhKfun"},"source":["Hey Guys , I have made a Tutorial for Seq2Seq Machine Translation from scratch. Hope you like it. Upvote :)"]},{"cell_type":"code","metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","id":"ykbPZi7zKfuo","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1606522703747,"user_tz":-330,"elapsed":4176,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}},"outputId":"9c7270fc-db42-464d-8824-a72995284e6d"},"source":["# Import Libraries\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","from torchtext.datasets import Multi30k #German to English dataset\n","from torchtext.data import Field, BucketIterator\n","import numpy as np\n","import spacy\n","import random\n","from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard\n","import torch\n","import spacy\n","!pip install torchtext==0.6.0\n","import torchtext.data\n","from torchtext.data.metrics import bleu_score\n","import sys\n","from torch import Tensor\n","from torch.nn.init import xavier_uniform_\n","from torch.nn.init import constant_\n","from torch.nn.init import xavier_normal_\n","from torch.nn.parameter import Parameter"],"execution_count":1,"outputs":[{"output_type":"stream","text":["Requirement already satisfied: torchtext==0.6.0 in /usr/local/lib/python3.6/dist-packages (0.6.0)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (4.41.1)\n","Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (1.15.0)\n","Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (0.1.94)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (1.18.5)\n","Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (1.7.0+cu101)\n","Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (2.23.0)\n","Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->torchtext==0.6.0) (0.16.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->torchtext==0.6.0) (3.7.4.3)\n","Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->torchtext==0.6.0) (0.8)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (3.0.4)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (2020.11.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (1.24.3)\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"y1fGgVBDKfus"},"source":["![image.png](attachment:image.png)"]},{"cell_type":"code","metadata":{"_kg_hide-output":true,"collapsed":true,"id":"-yeVJvpPKfus","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1606522711244,"user_tz":-330,"elapsed":11663,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}},"outputId":"adc3f735-03a8-4ebd-b45b-314324efb15c"},"source":["!python -m spacy download de"],"execution_count":2,"outputs":[{"output_type":"stream","text":["Collecting de_core_news_sm==2.2.5\n","\u001b[?25l  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.2.5/de_core_news_sm-2.2.5.tar.gz (14.9MB)\n","\u001b[K     |████████████████████████████████| 14.9MB 786kB/s \n","\u001b[?25hRequirement already satisfied: spacy>=2.2.2 in /usr/local/lib/python3.6/dist-packages (from de_core_news_sm==2.2.5) (2.2.4)\n","Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (0.8.0)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (50.3.2)\n","Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.18.5)\n","Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (3.0.4)\n","Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (7.4.0)\n","Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.1.3)\n","Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (2.23.0)\n","Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (2.0.4)\n","Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.4)\n","Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.4)\n","Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (4.41.1)\n","Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (0.4.1)\n","Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.0)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (1.24.3)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (2020.11.8)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.0.4)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (2.10)\n","Requirement already satisfied: importlib-metadata>=0.20; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (2.0.0)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.4.0)\n","Building wheels for collected packages: de-core-news-sm\n","  Building wheel for de-core-news-sm (setup.py) ... \u001b[?25l\u001b[?25hdone\n","  Created wheel for de-core-news-sm: filename=de_core_news_sm-2.2.5-cp36-none-any.whl size=14907056 sha256=b21881cc19e7922bb6c66787da7a655f07ff12b62b33070db759c8cada6d4b89\n","  Stored in directory: /tmp/pip-ephem-wheel-cache-0di8427y/wheels/ba/3f/ed/d4aa8e45e7191b7f32db4bfad565e7da1edbf05c916ca7a1ca\n","Successfully built de-core-news-sm\n","Installing collected packages: de-core-news-sm\n","Successfully installed de-core-news-sm-2.2.5\n","\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n","You can now load the model via spacy.load('de_core_news_sm')\n","\u001b[38;5;2m✔ Linking successful\u001b[0m\n","/usr/local/lib/python3.6/dist-packages/de_core_news_sm -->\n","/usr/local/lib/python3.6/dist-packages/spacy/data/de\n","You can now load the model via spacy.load('de')\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"_cell_guid":"79c7e3d0-c299-4dcb-8224-4455121ee9b0","_uuid":"d629ff2d2480ee46fbb7e2d37f6b5fab8052498a","id":"vc95bSrWKfuw","executionInfo":{"status":"ok","timestamp":1606522713444,"user_tz":-330,"elapsed":13859,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Loading Tokeniser in german and English\n","spacy_ger = spacy.load('de')\n","spacy_eng = spacy.load('en')"],"execution_count":3,"outputs":[]},{"cell_type":"code","metadata":{"id":"JinX1f6QKfuz","executionInfo":{"status":"ok","timestamp":1606522713451,"user_tz":-330,"elapsed":13861,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Tokenization of German Language\n","def tokenize_ger(text):\n","    return [tok.text for tok in spacy_ger.tokenizer(text)]"],"execution_count":4,"outputs":[]},{"cell_type":"code","metadata":{"id":"Avd2uijgKfu2","executionInfo":{"status":"ok","timestamp":1606522713453,"user_tz":-330,"elapsed":13860,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Tokenization of English Language\n","\n","def tokenize_eng(text):\n","    return [tok.text for tok in spacy_eng.tokenizer(text)]"],"execution_count":5,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-zP-N5ToKfu6"},"source":["## Preprocessing of Text"]},{"cell_type":"code","metadata":{"id":"o8t7cvVcKfu7","executionInfo":{"status":"ok","timestamp":1606522713454,"user_tz":-330,"elapsed":13857,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Applyling Tokenization , lowercase and special Tokens for preprocessing\n","german = Field(tokenize = tokenize_ger,lower = True,init_token = '<sos>',eos_token = '<eos>')"],"execution_count":6,"outputs":[]},{"cell_type":"code","metadata":{"id":"9PkvjAIqKfu-","executionInfo":{"status":"ok","timestamp":1606522713455,"user_tz":-330,"elapsed":13855,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["english = Field(tokenize = tokenize_eng,lower = True,init_token = '<sos>',eos_token = '<eos>')"],"execution_count":7,"outputs":[]},{"cell_type":"code","metadata":{"_kg_hide-output":false,"id":"s-1M5hXkKfvB","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1606522723163,"user_tz":-330,"elapsed":23560,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}},"outputId":"9518566a-0bdc-4964-81de-4e98048e5525"},"source":["# Dwonloading Dataset and storing them\n","train_data, valid_data, test_data = Multi30k.splits(\n","    exts=(\".de\", \".en\"), fields=(german, english)\n",")"],"execution_count":8,"outputs":[{"output_type":"stream","text":["downloading training.tar.gz\n"],"name":"stdout"},{"output_type":"stream","text":["training.tar.gz: 100%|██████████| 1.21M/1.21M [00:02<00:00, 602kB/s]\n"],"name":"stderr"},{"output_type":"stream","text":["downloading validation.tar.gz\n"],"name":"stdout"},{"output_type":"stream","text":["validation.tar.gz: 100%|██████████| 46.3k/46.3k [00:00<00:00, 174kB/s]\n"],"name":"stderr"},{"output_type":"stream","text":["downloading mmt_task1_test2016.tar.gz\n"],"name":"stdout"},{"output_type":"stream","text":["mmt_task1_test2016.tar.gz: 100%|██████████| 66.2k/66.2k [00:00<00:00, 162kB/s]\n"],"name":"stderr"}]},{"cell_type":"code","metadata":{"id":"fJF0wjBmKfvE","executionInfo":{"status":"ok","timestamp":1606522723167,"user_tz":-330,"elapsed":23560,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Creating vocabulary in each language\n","german.build_vocab(train_data,max_size = 10000,min_freq = 2)\n","english.build_vocab(train_data,max_size = 10000,min_freq = 2)\n"],"execution_count":9,"outputs":[]},{"cell_type":"code","metadata":{"id":"qqR3q-9C3a7N","executionInfo":{"status":"ok","timestamp":1606522723168,"user_tz":-330,"elapsed":23557,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["class LiSHT(torch.nn.Module):\n","     def forward(self, input: Tensor) -> Tensor:\n","        return input * torch.tanh(input)"],"execution_count":10,"outputs":[]},{"cell_type":"code","metadata":{"id":"96Sm0iJJKfvG","executionInfo":{"status":"ok","timestamp":1606522723168,"user_tz":-330,"elapsed":23554,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["\n","# Defining the Encoder part of the model\n","class Encoder(nn.Module):\n","    \n","    def __init__(self, input_size, embedding_size, hidden_size, num_layers, p):\n","        super(Encoder, self).__init__()\n","        self.dropout = nn.Dropout(p)\n","        self.hidden_size = hidden_size\n","        self.num_layers = num_layers\n","\n","        self.embedding = nn.Embedding(input_size, embedding_size)\n","        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=p)\n","        self.lisht = LiSHT()\n","        \n","    def forward(self, x):\n","        # x shape: (seq_length, N) where N is batch size\n","\n","        embedding = self.dropout(self.lisht(self.embedding(x)))\n","        # embedding shape: (seq_length, N, embedding_size)\n","\n","        outputs, (hidden, cell) = self.rnn(embedding)\n","        # outputs shape: (seq_length, N, hidden_size)\n","\n","        return hidden, cell"],"execution_count":11,"outputs":[]},{"cell_type":"code","metadata":{"id":"aiyK-YsTKfvJ","executionInfo":{"status":"ok","timestamp":1606522723703,"user_tz":-330,"elapsed":24086,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Defining the Decoder part\n","\n","class Decoder(nn.Module):\n","    def __init__(\n","        self, input_size, embedding_size, hidden_size, output_size, num_layers, p):\n","        super(Decoder, self).__init__()\n","        self.dropout = nn.Dropout(p)\n","        self.hidden_size = hidden_size\n","        self.num_layers = num_layers\n","\n","        self.embedding = nn.Embedding(input_size, embedding_size)\n","        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=p)\n","        self.fc = nn.Linear(hidden_size, output_size)\n","        self.lisht = LiSHT()\n","        \n","    def forward(self, x, hidden, cell):\n","        # x shape: (N) where N is for batch size, we want it to be (1, N), seq_length\n","        # is 1 here because we are sending in a single word and not a sentence\n","        x = x.unsqueeze(0)\n","\n","        embedding = self.dropout(self.lisht(self.embedding(x)))\n","        # embedding shape: (1, N, embedding_size)\n","\n","        outputs, (hidden, cell) = self.rnn(embedding, (hidden, cell))\n","        # outputs shape: (1, N, hidden_size)\n","\n","        predictions = self.fc(outputs)\n","\n","        # predictions shape: (1, N, length_target_vocabulary) to send it to\n","        # loss function we want it to be (N, length_target_vocabulary) so we're\n","        # just gonna remove the first dim\n","        predictions = predictions.squeeze(0)\n","\n","        return predictions, hidden, cell"],"execution_count":12,"outputs":[]},{"cell_type":"code","metadata":{"id":"mdJJXlGyKfvM","executionInfo":{"status":"ok","timestamp":1606522723705,"user_tz":-330,"elapsed":24084,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Defining the complete model\n","class Seq2Seq(nn.Module):\n","    def __init__(self, encoder, decoder):\n","        super(Seq2Seq, self).__init__()\n","        self.encoder = encoder\n","        self.decoder = decoder\n","\n","    def forward(self, source, target, teacher_force_ratio=0.5):\n","        batch_size = source.shape[1]\n","        target_len = target.shape[0]\n","        target_vocab_size = len(english.vocab)\n","\n","        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(device)\n","\n","        hidden, cell = self.encoder(source)\n","\n","        # Grab the first input to the Decoder which will be <SOS> token\n","        x = target[0]\n","\n","        for t in range(1, target_len):\n","            # Use previous hidden, cell as context from encoder at start\n","            output, hidden, cell = self.decoder(x, hidden, cell)\n","\n","            # Store next output prediction\n","            outputs[t] = output\n","\n","            # Get the best word the Decoder predicted (index in the vocabulary)\n","            best_guess = output.argmax(1)\n","\n","            # With probability of teacher_force_ratio we take the actual next word\n","            # otherwise we take the word that the Decoder predicted it to be.\n","            # Teacher Forcing is used so that the model gets used to seeing\n","            # similar inputs at training and testing time, if teacher forcing is 1\n","            # then inputs at test time might be completely different than what the\n","            # network is used to. This was a long comment.\n","            x = target[t] if random.random() < teacher_force_ratio else best_guess\n","\n","        return outputs"],"execution_count":13,"outputs":[]},{"cell_type":"code","metadata":{"id":"5ziLyXFxKfvP","executionInfo":{"status":"ok","timestamp":1606522723707,"user_tz":-330,"elapsed":24083,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Hyperparameters\n","num_epochs = 50\n","learning_rate = 0.001\n","batch_size = 256\n"],"execution_count":14,"outputs":[]},{"cell_type":"code","metadata":{"id":"Vr3ZyNckKfvS","executionInfo":{"status":"ok","timestamp":1606522723709,"user_tz":-330,"elapsed":24082,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Model hyperparameters\n","load_model = False\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else 'cpu')\n","input_size_encoder = len(german.vocab)\n","input_size_decoder = len(english.vocab)\n","output_size = len(english.vocab)\n","encoder_embedding_size = 300\n","decoder_embedding_size = 300\n","\n","hidden_size = 1024\n","num_layers = 1\n","enc_dropout = 0.5\n","dec_dropout = 0.5\n"],"execution_count":15,"outputs":[]},{"cell_type":"code","metadata":{"id":"6DZ7zj8LKfvU","executionInfo":{"status":"ok","timestamp":1606522725107,"user_tz":-330,"elapsed":25477,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["# Tensorboard to get nice loss plot\n","writer = SummaryWriter(f'runs/Loss_plot')\n","step = 0"],"execution_count":16,"outputs":[]},{"cell_type":"code","metadata":{"id":"z6g5kRsiKfvX","executionInfo":{"status":"ok","timestamp":1606522725108,"user_tz":-330,"elapsed":25474,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["train_iterator, validation_iterator, test_iterator = BucketIterator.splits(\n","    (train_data, valid_data, test_data),\n","     batch_size = batch_size, sort_within_batch = True, \n","     sort_key = lambda x:len(x.src),\n","     device = device)"],"execution_count":17,"outputs":[]},{"cell_type":"code","metadata":{"id":"lwapcC4mKfva","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1606522739864,"user_tz":-330,"elapsed":40227,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}},"outputId":"53bc4fe0-b822-4d5c-9bd7-b1a2952e6f3a"},"source":["encoder_net = Encoder(input_size_encoder, \n","                      encoder_embedding_size,\n","                      hidden_size,num_layers, \n","                      enc_dropout).to(device)\n","\n","\n","decoder_net = Decoder(input_size_decoder, \n","                      decoder_embedding_size,\n","                      hidden_size,output_size,num_layers, \n","                      dec_dropout).to(device)"],"execution_count":18,"outputs":[{"output_type":"stream","text":["/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:61: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n","  \"num_layers={}\".format(dropout, num_layers))\n"],"name":"stderr"}]},{"cell_type":"code","metadata":{"id":"Liy7CquOKfvd","executionInfo":{"status":"ok","timestamp":1606522739870,"user_tz":-330,"elapsed":40229,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["model = Seq2Seq(encoder_net, decoder_net).to(device)\n","optimizer = optim.Adam(model.parameters(), lr=learning_rate)"],"execution_count":19,"outputs":[]},{"cell_type":"code","metadata":{"id":"7ex-6pH0Kfvf","executionInfo":{"status":"ok","timestamp":1606522739871,"user_tz":-330,"elapsed":40227,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["pad_idx = english.vocab.stoi['<pad>']\n","criterion = nn.CrossEntropyLoss(ignore_index = pad_idx)\n","\n"],"execution_count":20,"outputs":[]},{"cell_type":"code","metadata":{"id":"f6ioRsy5Kfvh","executionInfo":{"status":"ok","timestamp":1606522739872,"user_tz":-330,"elapsed":40225,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["def translate_sentence(model, sentence, german, english, device, max_length=50):\n","    # print(sentence)\n","\n","    # sys.exit()\n","\n","    # Load german tokenizer\n","    spacy_ger = spacy.load(\"de\")\n","\n","    # Create tokens using spacy and everything in lower case (which is what our vocab is)\n","    if type(sentence) == str:\n","        tokens = [token.text.lower() for token in spacy_ger(sentence)]\n","    else:\n","        tokens = [token.lower() for token in sentence]\n","\n","    # print(tokens)\n","\n","    # sys.exit()\n","    # Add <SOS> and <EOS> in beginning and end respectively\n","    tokens.insert(0, german.init_token)\n","    tokens.append(german.eos_token)\n","\n","    # Go through each german token and convert to an index\n","    text_to_indices = [german.vocab.stoi[token] for token in tokens]\n","\n","    # Convert to Tensor\n","    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)\n","\n","    # Build encoder hidden, cell state\n","    with torch.no_grad():\n","        hidden, cell = model.encoder(sentence_tensor)\n","\n","    outputs = [english.vocab.stoi[\"<sos>\"]]\n","\n","    for _ in range(max_length):\n","        previous_word = torch.LongTensor([outputs[-1]]).to(device)\n","\n","        with torch.no_grad():\n","            output, hidden, cell = model.decoder(previous_word, hidden, cell)\n","            best_guess = output.argmax(1).item()\n","\n","        outputs.append(best_guess)\n","\n","        # Model predicts it's the end of the sentence\n","        if output.argmax(1).item() == english.vocab.stoi[\"<eos>\"]:\n","            break\n","\n","    translated_sentence = [english.vocab.itos[idx] for idx in outputs]\n","\n","    # remove start token\n","    return translated_sentence[1:]"],"execution_count":21,"outputs":[]},{"cell_type":"code","metadata":{"id":"rm9yIYvIKfvj","executionInfo":{"status":"ok","timestamp":1606522739873,"user_tz":-330,"elapsed":40222,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["def save_checkpoint(state, filename=\"my_checkpoint.pth.tar\"):\n","    print(\"=> Saving checkpoint\")\n","    torch.save(state, filename)\n"],"execution_count":22,"outputs":[]},{"cell_type":"code","metadata":{"id":"WbNauSlDKfvm","executionInfo":{"status":"ok","timestamp":1606522739875,"user_tz":-330,"elapsed":40220,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["def load_checkpoint(checkpoint, model, optimizer):\n","    print(\"=> Loading checkpoint\")\n","    model.load_state_dict(checkpoint[\"state_dict\"])\n","    optimizer.load_state_dict(checkpoint[\"optimizer\"])"],"execution_count":23,"outputs":[]},{"cell_type":"code","metadata":{"id":"CiNGFIGGKfvo","executionInfo":{"status":"ok","timestamp":1606522739875,"user_tz":-330,"elapsed":40217,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["if load_model:\n","    load_checkpoint(torch.load(\"my_checkpoint.pth.tar\"), model, optimizer)\n"],"execution_count":24,"outputs":[]},{"cell_type":"code","metadata":{"id":"rOiYTY4DKfvq","executionInfo":{"status":"ok","timestamp":1606522739876,"user_tz":-330,"elapsed":40215,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["sentence = \"Cristiano Ronaldo ist ein großartiger Fußballspieler mit erstaunlichen Fähigkeiten und Talenten.\"\n"],"execution_count":25,"outputs":[]},{"cell_type":"code","metadata":{"id":"ujHJNl4oKfvt","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1606524188632,"user_tz":-330,"elapsed":1488967,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}},"outputId":"63a33344-0254-49d3-a90a-059ae0673502"},"source":["for epoch in range(num_epochs):\n","    print(f\"[Epoch {epoch} / {num_epochs}]\")\n","\n","    checkpoint = {\"state_dict\": model.state_dict(), \"optimizer\": optimizer.state_dict()}\n","    save_checkpoint(checkpoint)\n","\n","    model.eval()\n","\n","    #translated_sentence = translate_sentence(\n","    #    model, sentence, german, english, device, max_length=50\n","    #)\n","\n","    #print(f\"Translated example sentence: \\n {translated_sentence}\")\n","\n","    model.train()\n","\n","    for batch_idx, batch in enumerate(train_iterator):\n","        # Get input and targets and get to cuda\n","        inp_data = batch.src.to(device)\n","        target = batch.trg.to(device)\n","\n","        # Forward prop\n","        output = model(inp_data, target)\n","\n","        # Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss\n","        # doesn't take input in that form. For example if we have MNIST we want to have\n","        # output to be: (N, 10) and targets just (N). Here we can view it in a similar\n","        # way that we have output_words * batch_size that we want to send in into\n","        # our cost function, so we need to do some reshapin. While we're at it\n","        # Let's also remove the start token while we're at it\n","        output = output[1:].reshape(-1, output.shape[2])\n","        target = target[1:].reshape(-1)\n","\n","        optimizer.zero_grad()\n","        loss = criterion(output, target)\n","\n","        # Back prop\n","        loss.backward()\n","\n","        # Clip to avoid exploding gradient issues, makes sure grads are\n","        # within a healthy range\n","        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n","\n","        # Gradient descent step\n","        optimizer.step()\n","\n","        # Plot to tensorboard\n","        writer.add_scalar(\"Training loss\", loss, global_step=step)\n","        step += 1"],"execution_count":26,"outputs":[{"output_type":"stream","text":["[Epoch 0 / 50]\n","=> Saving checkpoint\n","[Epoch 1 / 50]\n","=> Saving checkpoint\n","[Epoch 2 / 50]\n","=> Saving checkpoint\n","[Epoch 3 / 50]\n","=> Saving checkpoint\n","[Epoch 4 / 50]\n","=> Saving checkpoint\n","[Epoch 5 / 50]\n","=> Saving checkpoint\n","[Epoch 6 / 50]\n","=> Saving checkpoint\n","[Epoch 7 / 50]\n","=> Saving checkpoint\n","[Epoch 8 / 50]\n","=> Saving checkpoint\n","[Epoch 9 / 50]\n","=> Saving checkpoint\n","[Epoch 10 / 50]\n","=> Saving checkpoint\n","[Epoch 11 / 50]\n","=> Saving checkpoint\n","[Epoch 12 / 50]\n","=> Saving checkpoint\n","[Epoch 13 / 50]\n","=> Saving checkpoint\n","[Epoch 14 / 50]\n","=> Saving checkpoint\n","[Epoch 15 / 50]\n","=> Saving checkpoint\n","[Epoch 16 / 50]\n","=> Saving checkpoint\n","[Epoch 17 / 50]\n","=> Saving checkpoint\n","[Epoch 18 / 50]\n","=> Saving checkpoint\n","[Epoch 19 / 50]\n","=> Saving checkpoint\n","[Epoch 20 / 50]\n","=> Saving checkpoint\n","[Epoch 21 / 50]\n","=> Saving checkpoint\n","[Epoch 22 / 50]\n","=> Saving checkpoint\n","[Epoch 23 / 50]\n","=> Saving checkpoint\n","[Epoch 24 / 50]\n","=> Saving checkpoint\n","[Epoch 25 / 50]\n","=> Saving checkpoint\n","[Epoch 26 / 50]\n","=> Saving checkpoint\n","[Epoch 27 / 50]\n","=> Saving checkpoint\n","[Epoch 28 / 50]\n","=> Saving checkpoint\n","[Epoch 29 / 50]\n","=> Saving checkpoint\n","[Epoch 30 / 50]\n","=> Saving checkpoint\n","[Epoch 31 / 50]\n","=> Saving checkpoint\n","[Epoch 32 / 50]\n","=> Saving checkpoint\n","[Epoch 33 / 50]\n","=> Saving checkpoint\n","[Epoch 34 / 50]\n","=> Saving checkpoint\n","[Epoch 35 / 50]\n","=> Saving checkpoint\n","[Epoch 36 / 50]\n","=> Saving checkpoint\n","[Epoch 37 / 50]\n","=> Saving checkpoint\n","[Epoch 38 / 50]\n","=> Saving checkpoint\n","[Epoch 39 / 50]\n","=> Saving checkpoint\n","[Epoch 40 / 50]\n","=> Saving checkpoint\n","[Epoch 41 / 50]\n","=> Saving checkpoint\n","[Epoch 42 / 50]\n","=> Saving checkpoint\n","[Epoch 43 / 50]\n","=> Saving checkpoint\n","[Epoch 44 / 50]\n","=> Saving checkpoint\n","[Epoch 45 / 50]\n","=> Saving checkpoint\n","[Epoch 46 / 50]\n","=> Saving checkpoint\n","[Epoch 47 / 50]\n","=> Saving checkpoint\n","[Epoch 48 / 50]\n","=> Saving checkpoint\n","[Epoch 49 / 50]\n","=> Saving checkpoint\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"Ir5lIeOwKfvv","executionInfo":{"status":"ok","timestamp":1606524188635,"user_tz":-330,"elapsed":1488966,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["def bleu(data, model, german, english, device):\n","    targets = []\n","    outputs = []\n","\n","    for example in data:\n","        src = vars(example)[\"src\"]\n","        trg = vars(example)[\"trg\"]\n","\n","        prediction = translate_sentence(model, src, german, english, device)\n","        prediction = prediction[:-1]  # remove <eos> token\n","\n","        targets.append([trg])\n","        outputs.append(prediction)\n","\n","    return bleu_score(outputs, targets)"],"execution_count":27,"outputs":[]},{"cell_type":"code","metadata":{"id":"hvYMk4hYKfvy","executionInfo":{"status":"ok","timestamp":1606524188636,"user_tz":-330,"elapsed":1488964,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":["import torch\n","import spacy\n","from torchtext.data.metrics import bleu_score\n","import sys"],"execution_count":28,"outputs":[]},{"cell_type":"code","metadata":{"id":"f-bW5z8oKfv0","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1606524296781,"user_tz":-330,"elapsed":1597103,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}},"outputId":"b3bc3231-61e2-4b7c-b866-78749f385502"},"source":["\n","score = bleu(test_data[1:100], model, german, english, device)\n","print(f\"Bleu score {score*100:.2f}\")"],"execution_count":29,"outputs":[{"output_type":"stream","text":["Bleu score 18.95\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"-PN1sqf5Kfv2","executionInfo":{"status":"ok","timestamp":1606524296782,"user_tz":-330,"elapsed":1597100,"user":{"displayName":"Shiv Ram Dubey","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64","userId":"14553895138990175535"}}},"source":[""],"execution_count":29,"outputs":[]}]}