{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7F-sOZwhKfun"
   },
   "source": [
    "Hey Guys , I have made a Tutorial for Seq2Seq Machine Translation from scratch. Hope you like it. Upvote :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3879,
     "status": "ok",
     "timestamp": 1606478087201,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "ykbPZi7zKfuo",
    "outputId": "2361c13e-4953-435c-d36f-f273e8f3aa93"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://pypi.douban.com/simple/\n",
      "Requirement already satisfied: torchtext==0.6.0 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (0.6.0)\n",
      "Requirement already satisfied: tqdm in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (4.66.5)\n",
      "Requirement already satisfied: requests in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (2.32.3)\n",
      "Requirement already satisfied: torch in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (2.4.0)\n",
      "Requirement already satisfied: numpy in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (1.24.4)\n",
      "Requirement already satisfied: six in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (1.16.0)\n",
      "Requirement already satisfied: sentencepiece in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (0.2.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from requests->torchtext==0.6.0) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from requests->torchtext==0.6.0) (3.8)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from requests->torchtext==0.6.0) (2.2.2)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from requests->torchtext==0.6.0) (2024.8.30)\n",
      "Requirement already satisfied: filelock in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torch->torchtext==0.6.0) (3.16.0)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torch->torchtext==0.6.0) (4.12.2)\n",
      "Requirement already satisfied: sympy in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torch->torchtext==0.6.0) (1.13.1)\n",
      "Requirement already satisfied: networkx in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torch->torchtext==0.6.0) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torch->torchtext==0.6.0) (3.1.4)\n",
      "Requirement already satisfied: fsspec in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torch->torchtext==0.6.0) (2024.12.0)\n",
      "Requirement already satisfied: colorama in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from tqdm->torchtext==0.6.0) (0.4.6)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from jinja2->torch->torchtext==0.6.0) (2.1.5)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from sympy->torch->torchtext==0.6.0) (1.3.0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: The repository located at pypi.douban.com is not a trusted or secure host and is being ignored. If this repository is available via HTTPS we recommend you use HTTPS instead, otherwise you may silence this warning and allow it anyway with '--trusted-host pypi.douban.com'.\n"
     ]
    }
   ],
   "source": [
    "# Import Libraries\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchtext.datasets import Multi30k #German to English dataset\n",
    "from torchtext.data import Field, BucketIterator\n",
    "import numpy as np\n",
    "import spacy\n",
    "import random\n",
    "from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard\n",
    "import torch\n",
    "import spacy\n",
    "!pip install torchtext==0.6.0\n",
    "import torchtext.data\n",
    "from torchtext.data.metrics import bleu_score\n",
    "import sys"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y1fGgVBDKfus"
   },
   "source": [
    "![image.png](attachment:image.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "_kg_hide-output": true,
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "executionInfo": {
     "elapsed": 11002,
     "status": "ok",
     "timestamp": 1606478094336,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "-yeVJvpPKfus",
    "outputId": "8893285e-4bc7-4592-cbf8-1753f4404199"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://pypi.douban.com/simple/\n",
      "Collecting de-core-news-sm==3.8.0\n",
      "  Using cached https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.8.0/de_core_news_sm-3.8.0-py3-none-any.whl (14.6 MB)\n",
      "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
      "You can now load the package via spacy.load('de_core_news_sm')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: The repository located at pypi.douban.com is not a trusted or secure host and is being ignored. If this repository is available via HTTPS we recommend you use HTTPS instead, otherwise you may silence this warning and allow it anyway with '--trusted-host pypi.douban.com'.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://pypi.douban.com/simple/\n",
      "Collecting en-core-web-sm==3.8.0\n",
      "  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)\n",
      "     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--\n",
      "     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--\n",
      "      --------------------------------------- 0.3/12.8 MB ? eta -:--:--\n",
      "     - -------------------------------------- 0.5/12.8 MB 1.0 MB/s eta 0:00:12\n",
      "     - -------------------------------------- 0.5/12.8 MB 1.0 MB/s eta 0:00:12\n",
      "     -- ------------------------------------- 0.8/12.8 MB 1.1 MB/s eta 0:00:12\n",
      "     --- ------------------------------------ 1.0/12.8 MB 1.0 MB/s eta 0:00:12\n",
      "     --- ------------------------------------ 1.0/12.8 MB 1.0 MB/s eta 0:00:12\n",
      "     --- ----------------------------------- 1.3/12.8 MB 972.7 kB/s eta 0:00:12\n",
      "     ---- ---------------------------------- 1.6/12.8 MB 912.1 kB/s eta 0:00:13\n",
      "     ---- ---------------------------------- 1.6/12.8 MB 912.1 kB/s eta 0:00:13\n",
      "     ----- --------------------------------- 1.8/12.8 MB 907.1 kB/s eta 0:00:13\n",
      "     ----- --------------------------------- 1.8/12.8 MB 907.1 kB/s eta 0:00:13\n",
      "     ------ -------------------------------- 2.1/12.8 MB 870.1 kB/s eta 0:00:13\n",
      "     ------- ------------------------------- 2.4/12.8 MB 838.9 kB/s eta 0:00:13\n",
      "     ------- ------------------------------- 2.4/12.8 MB 838.9 kB/s eta 0:00:13\n",
      "     ------- ------------------------------- 2.4/12.8 MB 838.9 kB/s eta 0:00:13\n",
      "     ------- ------------------------------- 2.6/12.8 MB 807.7 kB/s eta 0:00:13\n",
      "     -------- ------------------------------ 2.9/12.8 MB 806.6 kB/s eta 0:00:13\n",
      "     -------- ------------------------------ 2.9/12.8 MB 806.6 kB/s eta 0:00:13\n",
      "     --------- ----------------------------- 3.1/12.8 MB 809.5 kB/s eta 0:00:12\n",
      "     ---------- ---------------------------- 3.4/12.8 MB 818.6 kB/s eta 0:00:12\n",
      "     ----------- --------------------------- 3.7/12.8 MB 829.4 kB/s eta 0:00:12\n",
      "     ----------- --------------------------- 3.7/12.8 MB 829.4 kB/s eta 0:00:12\n",
      "     ----------- --------------------------- 3.9/12.8 MB 830.1 kB/s eta 0:00:11\n",
      "     ------------ -------------------------- 4.2/12.8 MB 830.6 kB/s eta 0:00:11\n",
      "     ------------ -------------------------- 4.2/12.8 MB 830.6 kB/s eta 0:00:11\n",
      "     ------------- ------------------------- 4.5/12.8 MB 836.4 kB/s eta 0:00:10\n",
      "     ------------- ------------------------- 4.5/12.8 MB 836.4 kB/s eta 0:00:10\n",
      "     -------------- ------------------------ 4.7/12.8 MB 819.7 kB/s eta 0:00:10\n",
      "     -------------- ------------------------ 4.7/12.8 MB 819.7 kB/s eta 0:00:10\n",
      "     --------------- ----------------------- 5.0/12.8 MB 811.9 kB/s eta 0:00:10\n",
      "     --------------- ----------------------- 5.2/12.8 MB 811.2 kB/s eta 0:00:10\n",
      "     --------------- ----------------------- 5.2/12.8 MB 811.2 kB/s eta 0:00:10\n",
      "     ---------------- ---------------------- 5.5/12.8 MB 818.4 kB/s eta 0:00:09\n",
      "     ----------------- --------------------- 5.8/12.8 MB 823.2 kB/s eta 0:00:09\n",
      "     ------------------ -------------------- 6.0/12.8 MB 827.7 kB/s eta 0:00:09\n",
      "     ------------------ -------------------- 6.0/12.8 MB 827.7 kB/s eta 0:00:09\n",
      "     ------------------- ------------------- 6.3/12.8 MB 831.6 kB/s eta 0:00:08\n",
      "     ------------------- ------------------- 6.6/12.8 MB 837.2 kB/s eta 0:00:08\n",
      "     -------------------- ------------------ 6.8/12.8 MB 844.0 kB/s eta 0:00:08\n",
      "     --------------------- ----------------- 7.1/12.8 MB 850.4 kB/s eta 0:00:07\n",
      "     --------------------- ----------------- 7.1/12.8 MB 850.4 kB/s eta 0:00:07\n",
      "     ---------------------- ---------------- 7.3/12.8 MB 854.7 kB/s eta 0:00:07\n",
      "     ----------------------- --------------- 7.6/12.8 MB 854.2 kB/s eta 0:00:07\n",
      "     ----------------------- --------------- 7.6/12.8 MB 854.2 kB/s eta 0:00:07\n",
      "     ----------------------- --------------- 7.9/12.8 MB 855.1 kB/s eta 0:00:06\n",
      "     ------------------------ -------------- 8.1/12.8 MB 858.9 kB/s eta 0:00:06\n",
      "     ------------------------- ------------- 8.4/12.8 MB 862.6 kB/s eta 0:00:06\n",
      "     ------------------------- ------------- 8.4/12.8 MB 862.6 kB/s eta 0:00:06\n",
      "     -------------------------- ------------ 8.7/12.8 MB 866.0 kB/s eta 0:00:05\n",
      "     --------------------------- ----------- 8.9/12.8 MB 869.2 kB/s eta 0:00:05\n",
      "     --------------------------- ----------- 9.2/12.8 MB 870.9 kB/s eta 0:00:05\n",
      "     --------------------------- ----------- 9.2/12.8 MB 870.9 kB/s eta 0:00:05\n",
      "     ---------------------------- ---------- 9.4/12.8 MB 871.2 kB/s eta 0:00:04\n",
      "     ----------------------------- --------- 9.7/12.8 MB 877.9 kB/s eta 0:00:04\n",
      "     ----------------------------- -------- 10.0/12.8 MB 881.8 kB/s eta 0:00:04\n",
      "     ------------------------------ ------- 10.2/12.8 MB 889.2 kB/s eta 0:00:03\n",
      "     ------------------------------- ------ 10.5/12.8 MB 893.9 kB/s eta 0:00:03\n",
      "     ------------------------------- ------ 10.7/12.8 MB 896.0 kB/s eta 0:00:03\n",
      "     ------------------------------- ------ 10.7/12.8 MB 896.0 kB/s eta 0:00:03\n",
      "     -------------------------------- ----- 11.0/12.8 MB 895.7 kB/s eta 0:00:03\n",
      "     -------------------------------- ----- 11.0/12.8 MB 895.7 kB/s eta 0:00:03\n",
      "     --------------------------------- ---- 11.3/12.8 MB 881.9 kB/s eta 0:00:02\n",
      "     --------------------------------- ---- 11.3/12.8 MB 881.9 kB/s eta 0:00:02\n",
      "     ---------------------------------- --- 11.5/12.8 MB 874.5 kB/s eta 0:00:02\n",
      "     ---------------------------------- --- 11.5/12.8 MB 874.5 kB/s eta 0:00:02\n",
      "     ----------------------------------- -- 11.8/12.8 MB 869.5 kB/s eta 0:00:02\n",
      "     ----------------------------------- -- 11.8/12.8 MB 869.5 kB/s eta 0:00:02\n",
      "     ----------------------------------- -- 12.1/12.8 MB 862.9 kB/s eta 0:00:01\n",
      "     ----------------------------------- -- 12.1/12.8 MB 862.9 kB/s eta 0:00:01\n",
      "     ------------------------------------ - 12.3/12.8 MB 861.3 kB/s eta 0:00:01\n",
      "     -------------------------------------  12.6/12.8 MB 859.9 kB/s eta 0:00:01\n",
      "     -------------------------------------  12.6/12.8 MB 859.9 kB/s eta 0:00:01\n",
      "     -------------------------------------- 12.8/12.8 MB 859.6 kB/s eta 0:00:00\n",
      "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
      "You can now load the package via spacy.load('en_core_web_sm')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: The repository located at pypi.douban.com is not a trusted or secure host and is being ignored. If this repository is available via HTTPS we recommend you use HTTPS instead, otherwise you may silence this warning and allow it anyway with '--trusted-host pypi.douban.com'.\n"
     ]
    }
   ],
   "source": [
    "!python -m spacy download de_core_news_sm\n",
    "!python -m spacy download en_core_web_sm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a",
    "executionInfo": {
     "elapsed": 12908,
     "status": "ok",
     "timestamp": 1606478096248,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "vc95bSrWKfuw"
   },
   "outputs": [],
   "source": [
    "# Loading Tokeniser in german and English\n",
    "spacy_ger = spacy.load(\"de_core_news_sm\")\n",
    "spacy_eng = spacy.load(\"en_core_web_sm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 12904,
     "status": "ok",
     "timestamp": 1606478096249,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "JinX1f6QKfuz"
   },
   "outputs": [],
   "source": [
    "# Tokenization of German Language\n",
    "def tokenize_ger(text):\n",
    "    return [tok.text for tok in spacy_ger.tokenizer(text)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 12900,
     "status": "ok",
     "timestamp": 1606478096249,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "Avd2uijgKfu2"
   },
   "outputs": [],
   "source": [
    "# Tokenization of English Language\n",
    "\n",
    "def tokenize_eng(text):\n",
    "    return [tok.text for tok in spacy_eng.tokenizer(text)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-zP-N5ToKfu6"
   },
   "source": [
    "## Preprocessing of Text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 12897,
     "status": "ok",
     "timestamp": 1606478096250,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "o8t7cvVcKfu7"
   },
   "outputs": [],
   "source": [
    "# Applyling Tokenization , lowercase and special Tokens for preprocessing\n",
    "german = Field(tokenize = tokenize_ger,lower = True,init_token = '<sos>',eos_token = '<eos>')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 12894,
     "status": "ok",
     "timestamp": 1606478096250,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "9PkvjAIqKfu-"
   },
   "outputs": [],
   "source": [
    "english = Field(tokenize = tokenize_eng,lower = True,init_token = '<sos>',eos_token = '<eos>')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "_kg_hide-output": false,
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 22088,
     "status": "ok",
     "timestamp": 1606478105447,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "s-1M5hXkKfvB",
    "outputId": "c5c95c6d-faab-4759-9f59-3baefd569687"
   },
   "outputs": [],
   "source": [
    "# Dwonloading Dataset and storing them\n",
    "train_data, valid_data, test_data = Multi30k.splits(\n",
    "    exts=(\".de\", \".en\"), fields=(german, english)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "executionInfo": {
     "elapsed": 22096,
     "status": "ok",
     "timestamp": 1606478105458,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "fJF0wjBmKfvE"
   },
   "outputs": [],
   "source": [
    "# Creating vocabulary in each language\n",
    "german.build_vocab(train_data,max_size = 10000,min_freq = 2)\n",
    "english.build_vocab(train_data,max_size = 10000,min_freq = 2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "executionInfo": {
     "elapsed": 22094,
     "status": "ok",
     "timestamp": 1606478105459,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "96Sm0iJJKfvG"
   },
   "outputs": [],
   "source": [
    "\n",
    "# Defining the Encoder part of the model\n",
    "class Encoder(nn.Module):\n",
    "    \n",
    "    def __init__(self, input_size, embedding_size, hidden_size, num_layers, p):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.dropout = nn.Dropout(p)\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        self.embedding = nn.Embedding(input_size, embedding_size)\n",
    "        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=p)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        # x shape: (seq_length, N) where N is batch size\n",
    "\n",
    "        embedding = self.dropout(F.silu(self.embedding(x)))\n",
    "        # embedding shape: (seq_length, N, embedding_size)\n",
    "\n",
    "        outputs, (hidden, cell) = self.rnn(embedding)\n",
    "        # outputs shape: (seq_length, N, hidden_size)\n",
    "\n",
    "        return hidden, cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "executionInfo": {
     "elapsed": 22093,
     "status": "ok",
     "timestamp": 1606478105460,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "aiyK-YsTKfvJ"
   },
   "outputs": [],
   "source": [
    "# Defining the Decoder part\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(\n",
    "        self, input_size, embedding_size, hidden_size, output_size, num_layers, p):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.dropout = nn.Dropout(p)\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        self.embedding = nn.Embedding(input_size, embedding_size)\n",
    "        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=p)\n",
    "        self.fc = nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x, hidden, cell):\n",
    "        # x shape: (N) where N is for batch size, we want it to be (1, N), seq_length\n",
    "        # is 1 here because we are sending in a single word and not a sentence\n",
    "        x = x.unsqueeze(0)\n",
    "\n",
    "        embedding = self.dropout(F.silu(self.embedding(x)))\n",
    "        # embedding shape: (1, N, embedding_size)\n",
    "\n",
    "        outputs, (hidden, cell) = self.rnn(embedding, (hidden, cell))\n",
    "        # outputs shape: (1, N, hidden_size)\n",
    "\n",
    "        predictions = self.fc(outputs)\n",
    "\n",
    "        # predictions shape: (1, N, length_target_vocabulary) to send it to\n",
    "        # loss function we want it to be (N, length_target_vocabulary) so we're\n",
    "        # just gonna remove the first dim\n",
    "        predictions = predictions.squeeze(0)\n",
    "\n",
    "        return predictions, hidden, cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "executionInfo": {
     "elapsed": 22091,
     "status": "ok",
     "timestamp": 1606478105461,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "mdJJXlGyKfvM"
   },
   "outputs": [],
   "source": [
    "# Defining the complete model\n",
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder, decoder):\n",
    "        super(Seq2Seq, self).__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def forward(self, source, target, teacher_force_ratio=0.5):\n",
    "        batch_size = source.shape[1]\n",
    "        target_len = target.shape[0]\n",
    "        target_vocab_size = len(english.vocab)\n",
    "\n",
    "        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(device)\n",
    "\n",
    "        hidden, cell = self.encoder(source)\n",
    "\n",
    "        # Grab the first input to the Decoder which will be <SOS> token\n",
    "        x = target[0]\n",
    "\n",
    "        for t in range(1, target_len):\n",
    "            # Use previous hidden, cell as context from encoder at start\n",
    "            output, hidden, cell = self.decoder(x, hidden, cell)\n",
    "\n",
    "            # Store next output prediction\n",
    "            outputs[t] = output\n",
    "\n",
    "            # Get the best word the Decoder predicted (index in the vocabulary)\n",
    "            best_guess = output.argmax(1)\n",
    "\n",
    "            # With probability of teacher_force_ratio we take the actual next word\n",
    "            # otherwise we take the word that the Decoder predicted it to be.\n",
    "            # Teacher Forcing is used so that the model gets used to seeing\n",
    "            # similar inputs at training and testing time, if teacher forcing is 1\n",
    "            # then inputs at test time might be completely different than what the\n",
    "            # network is used to. This was a long comment.\n",
    "            x = target[t] if random.random() < teacher_force_ratio else best_guess\n",
    "\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "executionInfo": {
     "elapsed": 22089,
     "status": "ok",
     "timestamp": 1606478105462,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "5ziLyXFxKfvP"
   },
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "num_epochs = 50\n",
    "learning_rate = 0.001\n",
    "batch_size = 256\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "executionInfo": {
     "elapsed": 22087,
     "status": "ok",
     "timestamp": 1606478105463,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "Vr3ZyNckKfvS"
   },
   "outputs": [],
   "source": [
    "# Model hyperparameters\n",
    "load_model = False\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else 'cpu')\n",
    "input_size_encoder = len(german.vocab)\n",
    "input_size_decoder = len(english.vocab)\n",
    "output_size = len(english.vocab)\n",
    "encoder_embedding_size = 300\n",
    "decoder_embedding_size = 300\n",
    "\n",
    "hidden_size = 1024\n",
    "num_layers = 1\n",
    "enc_dropout = 0.5\n",
    "dec_dropout = 0.5\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "executionInfo": {
     "elapsed": 23819,
     "status": "ok",
     "timestamp": 1606478107197,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "6DZ7zj8LKfvU"
   },
   "outputs": [],
   "source": [
    "# Tensorboard to get nice loss plot\n",
    "writer = SummaryWriter(f'runs/Loss_plot')\n",
    "step = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "executionInfo": {
     "elapsed": 23817,
     "status": "ok",
     "timestamp": 1606478107198,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "z6g5kRsiKfvX"
   },
   "outputs": [],
   "source": [
    "train_iterator, validation_iterator, test_iterator = BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data),\n",
    "     batch_size = batch_size, sort_within_batch = True, \n",
    "     sort_key = lambda x:len(x.src),\n",
    "     device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 38350,
     "status": "ok",
     "timestamp": 1606478121734,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "lwapcC4mKfva",
    "outputId": "cd4de382-8790-4305-9ca4-8f2d6c811f3c"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\anaconda_download\\envs\\kan\\lib\\site-packages\\torch\\nn\\modules\\rnn.py:88: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
      "  warnings.warn(\"dropout option adds dropout after all but last \"\n"
     ]
    }
   ],
   "source": [
    "encoder_net = Encoder(input_size_encoder, \n",
    "                      encoder_embedding_size,\n",
    "                      hidden_size,num_layers, \n",
    "                      enc_dropout).to(device)\n",
    "\n",
    "\n",
    "decoder_net = Decoder(input_size_decoder, \n",
    "                      decoder_embedding_size,\n",
    "                      hidden_size,output_size,num_layers, \n",
    "                      dec_dropout).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "executionInfo": {
     "elapsed": 38351,
     "status": "ok",
     "timestamp": 1606478121737,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "Liy7CquOKfvd"
   },
   "outputs": [],
   "source": [
    "model = Seq2Seq(encoder_net, decoder_net).to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "executionInfo": {
     "elapsed": 38348,
     "status": "ok",
     "timestamp": 1606478121738,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "7ex-6pH0Kfvf"
   },
   "outputs": [],
   "source": [
    "pad_idx = english.vocab.stoi['<pad>']\n",
    "criterion = nn.CrossEntropyLoss(ignore_index = pad_idx)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "executionInfo": {
     "elapsed": 38346,
     "status": "ok",
     "timestamp": 1606478121739,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "f6ioRsy5Kfvh"
   },
   "outputs": [],
   "source": [
    "def translate_sentence(model, sentence, german, english, device, max_length=50):\n",
    "    # print(sentence)\n",
    "\n",
    "    # sys.exit()\n",
    "\n",
    "    # Load german tokenizer\n",
    "    spacy_ger = spacy.load(\"de\")\n",
    "\n",
    "    # Create tokens using spacy and everything in lower case (which is what our vocab is)\n",
    "    if type(sentence) == str:\n",
    "        tokens = [token.text.lower() for token in spacy_ger(sentence)]\n",
    "    else:\n",
    "        tokens = [token.lower() for token in sentence]\n",
    "\n",
    "    # print(tokens)\n",
    "\n",
    "    # sys.exit()\n",
    "    # Add <SOS> and <EOS> in beginning and end respectively\n",
    "    tokens.insert(0, german.init_token)\n",
    "    tokens.append(german.eos_token)\n",
    "\n",
    "    # Go through each german token and convert to an index\n",
    "    text_to_indices = [german.vocab.stoi[token] for token in tokens]\n",
    "\n",
    "    # Convert to Tensor\n",
    "    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)\n",
    "\n",
    "    # Build encoder hidden, cell state\n",
    "    with torch.no_grad():\n",
    "        hidden, cell = model.encoder(sentence_tensor)\n",
    "\n",
    "    outputs = [english.vocab.stoi[\"<sos>\"]]\n",
    "\n",
    "    for _ in range(max_length):\n",
    "        previous_word = torch.LongTensor([outputs[-1]]).to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            output, hidden, cell = model.decoder(previous_word, hidden, cell)\n",
    "            best_guess = output.argmax(1).item()\n",
    "\n",
    "        outputs.append(best_guess)\n",
    "\n",
    "        # Model predicts it's the end of the sentence\n",
    "        if output.argmax(1).item() == english.vocab.stoi[\"<eos>\"]:\n",
    "            break\n",
    "\n",
    "    translated_sentence = [english.vocab.itos[idx] for idx in outputs]\n",
    "\n",
    "    # remove start token\n",
    "    return translated_sentence[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "executionInfo": {
     "elapsed": 38343,
     "status": "ok",
     "timestamp": 1606478121739,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "rm9yIYvIKfvj"
   },
   "outputs": [],
   "source": [
    "def save_checkpoint(state, filename=\"my_checkpoint.pth.tar\"):\n",
    "    print(\"=> Saving checkpoint\")\n",
    "    torch.save(state, filename)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "executionInfo": {
     "elapsed": 38342,
     "status": "ok",
     "timestamp": 1606478121740,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "WbNauSlDKfvm"
   },
   "outputs": [],
   "source": [
    "def load_checkpoint(checkpoint, model, optimizer):\n",
    "    print(\"=> Loading checkpoint\")\n",
    "    model.load_state_dict(checkpoint[\"state_dict\"])\n",
    "    optimizer.load_state_dict(checkpoint[\"optimizer\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "executionInfo": {
     "elapsed": 38340,
     "status": "ok",
     "timestamp": 1606478121741,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "CiNGFIGGKfvo"
   },
   "outputs": [],
   "source": [
    "if load_model:\n",
    "    load_checkpoint(torch.load(\"my_checkpoint.pth.tar\"), model, optimizer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "executionInfo": {
     "elapsed": 38337,
     "status": "ok",
     "timestamp": 1606478121741,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "rOiYTY4DKfvq"
   },
   "outputs": [],
   "source": [
    "sentence = \"Cristiano Ronaldo ist ein großartiger Fußballspieler mit erstaunlichen Fähigkeiten und Talenten.\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1520107,
     "status": "ok",
     "timestamp": 1606479603514,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "ujHJNl4oKfvt",
    "outputId": "34c9c165-5046-4e7d-d6b7-3f0d30572cd3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0 / 50]\n",
      "=> Saving checkpoint\n",
      "[Epoch 1 / 50]\n",
      "=> Saving checkpoint\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    print(f\"[Epoch {epoch} / {num_epochs}]\")\n",
    "\n",
    "    checkpoint = {\"state_dict\": model.state_dict(), \"optimizer\": optimizer.state_dict()}\n",
    "    save_checkpoint(checkpoint)\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    #translated_sentence = translate_sentence(\n",
    "    #    model, sentence, german, english, device, max_length=50\n",
    "    #)\n",
    "\n",
    "    #print(f\"Translated example sentence: \\n {translated_sentence}\")\n",
    "\n",
    "    model.train()\n",
    "\n",
    "    for batch_idx, batch in enumerate(train_iterator):\n",
    "        # Get input and targets and get to cuda\n",
    "        inp_data = batch.src.to(device)\n",
    "        target = batch.trg.to(device)\n",
    "\n",
    "        # Forward prop\n",
    "        output = model(inp_data, target)\n",
    "\n",
    "        # Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss\n",
    "        # doesn't take input in that form. For example if we have MNIST we want to have\n",
    "        # output to be: (N, 10) and targets just (N). Here we can view it in a similar\n",
    "        # way that we have output_words * batch_size that we want to send in into\n",
    "        # our cost function, so we need to do some reshapin. While we're at it\n",
    "        # Let's also remove the start token while we're at it\n",
    "        output = output[1:].reshape(-1, output.shape[2])\n",
    "        target = target[1:].reshape(-1)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # Back prop\n",
    "        loss.backward()\n",
    "\n",
    "        # Clip to avoid exploding gradient issues, makes sure grads are\n",
    "        # within a healthy range\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n",
    "\n",
    "        # Gradient descent step\n",
    "        optimizer.step()\n",
    "\n",
    "        # Plot to tensorboard\n",
    "        writer.add_scalar(\"Training loss\", loss, global_step=step)\n",
    "        step += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "executionInfo": {
     "elapsed": 1520107,
     "status": "ok",
     "timestamp": 1606479603517,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "Ir5lIeOwKfvv"
   },
   "outputs": [],
   "source": [
    "def bleu(data, model, german, english, device):\n",
    "    targets = []\n",
    "    outputs = []\n",
    "\n",
    "    for example in data:\n",
    "        src = vars(example)[\"src\"]\n",
    "        trg = vars(example)[\"trg\"]\n",
    "\n",
    "        prediction = translate_sentence(model, src, german, english, device)\n",
    "        prediction = prediction[:-1]  # remove <eos> token\n",
    "\n",
    "        targets.append([trg])\n",
    "        outputs.append(prediction)\n",
    "\n",
    "    return bleu_score(outputs, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "executionInfo": {
     "elapsed": 1520106,
     "status": "ok",
     "timestamp": 1606479603518,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "hvYMk4hYKfvy"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import spacy\n",
    "from torchtext.data.metrics import bleu_score\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1628226,
     "status": "ok",
     "timestamp": 1606479711641,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "f-bW5z8oKfv0",
    "outputId": "d058542a-b11a-4c7a-abd2-23211039c29f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bleu score 16.31\n"
     ]
    }
   ],
   "source": [
    "\n",
    "score = bleu(test_data[1:100], model, german, english, device)\n",
    "print(f\"Bleu score {score*100:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "executionInfo": {
     "elapsed": 1628224,
     "status": "ok",
     "timestamp": 1606479711642,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "-PN1sqf5Kfv2"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Softplus_Seq2Seq-MachineTranslation.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "kan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
