{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7F-sOZwhKfun"
   },
   "source": [
    "Hey Guys , I have made a Tutorial for Seq2Seq Machine Translation from scratch. Hope you like it. Upvote :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 218
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 4837,
     "status": "ok",
     "timestamp": 1600772202373,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "ykbPZi7zKfuo",
    "outputId": "d3bc8dcf-70b7-4230-d0ab-3c4491577c8e"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\anaconda_download\\envs\\kan\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://pypi.douban.com/simple/\n",
      "Requirement already satisfied: torchtext==0.6.0 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (0.6.0)\n",
      "Requirement already satisfied: tqdm in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (4.66.5)\n",
      "Requirement already satisfied: requests in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (2.32.3)\n",
      "Requirement already satisfied: torch in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (1.8.0)\n",
      "Requirement already satisfied: numpy in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (1.21.6)\n",
      "Requirement already satisfied: six in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (1.16.0)\n",
      "Requirement already satisfied: sentencepiece in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torchtext==0.6.0) (0.2.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from requests->torchtext==0.6.0) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from requests->torchtext==0.6.0) (3.8)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from requests->torchtext==0.6.0) (2.2.2)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from requests->torchtext==0.6.0) (2024.8.30)\n",
      "Requirement already satisfied: typing-extensions in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from torch->torchtext==0.6.0) (4.12.2)\n",
      "Requirement already satisfied: colorama in d:\\anaconda_download\\envs\\kan\\lib\\site-packages (from tqdm->torchtext==0.6.0) (0.4.6)\n"
     ]
    }
   ],
   "source": [
    "# Import Libraries\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchtext.datasets import Multi30k #German to English dataset\n",
    "from torchtext.data import Field, BucketIterator\n",
    "import numpy as np\n",
    "import spacy\n",
    "import random\n",
    "from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard\n",
    "import torch\n",
    "import spacy\n",
    "!pip install torchtext==0.6.0\n",
    "import torchtext.data\n",
    "from torchtext.data.metrics import bleu_score\n",
    "import sys"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "y1fGgVBDKfus"
   },
   "source": [
    "![image.png](attachment:image.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "_kg_hide-output": true,
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 642
    },
    "colab_type": "code",
    "collapsed": true,
    "executionInfo": {
     "elapsed": 15900,
     "status": "ok",
     "timestamp": 1600772213444,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "-yeVJvpPKfus",
    "outputId": "5da9e427-a940-459a-ec9e-5c412970551a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://pypi.douban.com/simple/\n",
      "Collecting de-core-news-sm==3.8.0\n",
      "  Using cached https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.8.0/de_core_news_sm-3.8.0-py3-none-any.whl (14.6 MB)\n",
      "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
      "You can now load the package via spacy.load('de_core_news_sm')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: The repository located at pypi.douban.com is not a trusted or secure host and is being ignored. If this repository is available via HTTPS we recommend you use HTTPS instead, otherwise you may silence this warning and allow it anyway with '--trusted-host pypi.douban.com'.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://pypi.douban.com/simple/\n",
      "Collecting en-core-web-sm==3.8.0\n",
      "  Using cached https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)\n",
      "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
      "You can now load the package via spacy.load('en_core_web_sm')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: The repository located at pypi.douban.com is not a trusted or secure host and is being ignored. If this repository is available via HTTPS we recommend you use HTTPS instead, otherwise you may silence this warning and allow it anyway with '--trusted-host pypi.douban.com'.\n"
     ]
    }
   ],
   "source": [
    "!python -m spacy download de_core_news_sm\n",
    "!python -m spacy download en_core_web_sm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a",
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 17256,
     "status": "ok",
     "timestamp": 1600772214804,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "vc95bSrWKfuw"
   },
   "outputs": [],
   "source": [
    "# Loading Tokeniser in german and English\n",
    "spacy_ger = spacy.load(\"de_core_news_sm\")\n",
    "spacy_eng = spacy.load(\"en_core_web_sm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 17254,
     "status": "ok",
     "timestamp": 1600772214806,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "JinX1f6QKfuz"
   },
   "outputs": [],
   "source": [
    "# Tokenization of German Language\n",
    "def tokenize_ger(text):\n",
    "    return [tok.text for tok in spacy_ger.tokenizer(text)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 17250,
     "status": "ok",
     "timestamp": 1600772214807,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "Avd2uijgKfu2"
   },
   "outputs": [],
   "source": [
    "# Tokenization of English Language\n",
    "\n",
    "def tokenize_eng(text):\n",
    "    return [tok.text for tok in spacy_eng.tokenizer(text)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-zP-N5ToKfu6"
   },
   "source": [
    "## Preprocessing of Text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 17246,
     "status": "ok",
     "timestamp": 1600772214807,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "o8t7cvVcKfu7"
   },
   "outputs": [],
   "source": [
    "# Applyling Tokenization , lowercase and special Tokens for preprocessing\n",
    "german = Field(tokenize = tokenize_ger,lower = True,init_token = '<sos>',eos_token = '<eos>')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 17243,
     "status": "ok",
     "timestamp": 1600772214808,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "9PkvjAIqKfu-"
   },
   "outputs": [],
   "source": [
    "english = Field(tokenize = tokenize_eng,lower = True,init_token = '<sos>',eos_token = '<eos>')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "_kg_hide-output": false,
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 118
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 34389,
     "status": "ok",
     "timestamp": 1600772231959,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "s-1M5hXkKfvB",
    "outputId": "7c78a274-0ac2-4ff6-ec3e-309a6ba28aba"
   },
   "outputs": [],
   "source": [
    "# Dwonloading Dataset and storing them\n",
    "train_data, valid_data, test_data = Multi30k.splits(\n",
    "    exts=(\".de\", \".en\"), fields=(german, english)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 34387,
     "status": "ok",
     "timestamp": 1600772231961,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "fJF0wjBmKfvE"
   },
   "outputs": [],
   "source": [
    "# Creating vocabulary in each language\n",
    "german.build_vocab(train_data,max_size = 10000,min_freq = 2)\n",
    "english.build_vocab(train_data,max_size = 10000,min_freq = 2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CauchyActivation(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CauchyActivation, self).__init__()\n",
    "        self.lambda1 = nn.Parameter(torch.empty(1, dtype=torch.float32)) \n",
    "        self.lambda2 = nn.Parameter(torch.empty(1, dtype=torch.float32)) \n",
    "        self.d = nn.Parameter(torch.empty(1, dtype=torch.float32))    \n",
    "\n",
    "        init.uniform_(self.lambda1, -0.2, 0.2)  \n",
    "        init.uniform_(self.lambda2, -1.0, 1.0) \n",
    "        init.uniform_(self.d, -1, 1.0)         \n",
    "\n",
    "    def forward(self, x):\n",
    "        denominator = x**2 + self.d**2\n",
    "        return (self.lambda1 * x / denominator) + (self.lambda2 / denominator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 34383,
     "status": "ok",
     "timestamp": 1600772231961,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "96Sm0iJJKfvG"
   },
   "outputs": [],
   "source": [
    "\n",
    "# Defining the Encoder part of the model\n",
    "class Encoder(nn.Module):\n",
    "    \n",
    "    def __init__(self, input_size, embedding_size, hidden_size, num_layers, p):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.dropout = nn.Dropout(p)\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        self.embedding = nn.Embedding(input_size, embedding_size)\n",
    "        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=p)\n",
    "        self.cauchy_activation = CauchyActivation()\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x shape: (seq_length, N) where N is batch size\n",
    "\n",
    "        embedding = self.dropout(self.cauchy_activation(self.embedding(x)))\n",
    "        # embedding shape: (seq_length, N, embedding_size)\n",
    "\n",
    "        outputs, (hidden, cell) = self.rnn(embedding)\n",
    "        # outputs shape: (seq_length, N, hidden_size)\n",
    "\n",
    "        return hidden, cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 34380,
     "status": "ok",
     "timestamp": 1600772231962,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "aiyK-YsTKfvJ"
   },
   "outputs": [],
   "source": [
    "# Defining the Decoder part\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(\n",
    "        self, input_size, embedding_size, hidden_size, output_size, num_layers, p):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.dropout = nn.Dropout(p)\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        self.embedding = nn.Embedding(input_size, embedding_size)\n",
    "        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=p)\n",
    "        self.fc = nn.Linear(hidden_size, output_size)\n",
    "        self.cauchy_activation = CauchyActivation()\n",
    "\n",
    "    def forward(self, x, hidden, cell):\n",
    "        # x shape: (N) where N is for batch size, we want it to be (1, N), seq_length\n",
    "        # is 1 here because we are sending in a single word and not a sentence\n",
    "        x = x.unsqueeze(0)\n",
    "\n",
    "        embedding = self.dropout(self.cauchy_activation(self.embedding(x)))\n",
    "        # embedding shape: (1, N, embedding_size)\n",
    "\n",
    "        outputs, (hidden, cell) = self.rnn(embedding, (hidden, cell))\n",
    "        # outputs shape: (1, N, hidden_size)\n",
    "\n",
    "        predictions = self.fc(outputs)\n",
    "\n",
    "        # predictions shape: (1, N, length_target_vocabulary) to send it to\n",
    "        # loss function we want it to be (N, length_target_vocabulary) so we're\n",
    "        # just gonna remove the first dim\n",
    "        predictions = predictions.squeeze(0)\n",
    "\n",
    "        return predictions, hidden, cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 34376,
     "status": "ok",
     "timestamp": 1600772231962,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "mdJJXlGyKfvM"
   },
   "outputs": [],
   "source": [
    "# Defining the complete model\n",
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder, decoder):\n",
    "        super(Seq2Seq, self).__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def forward(self, source, target, teacher_force_ratio=0.5):\n",
    "        batch_size = source.shape[1]\n",
    "        target_len = target.shape[0]\n",
    "        target_vocab_size = len(english.vocab)\n",
    "\n",
    "        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(device)\n",
    "\n",
    "        hidden, cell = self.encoder(source)\n",
    "\n",
    "        # Grab the first input to the Decoder which will be <SOS> token\n",
    "        x = target[0]\n",
    "\n",
    "        for t in range(1, target_len):\n",
    "            # Use previous hidden, cell as context from encoder at start\n",
    "            output, hidden, cell = self.decoder(x, hidden, cell)\n",
    "\n",
    "            # Store next output prediction\n",
    "            outputs[t] = output\n",
    "\n",
    "            # Get the best word the Decoder predicted (index in the vocabulary)\n",
    "            best_guess = output.argmax(1)\n",
    "\n",
    "            # With probability of teacher_force_ratio we take the actual next word\n",
    "            # otherwise we take the word that the Decoder predicted it to be.\n",
    "            # Teacher Forcing is used so that the model gets used to seeing\n",
    "            # similar inputs at training and testing time, if teacher forcing is 1\n",
    "            # then inputs at test time might be completely different than what the\n",
    "            # network is used to. This was a long comment.\n",
    "            x = target[t] if random.random() < teacher_force_ratio else best_guess\n",
    "\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 34371,
     "status": "ok",
     "timestamp": 1600772231963,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "5ziLyXFxKfvP"
   },
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "num_epochs = 50\n",
    "# num_epochs = 5\n",
    "learning_rate = 0.001\n",
    "batch_size = 256\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 34790,
     "status": "ok",
     "timestamp": 1600772232388,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "Vr3ZyNckKfvS"
   },
   "outputs": [],
   "source": [
    "# Model hyperparameters\n",
    "load_model = False\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else 'cpu')\n",
    "input_size_encoder = len(german.vocab)\n",
    "input_size_decoder = len(english.vocab)\n",
    "output_size = len(english.vocab)\n",
    "encoder_embedding_size = 300\n",
    "decoder_embedding_size = 300\n",
    "hidden_size = 1024\n",
    "\n",
    "\n",
    "num_layers = 1\n",
    "enc_dropout = 0.5\n",
    "dec_dropout = 0.5\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 36469,
     "status": "ok",
     "timestamp": 1600772234071,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "6DZ7zj8LKfvU"
   },
   "outputs": [],
   "source": [
    "# Tensorboard to get nice loss plot\n",
    "writer = SummaryWriter(f'runs/Loss_plot')\n",
    "step = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 36464,
     "status": "ok",
     "timestamp": 1600772234072,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "z6g5kRsiKfvX"
   },
   "outputs": [],
   "source": [
    "train_iterator, validation_iterator, test_iterator = BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data),\n",
    "     batch_size = batch_size, sort_within_batch = True, \n",
    "     sort_key = lambda x:len(x.src),\n",
    "     device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 70
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 57253,
     "status": "ok",
     "timestamp": 1600772254866,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "lwapcC4mKfva",
    "outputId": "54803bb1-dd6a-4571-db35-2ae7679c97c8"
   },
   "outputs": [],
   "source": [
    "import torch.nn.init as init\n",
    "encoder_net = Encoder(input_size_encoder, \n",
    "                      encoder_embedding_size,\n",
    "                      hidden_size,num_layers, \n",
    "                      enc_dropout).to(device)\n",
    "\n",
    "\n",
    "decoder_net = Decoder(input_size_decoder, \n",
    "                      decoder_embedding_size,\n",
    "                      hidden_size,output_size,num_layers, \n",
    "                      dec_dropout).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 57251,
     "status": "ok",
     "timestamp": 1600772254869,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "Liy7CquOKfvd"
   },
   "outputs": [],
   "source": [
    "model = Seq2Seq(encoder_net, decoder_net).to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 57247,
     "status": "ok",
     "timestamp": 1600772254869,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "7ex-6pH0Kfvf"
   },
   "outputs": [],
   "source": [
    "pad_idx = english.vocab.stoi['<pad>']\n",
    "criterion = nn.CrossEntropyLoss(ignore_index = pad_idx)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 57244,
     "status": "ok",
     "timestamp": 1600772254870,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "f6ioRsy5Kfvh"
   },
   "outputs": [],
   "source": [
    "def translate_sentence(model, sentence, german, english, device, max_length=50):\n",
    "    # print(sentence)\n",
    "\n",
    "    # sys.exit()\n",
    "\n",
    "    # Load german tokenizer\n",
    "    spacy_ger = spacy.load(\"de_core_news_sm\")\n",
    "\n",
    "    # Create tokens using spacy and everything in lower case (which is what our vocab is)\n",
    "    if type(sentence) == str:\n",
    "        tokens = [token.text.lower() for token in spacy_ger(sentence)]\n",
    "    else:\n",
    "        tokens = [token.lower() for token in sentence]\n",
    "\n",
    "    # print(tokens)\n",
    "\n",
    "    # sys.exit()\n",
    "    # Add <SOS> and <EOS> in beginning and end respectively\n",
    "    tokens.insert(0, german.init_token)\n",
    "    tokens.append(german.eos_token)\n",
    "\n",
    "    # Go through each german token and convert to an index\n",
    "    text_to_indices = [german.vocab.stoi[token] for token in tokens]\n",
    "\n",
    "    # Convert to Tensor\n",
    "    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)\n",
    "\n",
    "    # Build encoder hidden, cell state\n",
    "    with torch.no_grad():\n",
    "        hidden, cell = model.encoder(sentence_tensor)\n",
    "\n",
    "    outputs = [english.vocab.stoi[\"<sos>\"]]\n",
    "\n",
    "    for _ in range(max_length):\n",
    "        previous_word = torch.LongTensor([outputs[-1]]).to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            output, hidden, cell = model.decoder(previous_word, hidden, cell)\n",
    "            best_guess = output.argmax(1).item()\n",
    "\n",
    "        outputs.append(best_guess)\n",
    "\n",
    "        # Model predicts it's the end of the sentence\n",
    "        if output.argmax(1).item() == english.vocab.stoi[\"<eos>\"]:\n",
    "            break\n",
    "\n",
    "    translated_sentence = [english.vocab.itos[idx] for idx in outputs]\n",
    "\n",
    "    # remove start token\n",
    "    return translated_sentence[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 57239,
     "status": "ok",
     "timestamp": 1600772254870,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "rm9yIYvIKfvj"
   },
   "outputs": [],
   "source": [
    "def save_checkpoint(state, filename=\"my_checkpoint.pth.tar\"):\n",
    "    print(\"=> Saving checkpoint\")\n",
    "    torch.save(state, filename)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 57236,
     "status": "ok",
     "timestamp": 1600772254871,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "WbNauSlDKfvm"
   },
   "outputs": [],
   "source": [
    "def load_checkpoint(checkpoint, model, optimizer):\n",
    "    print(\"=> Loading checkpoint\")\n",
    "    model.load_state_dict(checkpoint[\"state_dict\"])\n",
    "    optimizer.load_state_dict(checkpoint[\"optimizer\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 57232,
     "status": "ok",
     "timestamp": 1600772254872,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "CiNGFIGGKfvo"
   },
   "outputs": [],
   "source": [
    "if load_model:\n",
    "    load_checkpoint(torch.load(\"my_checkpoint.pth.tar\"), model, optimizer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 57228,
     "status": "ok",
     "timestamp": 1600772254872,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "rOiYTY4DKfvq"
   },
   "outputs": [],
   "source": [
    "sentence = \"Cristiano Ronaldo ist ein großartiger Fußballspieler mit erstaunlichen Fähigkeiten und Talenten.\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2443005,
     "status": "ok",
     "timestamp": 1600774640653,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "ujHJNl4oKfvt",
    "outputId": "5f49c3cd-80d4-4e0e-8ed2-ac33a20f073f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 1 Average Loss: 1.2180\n",
      "[Epoch 1 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 2 Average Loss: 1.1606\n",
      "[Epoch 2 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 3 Average Loss: 1.1486\n",
      "[Epoch 3 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 4 Average Loss: 1.1318\n",
      "[Epoch 4 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 5 Average Loss: 1.0935\n",
      "[Epoch 5 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 6 Average Loss: 1.0503\n",
      "[Epoch 6 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 7 Average Loss: 1.0472\n",
      "[Epoch 7 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 8 Average Loss: 0.9988\n",
      "[Epoch 8 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 9 Average Loss: 0.9964\n",
      "[Epoch 9 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 10 Average Loss: 0.9715\n",
      "[Epoch 10 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 11 Average Loss: 0.9670\n",
      "[Epoch 11 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 12 Average Loss: 0.8921\n",
      "[Epoch 12 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 13 Average Loss: 0.8936\n",
      "[Epoch 13 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 14 Average Loss: 0.8674\n",
      "[Epoch 14 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 15 Average Loss: 0.8695\n",
      "[Epoch 15 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 16 Average Loss: 0.8181\n",
      "[Epoch 16 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 17 Average Loss: 0.8355\n",
      "[Epoch 17 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 18 Average Loss: 0.8049\n",
      "[Epoch 18 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 19 Average Loss: 0.7724\n",
      "[Epoch 19 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 20 Average Loss: 0.7751\n",
      "[Epoch 20 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 21 Average Loss: 0.7626\n",
      "[Epoch 21 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 22 Average Loss: 0.7225\n",
      "[Epoch 22 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 23 Average Loss: 0.7171\n",
      "[Epoch 23 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 24 Average Loss: 0.6820\n",
      "[Epoch 24 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 25 Average Loss: 0.6774\n",
      "[Epoch 25 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 26 Average Loss: 0.6585\n",
      "[Epoch 26 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 27 Average Loss: 0.6298\n",
      "[Epoch 27 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 28 Average Loss: 0.6285\n",
      "[Epoch 28 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 29 Average Loss: 0.6019\n",
      "[Epoch 29 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 30 Average Loss: 0.6217\n",
      "[Epoch 30 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 31 Average Loss: 0.5736\n",
      "[Epoch 31 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 32 Average Loss: 0.5866\n",
      "[Epoch 32 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 33 Average Loss: 0.5509\n",
      "[Epoch 33 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 34 Average Loss: 0.5238\n",
      "[Epoch 34 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 35 Average Loss: 0.5194\n",
      "[Epoch 35 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 36 Average Loss: 0.4973\n",
      "[Epoch 36 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 37 Average Loss: 0.4914\n",
      "[Epoch 37 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 38 Average Loss: 0.4841\n",
      "[Epoch 38 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 39 Average Loss: 0.4762\n",
      "[Epoch 39 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 40 Average Loss: 0.4673\n",
      "[Epoch 40 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 41 Average Loss: 0.4477\n",
      "[Epoch 41 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 42 Average Loss: 0.4217\n",
      "[Epoch 42 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 43 Average Loss: 0.4397\n",
      "[Epoch 43 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 44 Average Loss: 0.4206\n",
      "[Epoch 44 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 45 Average Loss: 0.4143\n",
      "[Epoch 45 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 46 Average Loss: 0.3762\n",
      "[Epoch 46 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 47 Average Loss: 0.3790\n",
      "[Epoch 47 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 48 Average Loss: 0.3916\n",
      "[Epoch 48 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 49 Average Loss: 0.3686\n",
      "[Epoch 49 / 50]\n",
      "=> Saving checkpoint\n",
      "Epoch 50 Average Loss: 0.3652\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    print(f\"[Epoch {epoch} / {num_epochs}]\")\n",
    "\n",
    "    checkpoint = {\"state_dict\": model.state_dict(), \"optimizer\": optimizer.state_dict()}\n",
    "    save_checkpoint(checkpoint)\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    #translated_sentence = translate_sentence(\n",
    "    #    model, sentence, german, english, device, max_length=50\n",
    "    #)\n",
    "\n",
    "    #print(f\"Translated example sentence: \\n {translated_sentence}\")\n",
    "\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    for batch_idx, batch in enumerate(train_iterator):\n",
    "        # Get input and targets and get to cuda\n",
    "        inp_data = batch.src.to(device)\n",
    "        target = batch.trg.to(device)\n",
    "\n",
    "        # Forward prop\n",
    "        output = model(inp_data, target)\n",
    "\n",
    "        # Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss\n",
    "        # doesn't take input in that form. For example if we have MNIST we want to have\n",
    "        # output to be: (N, 10) and targets just (N). Here we can view it in a similar\n",
    "        # way that we have output_words * batch_size that we want to send in into\n",
    "        # our cost function, so we need to do some reshapin. While we're at it\n",
    "        # Let's also remove the start token while we're at it\n",
    "        output = output[1:].reshape(-1, output.shape[2])\n",
    "        target = target[1:].reshape(-1)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss = criterion(output, target)\n",
    "        epoch_loss += loss.item()\n",
    "        # Back prop\n",
    "        loss.backward()\n",
    "\n",
    "        # Clip to avoid exploding gradient issues, makes sure grads are\n",
    "        # within a healthy range\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n",
    "\n",
    "        # Gradient descent step\n",
    "        optimizer.step()\n",
    "\n",
    "        # Plot to tensorboard\n",
    "        writer.add_scalar(\"Training loss\", loss, global_step=step)\n",
    "        step += 1\n",
    "    avg_loss = epoch_loss / len(train_iterator)\n",
    "    print(f\"Epoch {epoch + 1} Average Loss: {avg_loss:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2443002,
     "status": "ok",
     "timestamp": 1600774640657,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "Ir5lIeOwKfvv"
   },
   "outputs": [],
   "source": [
    "def bleu(data, model, german, english, device):\n",
    "    targets = []\n",
    "    outputs = []\n",
    "\n",
    "    for example in data:\n",
    "        src = vars(example)[\"src\"]\n",
    "        trg = vars(example)[\"trg\"]\n",
    "\n",
    "        prediction = translate_sentence(model, src, german, english, device)\n",
    "        prediction = prediction[:-1]  # remove <eos> token\n",
    "\n",
    "        targets.append([trg])\n",
    "        outputs.append(prediction)\n",
    "\n",
    "    return bleu_score(outputs, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2442999,
     "status": "ok",
     "timestamp": 1600774640658,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "hvYMk4hYKfvy"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import spacy\n",
    "from torchtext.data.metrics import bleu_score\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2569844,
     "status": "ok",
     "timestamp": 1600774767507,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "f-bW5z8oKfv0",
    "outputId": "3c1a0ebf-6ff8-4383-f9aa-af0c25b2ceab"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bleu score 20.96\n"
     ]
    }
   ],
   "source": [
    "\n",
    "score = bleu(test_data[1:100], model, german, english, device)\n",
    "print(f\"Bleu score {score*100:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2569843,
     "status": "ok",
     "timestamp": 1600774767510,
     "user": {
      "displayName": "Shiv Ram Dubey",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gjtd_B-VgS3jkucANDIMEI21-5JajxJBYL4BagKOA=s64",
      "userId": "14553895138990175535"
     },
     "user_tz": -330
    },
    "id": "-PN1sqf5Kfv2"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import spacy\n",
    "from torchtext.data.metrics import bleu_score\n",
    "import sys\n",
    "\n",
    "# 定义 BLEU 分数计算函数\n",
    "def bleu(data, model, german, english, device):\n",
    "    targets = []\n",
    "    outputs = []\n",
    "\n",
    "    for example in data:\n",
    "        src = vars(example)[\"src\"]\n",
    "        trg = vars(example)[\"trg\"]\n",
    "\n",
    "        # 翻译句子\n",
    "        prediction = translate_sentence(model, src, german, english, device)\n",
    "        prediction = prediction[:-1]  # 移除 <eos> token\n",
    "\n",
    "        targets.append([trg])\n",
    "        outputs.append(prediction)\n",
    "\n",
    "    return bleu_score(outputs, targets)\n",
    "\n",
    "\n",
    "def save_checkpoint(state, filename=\"my_checkpoint.pth.tar\"):\n",
    "    print(\"=> Saving checkpoint\")\n",
    "    torch.save(state, filename)\n",
    "\n",
    "\n",
    "num_epochs = 10  \n",
    "loss_results = []  \n",
    "bleu_results = []  \n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    print(f\"[Epoch {epoch + 1} / {num_epochs}]\")\n",
    "\n",
    " \n",
    "    checkpoint = {\"state_dict\": model.state_dict(), \"optimizer\": optimizer.state_dict()}\n",
    "    save_checkpoint(checkpoint)\n",
    "\n",
    "    model.eval()  \n",
    "    model.train()  \n",
    "\n",
    "    epoch_loss = 0  \n",
    "    for batch_idx, batch in enumerate(train_iterator):\n",
    "        \n",
    "        inp_data = batch.src.to(device)\n",
    "        target = batch.trg.to(device)\n",
    "\n",
    "  \n",
    "        output = model(inp_data, target)\n",
    "\n",
    "\n",
    "        output = output[1:].reshape(-1, output.shape[2])\n",
    "        target = target[1:].reshape(-1)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss = criterion(output, target)\n",
    "        epoch_loss += loss.item()\n",
    "\n",
    "        loss.backward()\n",
    "\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n",
    "\n",
    "        optimizer.step()\n",
    "\n",
    "    avg_loss = epoch_loss / len(train_iterator)\n",
    "    print(f\"Epoch {epoch + 1} Average Loss: {avg_loss:.4f}\")\n",
    "    loss_results.append(avg_loss)  \n",
    "\n",
    "    score = bleu(test_data[1:100], model, german, english, device)\n",
    "    bleu_results.append(score * 100)  \n",
    "    print(f\"Epoch {epoch + 1} BLEU Score: {score * 100:.2f}\")\n",
    "\n",
    "\n",
    "print(\"\\nTraining Complete!\")\n",
    "print(\"Loss Results:\", loss_results)\n",
    "print(\"BLEU Results:\", bleu_results)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "LReLU_Seq2Seq-MachineTranslation.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "kan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
