{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/sean/.conda/envs/quantumtorch/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchquantum as tq\n",
    "import torchquantum.functional as tqf\n",
    "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrimEnc(tq.QuantumModule):\n",
    "        def __init__(self, wire) -> None:\n",
    "            super().__init__()\n",
    "            self.wire = wire\n",
    "            self.encoder = tq.QuantumModuleList([tq.H(), tq.RY(), tq.RY()])\n",
    "        def forward(self, qdevice, x):\n",
    "            '''\n",
    "            x.shape is (batch_size, :)\n",
    "            '''\n",
    "            self.encoder[0](qdevice, wires = self.wire)\n",
    "            self.encoder[1](qdevice, wires = self.wire, params = torch.arctan(x))\n",
    "            self.encoder[2](qdevice, wires = self.wire, params = torch.arctan(x**2))\n",
    "\n",
    "class QLSTM(nn.Module):\n",
    "    # use 'qiskit.ibmq' instead to run on hardware\n",
    "\n",
    "    class QLayer_forget(tq.QuantumModule):\n",
    "        def __init__(self):\n",
    "            super().__init__()    \n",
    "            self.n_wires = 4\n",
    "            self.encoder = tq.QuantumModuleList([PrimEnc(0), PrimEnc(1), PrimEnc(2), PrimEnc(3)])\n",
    "            self.rotate = tq.QuantumModuleList([tq.CNOT(wires=[0,1]), tq.CNOT(wires=[1,2]), tq.CNOT(wires=[2,3]), tq.CNOT(wires=[3,0]), tq.CNOT(wires=[0,2]), tq.CNOT(wires=[1,3]), tq.CNOT(wires=[2,0]), tq.CNOT(wires=[3,1])])\n",
    "            self.rx0 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx1 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx2 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx3 = tq.U3(has_params=True, trainable=True)\n",
    "            self.measure = tq.MeasureAll(tq.PauliZ)\n",
    "\n",
    "        def forward(self, x, fs=[None]*4):\n",
    "            qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=x.shape[0], device=x.device)\n",
    "            for i in range(4):\n",
    "                self.encoder[i](qdev, x[:, i])\n",
    "            for i in range(8):\n",
    "                self.rotate[i](qdev)\n",
    "            self.rx0(qdev, wires=0)\n",
    "            if fs[0]:\n",
    "                fs[0](qdev, [0])\n",
    "            self.rx1(qdev, wires=1)\n",
    "            if fs[1]:\n",
    "                fs[1](qdev, [1])\n",
    "            self.rx2(qdev, wires=2)\n",
    "            if fs[2]:\n",
    "                fs[2](qdev, [2])\n",
    "            self.rx3(qdev, wires=3)\n",
    "            if fs[3]:\n",
    "                fs[3](qdev, [3])\n",
    "            return(self.measure(qdev))\n",
    "        \n",
    "    class QLayer_input(tq.QuantumModule):\n",
    "        def __init__(self):\n",
    "            super().__init__()    \n",
    "            self.n_wires = 4\n",
    "            self.encoder = tq.QuantumModuleList([PrimEnc(0), PrimEnc(1), PrimEnc(2), PrimEnc(3)])\n",
    "            self.rotate = tq.QuantumModuleList([tq.CNOT(wires=[0,1]), tq.CNOT(wires=[1,2]), tq.CNOT(wires=[2,3]), tq.CNOT(wires=[3,0]), tq.CNOT(wires=[0,2]), tq.CNOT(wires=[1,3]), tq.CNOT(wires=[2,0]), tq.CNOT(wires=[3,1])])\n",
    "            self.rx0 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx1 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx2 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx3 = tq.U3(has_params=True, trainable=True)\n",
    "            self.measure = tq.MeasureAll(tq.PauliZ)\n",
    "\n",
    "        def forward(self, x, fs=[None]*4):\n",
    "            qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=x.shape[0], device=x.device)\n",
    "            for i in range(4):\n",
    "                self.encoder[i](qdev, x[:, i])\n",
    "            for i in range(8):\n",
    "                self.rotate[i](qdev)\n",
    "            self.rx0(qdev, wires=0)\n",
    "            if fs[0]:\n",
    "                fs[0](qdev, [0])\n",
    "            self.rx1(qdev, wires=1)\n",
    "            if fs[1]:\n",
    "                fs[1](qdev, [1])\n",
    "            self.rx2(qdev, wires=2)\n",
    "            if fs[2]:\n",
    "                fs[2](qdev, [2])\n",
    "            self.rx3(qdev, wires=3)\n",
    "            if fs[3]:\n",
    "                fs[3](qdev, [3])\n",
    "            return(self.measure(qdev))\n",
    "        \n",
    "    class QLayer_update(tq.QuantumModule):\n",
    "        def __init__(self):\n",
    "            super().__init__()    \n",
    "            self.n_wires = 4\n",
    "            self.encoder = tq.QuantumModuleList([PrimEnc(0), PrimEnc(1), PrimEnc(2), PrimEnc(3)])\n",
    "            self.rotate = tq.QuantumModuleList([tq.CNOT(wires=[0,1]), tq.CNOT(wires=[1,2]), tq.CNOT(wires=[2,3]), tq.CNOT(wires=[3,0]), tq.CNOT(wires=[0,2]), tq.CNOT(wires=[1,3]), tq.CNOT(wires=[2,0]), tq.CNOT(wires=[3,1])])\n",
    "            self.rx0 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx1 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx2 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx3 = tq.U3(has_params=True, trainable=True)\n",
    "            self.measure = tq.MeasureAll(tq.PauliZ)\n",
    "\n",
    "        def forward(self, x, fs=[None]*4):\n",
    "            qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=x.shape[0], device=x.device)\n",
    "            for i in range(4):\n",
    "                self.encoder[i](qdev, x[:, i])\n",
    "            for i in range(8):\n",
    "                self.rotate[i](qdev)\n",
    "            self.rx0(qdev, wires=0)\n",
    "            if fs[0]:\n",
    "                fs[0](qdev, [0])\n",
    "            self.rx1(qdev, wires=1)\n",
    "            if fs[1]:\n",
    "                fs[1](qdev, [1])\n",
    "            self.rx2(qdev, wires=2)\n",
    "            if fs[2]:\n",
    "                fs[2](qdev, [2])\n",
    "            self.rx3(qdev, wires=3)\n",
    "            if fs[3]:\n",
    "                fs[3](qdev, [3])\n",
    "            return(self.measure(qdev))\n",
    "        \n",
    "    class QLayer_output(tq.QuantumModule):\n",
    "        def __init__(self):\n",
    "            super().__init__()    \n",
    "            self.n_wires = 4\n",
    "            self.encoder = tq.QuantumModuleList([PrimEnc(0), PrimEnc(1), PrimEnc(2), PrimEnc(3)])\n",
    "            self.rotate = tq.QuantumModuleList([tq.CNOT(wires=[0,1]), tq.CNOT(wires=[1,2]), tq.CNOT(wires=[2,3]), tq.CNOT(wires=[3,0]), tq.CNOT(wires=[0,2]), tq.CNOT(wires=[1,3]), tq.CNOT(wires=[2,0]), tq.CNOT(wires=[3,1])])\n",
    "            self.rx0 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx1 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx2 = tq.U3(has_params=True, trainable=True)\n",
    "            self.rx3 = tq.U3(has_params=True, trainable=True)\n",
    "            self.measure = tq.MeasureAll(tq.PauliZ)\n",
    "\n",
    "        def forward(self, x, fs=[None]*4):\n",
    "            qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=x.shape[0], device=x.device)\n",
    "            for i in range(4):\n",
    "                self.encoder[i](qdev, x[:, i])\n",
    "            for i in range(8):\n",
    "                self.rotate[i](qdev)\n",
    "            self.rx0(qdev, wires=0)\n",
    "            if fs[0]:\n",
    "                fs[0](qdev, [0])\n",
    "            self.rx1(qdev, wires=1)\n",
    "            if fs[1]:\n",
    "                fs[1](qdev, [1])\n",
    "            self.rx2(qdev, wires=2)\n",
    "            if fs[2]:\n",
    "                fs[2](qdev, [2])\n",
    "            self.rx3(qdev, wires=3)\n",
    "            if fs[3]:\n",
    "                fs[3](qdev, [3])\n",
    "            return(self.measure(qdev))\n",
    "        \n",
    "    def __init__(self, \n",
    "                input_size, \n",
    "                hidden_size, \n",
    "                n_qubits=4,\n",
    "                n_qlayers=1,\n",
    "                batch_first=False,\n",
    "                return_sequences=False, \n",
    "                return_state=False,\n",
    "                backend=\"default.qubit\"):\n",
    "        super(QLSTM, self).__init__()\n",
    "        self.n_inputs = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.concat_size = self.n_inputs + self.hidden_size\n",
    "        self.n_qubits = n_qubits\n",
    "        self.n_qlayers = n_qlayers\n",
    "        self.backend = backend  # \"default.qubit\", \"qiskit.basicaer\", \"qiskit.ibm\"\n",
    "\n",
    "        self.batch_first = batch_first\n",
    "        self.return_sequences = return_sequences\n",
    "        self.return_state = return_state\n",
    "\n",
    "        self.clayer_in = torch.nn.Linear(self.concat_size, n_qubits)\n",
    "        self.forget = self.QLayer_forget()\n",
    "        self.input = self.QLayer_input()\n",
    "        self.update = self.QLayer_update()\n",
    "        self.output = self.QLayer_output()\n",
    "        self.clayer_out = torch.nn.Linear(self.n_qubits, self.hidden_size)\n",
    "        #self.clayer_out = [torch.nn.Linear(n_qubits, self.hidden_size) for _ in range(4)]\n",
    "\n",
    "    def forward(self, x, init_states=None, fs=[[None]*4 for _ in range(4)]):\n",
    "        '''\n",
    "        x.shape is (seq_length, batch_size, feature_size)\n",
    "        recurrent_activation -> sigmoid\n",
    "        activation -> tanh\n",
    "        '''\n",
    "        if self.batch_first is True:\n",
    "            batch_size, seq_length, features_size = x.size()\n",
    "        else:\n",
    "            seq_length, batch_size, features_size = x.size()\n",
    "\n",
    "        hidden_seq = []\n",
    "        if init_states is None:\n",
    "            h_t = torch.zeros(batch_size, self.hidden_size).to(device)  # hidden state (output)\n",
    "            c_t = torch.zeros(batch_size, self.hidden_size).to(device)  # cell state\n",
    "        else:\n",
    "            # for now we ignore the fact that in PyTorch you can stack multiple RNNs\n",
    "            # so we take only the first elements of the init_states tuple init_states[0][0], init_states[1][0]\n",
    "            h_t, c_t = init_states\n",
    "            h_t = h_t[0]\n",
    "            c_t = c_t[0]\n",
    "        for t in range(seq_length):\n",
    "            # get features from the t-th element in seq, for all entries in the batch\n",
    "            x_t = x[t, :, :]\n",
    "            \n",
    "            # Concatenate input and hidden state\n",
    "            v_t = torch.cat((h_t, x_t), dim=1)\n",
    "\n",
    "            # match qubit dimension\n",
    "            y_t = self.clayer_in(v_t)\n",
    "\n",
    "            f_t = torch.sigmoid(self.clayer_out(self.forget(y_t, fs[0])))  # forget block\n",
    "            i_t = torch.sigmoid(self.clayer_out(self.input(y_t, fs[1])))  # input block\n",
    "            g_t = torch.tanh(self.clayer_out(self.update(y_t, fs[2])))  # update block\n",
    "            o_t = torch.sigmoid(self.clayer_out(self.output(y_t, fs[3]))) # output block\n",
    "\n",
    "            c_t = (f_t * c_t) + (i_t * g_t)\n",
    "            h_t = o_t * torch.tanh(c_t)\n",
    "\n",
    "            hidden_seq.append(h_t.unsqueeze(0))\n",
    "        hidden_seq = torch.cat(hidden_seq, dim=0)\n",
    "        hidden_seq = hidden_seq.transpose(0, 1).contiguous()\n",
    "        return hidden_seq, (h_t, c_t)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# POS tagging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tag_to_ix = {\"DET\": 0, \"NN\": 1, \"V\": 2, \"ADJ\": 3, \"ADV\": 4, \"PRON\": 5}  # Assign each tag with a unique index\n",
    "ix_to_tag = {i:k for k,i in tag_to_ix.items()}\n",
    "\n",
    "def prepare_sequence(seq, to_ix):\n",
    "    idxs = [to_ix[w] for w in seq]\n",
    "    return torch.tensor(idxs, dtype=torch.long).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocabulary: {'I': 0, 'love': 1, 'the': 2, 'beautiful': 3, 'garden': 4, 'She': 5, 'quickly': 6, 'eats': 7, 'delicious': 8, 'cake': 9, 'The': 10, 'big': 11, 'dog': 12, 'chases': 13, 'small': 14, 'cat': 15, 'My': 16, 'friend': 17, 'sings': 18, 'very': 19, 'well': 20, 'You': 21, 'should': 22, 'always': 23, 'try': 24, 'your': 25, 'best': 26, 'We': 27, 'enjoy': 28, 'sunny': 29, 'weather': 30, 'today': 31, 'His': 32, 'new': 33, 'car': 34, 'runs': 35, 'smoothly': 36, 'They': 37, 'will': 38, 'visit': 39, 'old': 40, 'museum': 41, 'soon': 42, 'dances': 43, 'gracefully': 44, 'tonight': 45, 'read': 46, 'interesting': 47, 'books': 48, 'often': 49, 'He': 50, 'drives': 51, 'fast': 52, 'cars': 53, 'every': 54, 'weekend': 55, 'Their': 56, 'little': 57, 'child': 58, 'sleeps': 59, 'quietly': 60, 'write': 61, 'good': 62, 'stories': 63, 'sometimes': 64, 'happy': 65, 'children': 66, 'play': 67, 'outside': 68, 'watch': 69, 'funny': 70, 'movies': 71, 'together': 72, 'helps': 73, 'her': 74, 'friends': 75, 'usually': 76, 'drink': 77, 'hot': 78, 'coffee': 79, 'warm': 80, 'sunshine': 81, 'writes': 82, 'probably': 83, 'finish': 84, 'project': 85, 'successfully': 86, 'next': 87, 'week': 88}\n",
      "Entities: {0: 'DET', 1: 'NN', 2: 'V', 3: 'ADJ', 4: 'ADV', 5: 'PRON'}\n"
     ]
    }
   ],
   "source": [
    "training_data = [\n",
    "    # Tags are: DET - determiner; NN - noun; V - verb; ADJ - adjective; ADV - adverb; PRON - pronoun\n",
    "    # For example, the word \"The\" is a determiner\n",
    "    (\"I love the beautiful garden\".split(), [\"PRON\", \"V\", \"DET\", \"ADJ\", \"NN\"]),\n",
    "    (\"She quickly eats the delicious cake\".split(), [\"PRON\", \"ADV\", \"V\", \"DET\", \"ADJ\", \"NN\"]),\n",
    "    (\"The big dog chases the small cat\".split(), [\"DET\", \"ADJ\", \"NN\", \"V\", \"DET\", \"ADJ\", \"NN\"]),\n",
    "    (\"My friend sings very well\".split(), [\"DET\", \"NN\", \"V\", \"ADV\", \"ADV\"]),\n",
    "    (\"You should always try your best\".split(), [\"PRON\", \"V\", \"ADV\", \"V\", \"DET\", \"NN\"]),\n",
    "    (\"We enjoy the sunny weather today\".split(), [\"PRON\", \"V\", \"DET\", \"ADJ\", \"NN\", \"ADV\"]),\n",
    "    (\"His new car runs very smoothly\".split(), [\"DET\", \"ADJ\", \"NN\", \"V\", \"ADV\", \"ADV\"]),\n",
    "    (\"They will visit the old museum soon\".split(), [\"PRON\", \"V\", \"V\", \"DET\", \"ADJ\", \"NN\", \"ADV\"]),\n",
    "    (\"She dances gracefully tonight\".split(), [\"PRON\", \"V\", \"ADV\", \"ADV\"]),\n",
    "    (\"I read interesting books often\".split(), [\"PRON\", \"V\", \"ADJ\", \"NN\", \"ADV\"]),\n",
    "    (\"He drives fast cars every weekend\".split(), [\"PRON\", \"V\", \"ADJ\", \"NN\", \"DET\", \"NN\"]),\n",
    "    (\"Their little child sleeps very quietly\".split(), [\"DET\", \"ADJ\", \"NN\", \"V\", \"ADV\", \"ADV\"]),\n",
    "    (\"You write good stories sometimes\".split(), [\"PRON\", \"V\", \"ADJ\", \"NN\", \"ADV\"]),\n",
    "    (\"The happy children play outside\".split(), [\"DET\", \"ADJ\", \"NN\", \"V\", \"ADV\"]),\n",
    "    (\"We watch funny movies together\".split(), [\"PRON\", \"V\", \"ADJ\", \"NN\", \"ADV\"]),\n",
    "    (\"She always helps her friends\".split(), [\"PRON\", \"ADV\", \"V\", \"DET\", \"NN\"]),\n",
    "    (\"I usually drink hot coffee\".split(), [\"PRON\", \"ADV\", \"V\", \"ADJ\", \"NN\"]),\n",
    "    (\"They enjoy the warm sunshine today\".split(), [\"PRON\", \"V\", \"DET\", \"ADJ\", \"NN\", \"ADV\"]),\n",
    "    (\"My old friend always writes interesting funny stories\".split(), [\"DET\", \"ADJ\", \"NN\", \"ADV\", \"V\", \"ADJ\", \"ADJ\", \"NN\"]),\n",
    "    (\"They will probably finish the new project successfully next week\".split(), [\"PRON\", \"V\", \"ADV\", \"V\", \"DET\", \"ADJ\", \"NN\", \"ADV\", \"ADJ\", \"NN\"])\n",
    "]\n",
    "word_to_ix = {}\n",
    "\n",
    "# For each words-list (sentence) and tags-list in each tuple of training_data\n",
    "for sent, tags in training_data:\n",
    "    for word in sent:\n",
    "        if word not in word_to_ix:  # word has not been assigned an index yet\n",
    "            word_to_ix[word] = len(word_to_ix)  # Assign each word with a unique index\n",
    "\n",
    "print(f\"Vocabulary: {word_to_ix}\")\n",
    "print(f\"Entities: {ix_to_tag}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMTagger(nn.Module):\n",
    "\n",
    "    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, n_qubits=0):\n",
    "        super(LSTMTagger, self).__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "\n",
    "        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n",
    "\n",
    "        # The LSTM takes word embeddings as inputs, and outputs hidden states\n",
    "        # with dimensionality hidden_dim.\n",
    "        if n_qubits > 0:\n",
    "            print(\"Tagger will use Quantum LSTM\")\n",
    "            self.lstm = QLSTM(embedding_dim, hidden_dim, n_qubits=n_qubits)\n",
    "        else:\n",
    "            print(\"Tagger will use Classical LSTM\")\n",
    "            self.lstm = nn.LSTM(embedding_dim, hidden_dim)\n",
    "\n",
    "        # The linear layer that maps from hidden state space to tag space\n",
    "        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)\n",
    "\n",
    "    def forward(self, sentence, fs=[[None]*4 for _ in range(4)]):\n",
    "        embeds = self.word_embeddings(sentence)\n",
    "        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1), fs=fs)\n",
    "        tag_logits = self.hidden2tag(lstm_out.view(len(sentence), -1))\n",
    "        tag_scores = F.log_softmax(tag_logits, dim=1)\n",
    "        return tag_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, n_epochs, training_data, optimizer, scheduler, word_to_ix, tag_to_ix):\n",
    "    loss_function = nn.NLLLoss()\n",
    "    \n",
    "\n",
    "    history = {\n",
    "        'loss': [],\n",
    "        'acc': []\n",
    "    }\n",
    "    for epoch in range(n_epochs):\n",
    "        losses = []\n",
    "        preds = []\n",
    "        targets = []\n",
    "        for sentence, tags in training_data:\n",
    "            # Step 1. Remember that Pytorch accumulates gradients.\n",
    "            # We need to clear them out before each instance\n",
    "            model.zero_grad()\n",
    "\n",
    "            # Step 2. Get our inputs ready for the network, that is, turn them into\n",
    "            # Tensors of word indices.\n",
    "            sentence_in = prepare_sequence(sentence, word_to_ix)\n",
    "            labels = prepare_sequence(tags, tag_to_ix)\n",
    "\n",
    "            # Step 3. Run our forward pass.\n",
    "            tag_scores = model(sentence_in)\n",
    "\n",
    "            # Step 4. Compute the loss, gradients, and update the parameters by\n",
    "            #  calling optimizer.step()\n",
    "            loss = loss_function(tag_scores, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            losses.append(float(loss))\n",
    "            \n",
    "            probs = torch.softmax(tag_scores, dim=-1)\n",
    "            preds.append(probs.argmax(dim=-1))\n",
    "            targets.append(labels)\n",
    "\n",
    "        avg_loss = np.mean(losses)\n",
    "        history['loss'].append(avg_loss)\n",
    "        \n",
    "        preds = torch.cat(preds)\n",
    "        targets = torch.cat(targets)\n",
    "        corrects = (preds == targets)\n",
    "        accuracy = corrects.sum().float() / float(targets.size(0) )\n",
    "        history['acc'].append(accuracy)\n",
    "\n",
    "        print(f\"Epoch {epoch+1} / {n_epochs}: Loss = {avg_loss:.3f} Acc = {accuracy:.2f}\")\n",
    "\n",
    "        scheduler.step()\n",
    "    return history\n",
    "\n",
    "def validate(model, test_data, word_to_ix, tag_to_ix):\n",
    "    preds = []\n",
    "    targets = []\n",
    "    with torch.no_grad():\n",
    "        for sentence, tags in test_data:\n",
    "            sentence_in = prepare_sequence(sentence, word_to_ix)\n",
    "            labels = prepare_sequence(tags, tag_to_ix)\n",
    "\n",
    "            # Step 3. Run our forward pass.\n",
    "            tag_scores = model(sentence_in)\n",
    "            \n",
    "            probs = torch.softmax(tag_scores, dim=-1)\n",
    "            preds.append(probs.argmax(dim=-1))\n",
    "            targets.append(labels)\n",
    "    preds = torch.cat(preds)\n",
    "    targets = torch.cat(targets)\n",
    "    corrects = (preds == targets)\n",
    "    accuracy = corrects.sum().float() / float(targets.size(0) )\n",
    "\n",
    "    return accuracy\n",
    "\n",
    "def print_result(model, training_data, word_to_ix, ix_to_tag):\n",
    "    with torch.no_grad():\n",
    "        input_sentence = training_data[0][0]\n",
    "        labels = training_data[0][1]\n",
    "        inputs = prepare_sequence(input_sentence, word_to_ix)\n",
    "        tag_scores = model(inputs)\n",
    "\n",
    "        tag_ids = torch.argmax(tag_scores, dim=1).numpy()\n",
    "        tag_labels = [ix_to_tag[k] for k in tag_ids]\n",
    "        print(f\"Sentence:  {input_sentence}\")\n",
    "        print(f\"Labels:    {labels}\")\n",
    "        print(f\"Predicted: {tag_labels}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tagger will use Quantum LSTM\n",
      "Epoch 1 / 50: Loss = 1.781 Acc = 0.18\n",
      "Epoch 2 / 50: Loss = 1.715 Acc = 0.21\n",
      "Epoch 3 / 50: Loss = 1.433 Acc = 0.42\n",
      "Epoch 4 / 50: Loss = 1.301 Acc = 0.46\n",
      "Epoch 5 / 50: Loss = 1.195 Acc = 0.52\n",
      "Epoch 6 / 50: Loss = 0.992 Acc = 0.60\n",
      "Epoch 7 / 50: Loss = 0.819 Acc = 0.69\n",
      "Epoch 8 / 50: Loss = 0.974 Acc = 0.62\n",
      "Epoch 9 / 50: Loss = 0.837 Acc = 0.66\n",
      "Epoch 10 / 50: Loss = 0.642 Acc = 0.72\n",
      "Epoch 11 / 50: Loss = 0.515 Acc = 0.81\n",
      "Epoch 12 / 50: Loss = 0.544 Acc = 0.76\n",
      "Epoch 13 / 50: Loss = 0.579 Acc = 0.77\n",
      "Epoch 14 / 50: Loss = 0.448 Acc = 0.85\n",
      "Epoch 15 / 50: Loss = 0.426 Acc = 0.85\n",
      "Epoch 16 / 50: Loss = 0.407 Acc = 0.82\n",
      "Epoch 17 / 50: Loss = 0.381 Acc = 0.86\n",
      "Epoch 18 / 50: Loss = 0.351 Acc = 0.86\n",
      "Epoch 19 / 50: Loss = 0.333 Acc = 0.85\n",
      "Epoch 20 / 50: Loss = 0.320 Acc = 0.86\n",
      "Epoch 21 / 50: Loss = 0.311 Acc = 0.86\n",
      "Epoch 22 / 50: Loss = 0.301 Acc = 0.86\n",
      "Epoch 23 / 50: Loss = 0.291 Acc = 0.88\n",
      "Epoch 24 / 50: Loss = 0.286 Acc = 0.91\n",
      "Epoch 25 / 50: Loss = 0.277 Acc = 0.92\n",
      "Epoch 26 / 50: Loss = 0.269 Acc = 0.91\n",
      "Epoch 27 / 50: Loss = 0.263 Acc = 0.92\n",
      "Epoch 28 / 50: Loss = 0.258 Acc = 0.91\n",
      "Epoch 29 / 50: Loss = 0.258 Acc = 0.92\n",
      "Epoch 30 / 50: Loss = 0.250 Acc = 0.92\n",
      "Epoch 31 / 50: Loss = 0.246 Acc = 0.91\n",
      "Epoch 32 / 50: Loss = 0.241 Acc = 0.92\n",
      "Epoch 33 / 50: Loss = 0.237 Acc = 0.92\n",
      "Epoch 34 / 50: Loss = 0.231 Acc = 0.92\n",
      "Epoch 35 / 50: Loss = 0.255 Acc = 0.90\n",
      "Epoch 36 / 50: Loss = 0.238 Acc = 0.92\n",
      "Epoch 37 / 50: Loss = 0.228 Acc = 0.92\n",
      "Epoch 38 / 50: Loss = 0.225 Acc = 0.92\n",
      "Epoch 39 / 50: Loss = 0.223 Acc = 0.92\n",
      "Epoch 40 / 50: Loss = 0.220 Acc = 0.92\n",
      "Epoch 41 / 50: Loss = 0.218 Acc = 0.92\n",
      "Epoch 42 / 50: Loss = 0.217 Acc = 0.93\n",
      "Epoch 43 / 50: Loss = 0.215 Acc = 0.93\n",
      "Epoch 44 / 50: Loss = 0.214 Acc = 0.93\n",
      "Epoch 45 / 50: Loss = 0.213 Acc = 0.93\n",
      "Epoch 46 / 50: Loss = 0.213 Acc = 0.93\n",
      "Epoch 47 / 50: Loss = 0.212 Acc = 0.93\n",
      "Epoch 48 / 50: Loss = 0.212 Acc = 0.93\n",
      "Epoch 49 / 50: Loss = 0.211 Acc = 0.93\n",
      "Epoch 50 / 50: Loss = 0.211 Acc = 0.93\n"
     ]
    }
   ],
   "source": [
    "embedding_dim = 16\n",
    "hidden_dim = 32\n",
    "n_epochs = 50\n",
    "\n",
    "n_qubits = 4\n",
    "\n",
    "model_quantum = LSTMTagger(embedding_dim, \n",
    "                        hidden_dim, \n",
    "                        vocab_size=len(word_to_ix), \n",
    "                        tagset_size=len(tag_to_ix), \n",
    "                        n_qubits=n_qubits)\n",
    "\n",
    "optimizer = optim.Adam(model_quantum.parameters(), lr=0.03, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)\n",
    "\n",
    "history_quantum = train(model_quantum, n_epochs, training_data, optimizer, scheduler, word_to_ix, tag_to_ix)\n",
    "\n",
    "torch.save(model_quantum.state_dict(), 'lstm_baseline.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NAT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, n_epochs, training_data, optimizer, scheduler, word_to_ix, tag_to_ix, errorRate=0):\n",
    "    def randomSample(errorRate):\n",
    "        fault = np.random.choice([None, tq.PauliX(), tq.PauliY(), tq.PauliZ()], size=(4,4), p=[1-errorRate*3, *([errorRate]*3)])\n",
    "        return fault\n",
    "    \n",
    "    loss_function = nn.NLLLoss()\n",
    "    \n",
    "\n",
    "    history = {\n",
    "        'loss': [],\n",
    "        'acc': []\n",
    "    }\n",
    "    for epoch in range(n_epochs):\n",
    "        losses = []\n",
    "        preds = []\n",
    "        targets = []\n",
    "        for sentence, tags in training_data:\n",
    "            # Step 1. Remember that Pytorch accumulates gradients.\n",
    "            # We need to clear them out before each instance\n",
    "            model.zero_grad()\n",
    "\n",
    "            # Step 2. Get our inputs ready for the network, that is, turn them into\n",
    "            # Tensors of word indices.\n",
    "            sentence_in = prepare_sequence(sentence, word_to_ix)\n",
    "            labels = prepare_sequence(tags, tag_to_ix)\n",
    "\n",
    "            # Step 3. Run our forward pass.\n",
    "            fault = randomSample(errorRate)\n",
    "            tag_scores = model(sentence_in, fault)\n",
    "\n",
    "            # Step 4. Compute the loss, gradients, and update the parameters by\n",
    "            #  calling optimizer.step()\n",
    "            loss = loss_function(tag_scores, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            losses.append(float(loss))\n",
    "            \n",
    "            probs = torch.softmax(tag_scores, dim=-1)\n",
    "            preds.append(probs.argmax(dim=-1))\n",
    "            targets.append(labels)\n",
    "\n",
    "        avg_loss = np.mean(losses)\n",
    "        history['loss'].append(avg_loss)\n",
    "        \n",
    "        preds = torch.cat(preds)\n",
    "        targets = torch.cat(targets)\n",
    "        corrects = (preds == targets)\n",
    "        accuracy = corrects.sum().float() / float(targets.size(0) )\n",
    "        history['acc'].append(accuracy)\n",
    "\n",
    "        print(f\"Epoch {epoch+1} / {n_epochs}: Loss = {avg_loss:.3f} Acc = {accuracy:.2f}\")\n",
    "\n",
    "        scheduler.step()\n",
    "    return history\n",
    "\n",
    "def validate(model, test_data, word_to_ix, tag_to_ix):\n",
    "    preds = []\n",
    "    targets = []\n",
    "    with torch.no_grad():\n",
    "        for sentence, tags in test_data:\n",
    "            sentence_in = prepare_sequence(sentence, word_to_ix)\n",
    "            labels = prepare_sequence(tags, tag_to_ix)\n",
    "\n",
    "            # Step 3. Run our forward pass.\n",
    "            tag_scores = model(sentence_in)\n",
    "            \n",
    "            probs = torch.softmax(tag_scores, dim=-1)\n",
    "            preds.append(probs.argmax(dim=-1))\n",
    "            targets.append(labels)\n",
    "    preds = torch.cat(preds)\n",
    "    targets = torch.cat(targets)\n",
    "    corrects = (preds == targets)\n",
    "    accuracy = corrects.sum().float() / float(targets.size(0) )\n",
    "\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tagger will use Quantum LSTM\n",
      "Epoch 1 / 50: Loss = 1.771 Acc = 0.19\n",
      "Epoch 2 / 50: Loss = 1.593 Acc = 0.32\n",
      "Epoch 3 / 50: Loss = 1.430 Acc = 0.41\n",
      "Epoch 4 / 50: Loss = 1.310 Acc = 0.45\n",
      "Epoch 5 / 50: Loss = 1.240 Acc = 0.51\n",
      "Epoch 6 / 50: Loss = 0.993 Acc = 0.64\n",
      "Epoch 7 / 50: Loss = 0.882 Acc = 0.61\n",
      "Epoch 8 / 50: Loss = 0.817 Acc = 0.58\n",
      "Epoch 9 / 50: Loss = 0.736 Acc = 0.64\n",
      "Epoch 10 / 50: Loss = 0.846 Acc = 0.64\n",
      "Epoch 11 / 50: Loss = 0.846 Acc = 0.64\n",
      "Epoch 12 / 50: Loss = 0.867 Acc = 0.65\n",
      "Epoch 13 / 50: Loss = 0.762 Acc = 0.64\n",
      "Epoch 14 / 50: Loss = 0.755 Acc = 0.64\n",
      "Epoch 15 / 50: Loss = 0.665 Acc = 0.71\n",
      "Epoch 16 / 50: Loss = 0.726 Acc = 0.64\n",
      "Epoch 17 / 50: Loss = 0.718 Acc = 0.69\n",
      "Epoch 18 / 50: Loss = 0.620 Acc = 0.73\n",
      "Epoch 19 / 50: Loss = 0.660 Acc = 0.73\n",
      "Epoch 20 / 50: Loss = 0.610 Acc = 0.73\n",
      "Epoch 21 / 50: Loss = 0.578 Acc = 0.75\n",
      "Epoch 22 / 50: Loss = 0.546 Acc = 0.72\n",
      "Epoch 23 / 50: Loss = 0.545 Acc = 0.76\n",
      "Epoch 24 / 50: Loss = 0.526 Acc = 0.78\n",
      "Epoch 25 / 50: Loss = 0.513 Acc = 0.79\n",
      "Epoch 26 / 50: Loss = 0.501 Acc = 0.79\n",
      "Epoch 27 / 50: Loss = 0.520 Acc = 0.77\n",
      "Epoch 28 / 50: Loss = 0.456 Acc = 0.81\n",
      "Epoch 29 / 50: Loss = 0.442 Acc = 0.83\n",
      "Epoch 30 / 50: Loss = 0.423 Acc = 0.82\n",
      "Epoch 31 / 50: Loss = 0.415 Acc = 0.82\n",
      "Epoch 32 / 50: Loss = 0.411 Acc = 0.84\n",
      "Epoch 33 / 50: Loss = 0.402 Acc = 0.86\n",
      "Epoch 34 / 50: Loss = 0.381 Acc = 0.86\n",
      "Epoch 35 / 50: Loss = 0.382 Acc = 0.86\n",
      "Epoch 36 / 50: Loss = 0.371 Acc = 0.91\n",
      "Epoch 37 / 50: Loss = 0.361 Acc = 0.91\n",
      "Epoch 38 / 50: Loss = 0.349 Acc = 0.88\n",
      "Epoch 39 / 50: Loss = 0.349 Acc = 0.90\n",
      "Epoch 40 / 50: Loss = 0.344 Acc = 0.89\n",
      "Epoch 41 / 50: Loss = 0.340 Acc = 0.90\n",
      "Epoch 42 / 50: Loss = 0.337 Acc = 0.90\n",
      "Epoch 43 / 50: Loss = 0.364 Acc = 0.87\n",
      "Epoch 44 / 50: Loss = 0.333 Acc = 0.90\n",
      "Epoch 45 / 50: Loss = 0.345 Acc = 0.87\n",
      "Epoch 46 / 50: Loss = 0.330 Acc = 0.91\n",
      "Epoch 47 / 50: Loss = 0.329 Acc = 0.91\n",
      "Epoch 48 / 50: Loss = 0.329 Acc = 0.91\n",
      "Epoch 49 / 50: Loss = 0.329 Acc = 0.91\n",
      "Epoch 50 / 50: Loss = 0.328 Acc = 0.91\n"
     ]
    }
   ],
   "source": [
    "er = 0.001\n",
    "\n",
    "embedding_dim = 16\n",
    "hidden_dim = 32\n",
    "n_epochs = 50\n",
    "\n",
    "n_qubits = 4\n",
    "\n",
    "model_quantum = LSTMTagger(embedding_dim, \n",
    "                        hidden_dim, \n",
    "                        vocab_size=len(word_to_ix), \n",
    "                        tagset_size=len(tag_to_ix), \n",
    "                        n_qubits=n_qubits).to(device)\n",
    "\n",
    "optimizer = optim.Adam(model_quantum.parameters(), lr=0.03, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)\n",
    "\n",
    "history_quantum = train(model_quantum, n_epochs, training_data, optimizer, scheduler, word_to_ix, tag_to_ix, errorRate=er)\n",
    "\n",
    "# torch.save(model_quantum.state_dict(), 'lstm_NATL.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tagger will use Quantum LSTM\n",
      "Epoch 1 / 50: Loss = 1.780 Acc = 0.17\n",
      "Epoch 2 / 50: Loss = 1.672 Acc = 0.22\n",
      "Epoch 3 / 50: Loss = 1.487 Acc = 0.36\n",
      "Epoch 4 / 50: Loss = 1.327 Acc = 0.52\n",
      "Epoch 5 / 50: Loss = 1.411 Acc = 0.45\n",
      "Epoch 6 / 50: Loss = 1.405 Acc = 0.47\n",
      "Epoch 7 / 50: Loss = 1.140 Acc = 0.58\n",
      "Epoch 8 / 50: Loss = 1.169 Acc = 0.49\n",
      "Epoch 9 / 50: Loss = 1.082 Acc = 0.58\n",
      "Epoch 10 / 50: Loss = 1.005 Acc = 0.58\n",
      "Epoch 11 / 50: Loss = 1.023 Acc = 0.62\n",
      "Epoch 12 / 50: Loss = 1.024 Acc = 0.59\n",
      "Epoch 13 / 50: Loss = 0.944 Acc = 0.68\n",
      "Epoch 14 / 50: Loss = 1.017 Acc = 0.61\n",
      "Epoch 15 / 50: Loss = 0.915 Acc = 0.62\n",
      "Epoch 16 / 50: Loss = 0.884 Acc = 0.58\n",
      "Epoch 17 / 50: Loss = 1.010 Acc = 0.64\n",
      "Epoch 18 / 50: Loss = 0.737 Acc = 0.70\n",
      "Epoch 19 / 50: Loss = 0.690 Acc = 0.75\n",
      "Epoch 20 / 50: Loss = 0.726 Acc = 0.69\n",
      "Epoch 21 / 50: Loss = 0.660 Acc = 0.72\n",
      "Epoch 22 / 50: Loss = 0.675 Acc = 0.76\n",
      "Epoch 23 / 50: Loss = 0.573 Acc = 0.77\n",
      "Epoch 24 / 50: Loss = 0.675 Acc = 0.75\n",
      "Epoch 25 / 50: Loss = 0.850 Acc = 0.76\n",
      "Epoch 26 / 50: Loss = 0.786 Acc = 0.71\n",
      "Epoch 27 / 50: Loss = 0.661 Acc = 0.80\n",
      "Epoch 28 / 50: Loss = 0.630 Acc = 0.80\n",
      "Epoch 29 / 50: Loss = 0.619 Acc = 0.77\n",
      "Epoch 30 / 50: Loss = 0.547 Acc = 0.84\n",
      "Epoch 31 / 50: Loss = 0.537 Acc = 0.82\n",
      "Epoch 32 / 50: Loss = 0.562 Acc = 0.80\n",
      "Epoch 33 / 50: Loss = 0.451 Acc = 0.82\n",
      "Epoch 34 / 50: Loss = 0.458 Acc = 0.87\n",
      "Epoch 35 / 50: Loss = 0.331 Acc = 0.92\n",
      "Epoch 36 / 50: Loss = 0.451 Acc = 0.87\n",
      "Epoch 37 / 50: Loss = 0.443 Acc = 0.85\n",
      "Epoch 38 / 50: Loss = 0.373 Acc = 0.87\n",
      "Epoch 39 / 50: Loss = 0.625 Acc = 0.83\n",
      "Epoch 40 / 50: Loss = 0.375 Acc = 0.89\n",
      "Epoch 41 / 50: Loss = 0.381 Acc = 0.87\n",
      "Epoch 42 / 50: Loss = 0.384 Acc = 0.86\n",
      "Epoch 43 / 50: Loss = 0.305 Acc = 0.92\n",
      "Epoch 44 / 50: Loss = 0.276 Acc = 0.93\n",
      "Epoch 45 / 50: Loss = 0.265 Acc = 0.93\n",
      "Epoch 46 / 50: Loss = 0.318 Acc = 0.91\n",
      "Epoch 47 / 50: Loss = 0.592 Acc = 0.86\n",
      "Epoch 48 / 50: Loss = 0.366 Acc = 0.89\n",
      "Epoch 49 / 50: Loss = 0.282 Acc = 0.94\n",
      "Epoch 50 / 50: Loss = 0.308 Acc = 0.92\n"
     ]
    }
   ],
   "source": [
    "er = 0.005\n",
    "\n",
    "embedding_dim = 16\n",
    "hidden_dim = 32\n",
    "n_epochs = 50\n",
    "\n",
    "n_qubits = 4\n",
    "\n",
    "model_quantum = LSTMTagger(embedding_dim, \n",
    "                        hidden_dim, \n",
    "                        vocab_size=len(word_to_ix), \n",
    "                        tagset_size=len(tag_to_ix), \n",
    "                        n_qubits=n_qubits).to(device)\n",
    "\n",
    "optimizer = optim.Adam(model_quantum.parameters(), lr=0.03, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)\n",
    "\n",
    "history_quantum = train(model_quantum, n_epochs, training_data, optimizer, scheduler, word_to_ix, tag_to_ix, errorRate=er)\n",
    "\n",
    "torch.save(model_quantum.state_dict(), 'lstm_NATM.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tagger will use Quantum LSTM\n",
      "Epoch 1 / 50: Loss = 1.761 Acc = 0.21\n",
      "Epoch 2 / 50: Loss = 1.657 Acc = 0.36\n",
      "Epoch 3 / 50: Loss = 1.414 Acc = 0.43\n",
      "Epoch 4 / 50: Loss = 1.259 Acc = 0.53\n",
      "Epoch 5 / 50: Loss = 1.266 Acc = 0.57\n",
      "Epoch 6 / 50: Loss = 1.384 Acc = 0.51\n",
      "Epoch 7 / 50: Loss = 1.057 Acc = 0.60\n",
      "Epoch 8 / 50: Loss = 1.150 Acc = 0.56\n",
      "Epoch 9 / 50: Loss = 0.935 Acc = 0.68\n",
      "Epoch 10 / 50: Loss = 0.950 Acc = 0.64\n",
      "Epoch 11 / 50: Loss = 0.905 Acc = 0.64\n",
      "Epoch 12 / 50: Loss = 0.828 Acc = 0.67\n",
      "Epoch 13 / 50: Loss = 0.797 Acc = 0.70\n",
      "Epoch 14 / 50: Loss = 0.805 Acc = 0.71\n",
      "Epoch 15 / 50: Loss = 1.015 Acc = 0.62\n",
      "Epoch 16 / 50: Loss = 0.724 Acc = 0.70\n",
      "Epoch 17 / 50: Loss = 0.884 Acc = 0.64\n",
      "Epoch 18 / 50: Loss = 0.881 Acc = 0.60\n",
      "Epoch 19 / 50: Loss = 0.653 Acc = 0.75\n",
      "Epoch 20 / 50: Loss = 0.663 Acc = 0.77\n",
      "Epoch 21 / 50: Loss = 0.705 Acc = 0.76\n",
      "Epoch 22 / 50: Loss = 0.721 Acc = 0.71\n",
      "Epoch 23 / 50: Loss = 0.756 Acc = 0.74\n",
      "Epoch 24 / 50: Loss = 0.708 Acc = 0.74\n",
      "Epoch 25 / 50: Loss = 0.645 Acc = 0.79\n",
      "Epoch 26 / 50: Loss = 0.665 Acc = 0.77\n",
      "Epoch 27 / 50: Loss = 0.565 Acc = 0.75\n",
      "Epoch 28 / 50: Loss = 0.563 Acc = 0.78\n",
      "Epoch 29 / 50: Loss = 0.686 Acc = 0.78\n",
      "Epoch 30 / 50: Loss = 0.672 Acc = 0.75\n",
      "Epoch 31 / 50: Loss = 0.561 Acc = 0.83\n",
      "Epoch 32 / 50: Loss = 0.707 Acc = 0.75\n",
      "Epoch 33 / 50: Loss = 0.668 Acc = 0.78\n",
      "Epoch 34 / 50: Loss = 0.564 Acc = 0.81\n",
      "Epoch 35 / 50: Loss = 0.529 Acc = 0.81\n",
      "Epoch 36 / 50: Loss = 0.520 Acc = 0.81\n",
      "Epoch 37 / 50: Loss = 0.517 Acc = 0.80\n",
      "Epoch 38 / 50: Loss = 0.659 Acc = 0.81\n",
      "Epoch 39 / 50: Loss = 0.489 Acc = 0.81\n",
      "Epoch 40 / 50: Loss = 0.644 Acc = 0.73\n",
      "Epoch 41 / 50: Loss = 0.484 Acc = 0.83\n",
      "Epoch 42 / 50: Loss = 0.582 Acc = 0.81\n",
      "Epoch 43 / 50: Loss = 0.517 Acc = 0.81\n",
      "Epoch 44 / 50: Loss = 0.499 Acc = 0.83\n",
      "Epoch 45 / 50: Loss = 0.696 Acc = 0.76\n",
      "Epoch 46 / 50: Loss = 0.482 Acc = 0.82\n",
      "Epoch 47 / 50: Loss = 0.519 Acc = 0.83\n",
      "Epoch 48 / 50: Loss = 0.485 Acc = 0.86\n",
      "Epoch 49 / 50: Loss = 0.525 Acc = 0.79\n",
      "Epoch 50 / 50: Loss = 0.464 Acc = 0.84\n"
     ]
    }
   ],
   "source": [
    "er = 0.01\n",
    "\n",
    "embedding_dim = 16\n",
    "hidden_dim = 32\n",
    "n_epochs = 50\n",
    "\n",
    "n_qubits = 4\n",
    "\n",
    "model_quantum = LSTMTagger(embedding_dim, \n",
    "                        hidden_dim, \n",
    "                        vocab_size=len(word_to_ix), \n",
    "                        tagset_size=len(tag_to_ix), \n",
    "                        n_qubits=n_qubits).to(device)\n",
    "\n",
    "optimizer = optim.Adam(model_quantum.parameters(), lr=0.03, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)\n",
    "\n",
    "history_quantum = train(model_quantum, n_epochs, training_data, optimizer, scheduler, word_to_ix, tag_to_ix, errorRate=er)\n",
    "\n",
    "torch.save(model_quantum.state_dict(), 'lstm_NATH.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ours"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ea_task import OneGen_task\n",
    "import geatpy as ea\n",
    "from new_gates import PauliX, PauliY, PauliZ\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "ERROR_DICT = {0: None, 1: PauliX().to(device), 2: PauliY().to(device), 3: PauliZ().to(device)}\n",
    "\n",
    "def gen_fd(var):\n",
    "    fd = []\n",
    "    for i, v in enumerate(var):\n",
    "        fd.append(ERROR_DICT[v])\n",
    "    # reshape fd to 4x4\n",
    "    fd = np.array(fd).reshape(4, 4)\n",
    "    return fd\n",
    "\n",
    "def aim(var, model, inputs, labels):\n",
    "    fs = gen_fd(var)\n",
    "    with torch.no_grad():\n",
    "        outputs = model(inputs, fs)\n",
    "        loss = F.nll_loss(outputs, labels).item()\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ARG():\n",
    "    pass\n",
    "##################################\n",
    "##### set up EA param.############\n",
    "##################################\n",
    "args = ARG()\n",
    "args.N = 16\n",
    "args.M = 1\n",
    "args.NIND = 30\n",
    "args.K = 0.1 * args.NIND    # 10% of the population for elite selection\n",
    "args.selS = 'etour' \n",
    "args.recS = 'xovdp' \n",
    "args.mutS = 'mutbin' \n",
    "args.Encoding = 'BG' \n",
    "args.pc = 0.8  \n",
    "args.EA = False\n",
    "args.lambda_ = 0.5\n",
    "\n",
    "ranges = np.array([[0, 3]] * args.N).T\n",
    "borders = np.ones_like(ranges)\n",
    "varTypes = np.array([1]*args.N) \n",
    "codes = [0] * args.N \n",
    "precisions =[0] * args.N\n",
    "scales = [0] * args.N\n",
    "\n",
    "args.FieldD = ea.crtfld(args.Encoding,varTypes,ranges,borders,precisions,codes,scales)\n",
    "args.aim = aim\n",
    "args.Chrom = ea.crtpc(args.Encoding, args.NIND, args.FieldD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, training_data, optimizer, word_to_ix, tag_to_ix, args):\n",
    "    loss1_acc = 0\n",
    "    loss2_acc = 0\n",
    "    for sentence, tags in training_data:\n",
    "        sentence_in = prepare_sequence(sentence, word_to_ix)\n",
    "        labels = prepare_sequence(tags, tag_to_ix)\n",
    "        outputs1 = model(sentence_in)\n",
    "        loss1 = F.nll_loss(outputs1, labels)\n",
    "        loss1_acc += loss1.item()\n",
    "\n",
    "        if args.EA:\n",
    "            args.Chrom, args.best_pop = OneGen_task(args.N, args.M, args.K, args.NIND, args.selS, args.recS, args.mutS, args.FieldD, \\\n",
    "                                                    model, sentence_in, labels, args.aim, args.Chrom, args.pc, args.Encoding)\n",
    "            fd = gen_fd(args.best_pop)\n",
    "            outputs2 = model(sentence_in, fd)\n",
    "            loss2 = F.nll_loss(outputs2, labels)\n",
    "            loss2_acc += loss2.item()\n",
    "            loss = loss1 + args.lambda_ * loss2\n",
    "        else:\n",
    "            loss = loss1\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    return loss1_acc, loss2_acc\n",
    "\n",
    "def validate(model, test_data, word_to_ix, tag_to_ix):\n",
    "    preds = []\n",
    "    targets = []\n",
    "    with torch.no_grad():\n",
    "        for sentence, tags in test_data:\n",
    "            sentence_in = prepare_sequence(sentence, word_to_ix)\n",
    "            labels = prepare_sequence(tags, tag_to_ix)\n",
    "\n",
    "            # Step 3. Run our forward pass.\n",
    "            tag_scores = model(sentence_in)\n",
    "            \n",
    "            probs = torch.softmax(tag_scores, dim=-1)\n",
    "            preds.append(probs.argmax(dim=-1))\n",
    "            targets.append(labels)\n",
    "    preds = torch.cat(preds)\n",
    "    targets = torch.cat(targets)\n",
    "    corrects = (preds == targets)\n",
    "    accuracy = corrects.sum().float() / float(targets.size(0) )\n",
    "\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tagger will use Quantum LSTM\n",
      "epoch 0, loss1 35.55309009552002, loss2 0, accuracy 0.19491524994373322\n",
      "epoch 1, loss1 34.50109088420868, loss2 0, accuracy 0.24576270580291748\n",
      "epoch 2, loss1 33.596715807914734, loss2 0, accuracy 0.2881355881690979\n",
      "epoch 3, loss1 32.04462730884552, loss2 0, accuracy 0.33898305892944336\n",
      "epoch 4, loss1 29.48051166534424, loss2 0, accuracy 0.4237288236618042\n",
      "epoch 5, loss1 26.41031277179718, loss2 0, accuracy 0.4576271176338196\n",
      "epoch 6, loss1 23.885135889053345, loss2 0, accuracy 0.4576271176338196\n",
      "epoch 7, loss1 21.769143223762512, loss2 0, accuracy 0.49152541160583496\n",
      "epoch 8, loss1 20.239594638347626, loss2 0, accuracy 0.5847457647323608\n",
      "epoch 9, loss1 19.042158901691437, loss2 0, accuracy 0.5677965879440308\n",
      "epoch 10, loss1 19.650225341320038, loss2 87.29529786109924, accuracy 0.5\n",
      "epoch 11, loss1 22.064439237117767, loss2 54.69685173034668, accuracy 0.5847457647323608\n",
      "epoch 12, loss1 23.182980597019196, loss2 42.725456953048706, accuracy 0.6101694703102112\n",
      "epoch 13, loss1 22.14625197649002, loss2 37.30340898036957, accuracy 0.6440678238868713\n",
      "epoch 14, loss1 20.578870475292206, loss2 32.479291915893555, accuracy 0.7542372941970825\n",
      "epoch 15, loss1 18.328531444072723, loss2 28.578118562698364, accuracy 0.7627118825912476\n",
      "epoch 16, loss1 16.707754611968994, loss2 26.424447894096375, accuracy 0.805084764957428\n",
      "epoch 17, loss1 15.369911313056946, loss2 25.423142552375793, accuracy 0.7966101765632629\n",
      "epoch 18, loss1 14.34741073846817, loss2 24.672243475914, accuracy 0.8305084705352783\n",
      "epoch 19, loss1 13.740282356739044, loss2 23.6947922706604, accuracy 0.8813559412956238\n",
      "epoch 20, loss1 12.348374724388123, loss2 21.782179951667786, accuracy 0.8898305296897888\n",
      "epoch 21, loss1 12.008170545101166, loss2 20.97259432077408, accuracy 0.8559321761131287\n",
      "epoch 22, loss1 11.047289609909058, loss2 20.288437068462372, accuracy 0.9237288236618042\n",
      "epoch 23, loss1 10.628255873918533, loss2 19.586112082004547, accuracy 0.9067796468734741\n",
      "epoch 24, loss1 9.899688005447388, loss2 18.802454888820648, accuracy 0.8983050584793091\n",
      "epoch 25, loss1 9.093968272209167, loss2 17.751454889774323, accuracy 0.9237288236618042\n",
      "epoch 26, loss1 8.740900725126266, loss2 17.309669077396393, accuracy 0.9067796468734741\n",
      "epoch 27, loss1 8.274422854185104, loss2 16.711966931819916, accuracy 0.9067796468734741\n",
      "epoch 28, loss1 7.854577422142029, loss2 16.559388399124146, accuracy 0.9322034120559692\n",
      "epoch 29, loss1 7.652022987604141, loss2 15.904659867286682, accuracy 0.9067796468734741\n",
      "epoch 30, loss1 7.35828298330307, loss2 15.056484401226044, accuracy 0.9067796468734741\n",
      "epoch 31, loss1 7.159100532531738, loss2 15.690467476844788, accuracy 0.9237288236618042\n",
      "epoch 32, loss1 6.993060767650604, loss2 14.683244943618774, accuracy 0.9237288236618042\n",
      "epoch 33, loss1 6.771033316850662, loss2 14.72069787979126, accuracy 0.9152542352676392\n",
      "epoch 34, loss1 6.618819609284401, loss2 14.60159319639206, accuracy 0.9322034120559692\n",
      "epoch 35, loss1 6.445711329579353, loss2 14.343875348567963, accuracy 0.9237288236618042\n",
      "epoch 36, loss1 6.344765782356262, loss2 14.050084054470062, accuracy 0.9406779408454895\n",
      "epoch 37, loss1 6.250962510704994, loss2 13.944658041000366, accuracy 0.9322034120559692\n",
      "epoch 38, loss1 6.138379707932472, loss2 13.753267705440521, accuracy 0.9237288236618042\n",
      "epoch 39, loss1 6.087879061698914, loss2 13.68615898489952, accuracy 0.9406779408454895\n",
      "epoch 40, loss1 6.033051058650017, loss2 13.468632936477661, accuracy 0.9406779408454895\n",
      "epoch 41, loss1 5.975088983774185, loss2 13.377310395240784, accuracy 0.9406779408454895\n",
      "epoch 42, loss1 5.930608451366425, loss2 13.488123893737793, accuracy 0.9406779408454895\n",
      "epoch 43, loss1 5.906953260302544, loss2 13.30675756931305, accuracy 0.9322034120559692\n",
      "epoch 44, loss1 5.885826393961906, loss2 13.317530333995819, accuracy 0.9322034120559692\n",
      "epoch 45, loss1 5.863943666219711, loss2 13.187527477741241, accuracy 0.9406779408454895\n",
      "epoch 46, loss1 5.854416027665138, loss2 13.31262457370758, accuracy 0.9406779408454895\n",
      "epoch 47, loss1 5.846678480505943, loss2 13.108937680721283, accuracy 0.9406779408454895\n",
      "epoch 48, loss1 5.842467591166496, loss2 13.256053626537323, accuracy 0.9406779408454895\n",
      "epoch 49, loss1 5.840454116463661, loss2 13.282144963741302, accuracy 0.9406779408454895\n"
     ]
    }
   ],
   "source": [
    "embedding_dim = 16\n",
    "hidden_dim = 32\n",
    "n_epochs = 50\n",
    "\n",
    "n_qubits = 4\n",
    "\n",
    "model = LSTMTagger(embedding_dim, \n",
    "                    hidden_dim, \n",
    "                    vocab_size=len(word_to_ix), \n",
    "                    tagset_size=len(tag_to_ix), \n",
    "                    n_qubits=n_qubits).to(device)\n",
    "n_epochs = 50\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)\n",
    "\n",
    "loss1_trace = []\n",
    "loss2_trace = []\n",
    "\n",
    "# itr = range(1, n_epochs + 1)\n",
    "for epoch in range(n_epochs):\n",
    "    if epoch == 10:\n",
    "        args.EA = True\n",
    "    l1, l2 = train(model, training_data, optimizer, word_to_ix, tag_to_ix, args)\n",
    "    loss1_trace.append(l1)\n",
    "    loss2_trace.append(l2)\n",
    "    scheduler.step()\n",
    "    if epoch % 1 == 0:\n",
    "        print(f\"epoch {epoch}, loss1 {l1}, loss2 {l2}, accuracy {validate(model, training_data, word_to_ix, tag_to_ix)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'lstm_ours.pt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quantumtorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
