{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "64f1cd87-f0ef-4f44-a9e0-4a29ccd5a28c",
   "metadata": {},
   "source": [
    "# BERT-GNN-CRF model without fined tuning BERT for Argument Unit Recognition and Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fb647d3",
   "metadata": {},
   "source": [
    "# Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "periodic-phrase",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import random\n",
    "import datetime\n",
    "from collections import defaultdict\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.nn import Linear, ReLU\n",
    "\n",
    "from torchcrf import CRF\n",
    "\n",
    "from torch_geometric.nn import GCNConv,GATConv, SAGEConv,Sequential,AGNNConv,global_mean_pool\n",
    "from torch_geometric.data import Dataset, download_url,Data\n",
    "from torch_geometric.utils import to_undirected, sort_edge_index, to_networkx,to_dense_batch\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import numpy as np\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, '../../../utils/')\n",
    "from utils_GNN_CRF import (\n",
    "    compute_tab_logits_pred_sentence_level,\n",
    "    extract_sentence_level_label,\n",
    "    F1_Loss,\n",
    "    find_max_size_batch_without_nt,\n",
    "    remove_internal_nodes,\n",
    "    remove_internal_nodes_without_labels,\n",
    "    reshape_label_to_batch_padded,\n",
    "    reshape_mask_to_batch_padded,\n",
    "    reshape_data_to_batch_padded,\n",
    "    find_the_most_frequent_element\n",
    ")\n",
    "\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "interstate-release",
   "metadata": {},
   "source": [
    "# Dataset creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "tribal-parts",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '../../../data/aurc/bert/large_depth'\n",
    "\n",
    "class MyOwnDataset(Dataset):\n",
    "    def __init__(self, num_data, root, transform=None, pre_transform=None):\n",
    "        super(MyOwnDataset, self).__init__(root, transform, pre_transform)\n",
    "        self.num_data = num_data\n",
    "        \n",
    "    @property\n",
    "    def raw_file_names(self):\n",
    "        return []\n",
    "\n",
    "    @property\n",
    "    def processed_file_names(self):\n",
    "        return ['data_{}.pt'.format(idx) for idx in range(self.num_data)]\n",
    "\n",
    "    def len(self):\n",
    "        return len(self.processed_file_names)\n",
    "\n",
    "    def get(self, idx):\n",
    "        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "greatest-trading",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train size : 3960\n",
      "dev size : 790\n",
      "test size : 1959\n"
     ]
    }
   ],
   "source": [
    "train_dataset = MyOwnDataset(num_data = 3960,root = data_dir+\"/Train\")\n",
    "test_dataset = MyOwnDataset(num_data = 1959,root = data_dir+\"/Test\")\n",
    "dev_dataset = MyOwnDataset(num_data = 790,root = data_dir+\"/Dev\")\n",
    "\n",
    "print(\"train size : \" + str(len(train_dataset)))\n",
    "print(\"dev size : \" + str(len(dev_dataset)))\n",
    "print(\"test size : \" + str(len(test_dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e42a3e62",
   "metadata": {},
   "source": [
    "# Model Creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "skilled-preview",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 190\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)\n",
    "dev_loader = DataLoader(dev_dataset, batch_size=batch_size,shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "51f3f7e9-cdc7-4978-a3cd-4bc7f87b5498",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "62"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataloader = train_loader\n",
    "max_size_actu = 0\n",
    "for idx, (ele) in enumerate(dataloader):\n",
    "    sample_nt_tensor = ele[\"indice_nt_tensor\"]\n",
    "    sample_batch = ele[\"batch\"]\n",
    "    sample_x = ele[\"x\"]\n",
    "    sample_label = ele[\"y\"]\n",
    "    max_size_actu = max(find_max_size_batch_without_nt(sample_nt_tensor,sample_batch),max_size_actu)\n",
    "max_size_actu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a1766772-6022-4e03-b8bd-cdc52cf76227",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net_large_CRF(torch.nn.Module):\n",
    "    def __init__(self,num_node_features,num_classes, num_labels = 3, batch_first=True):\n",
    "        super(Net_large_CRF, self).__init__()\n",
    "        \n",
    "        self.num_labels = num_labels\n",
    "        self.batch_first = batch_first\n",
    "        \n",
    "        self.conv0 = GATConv(num_node_features, 303,heads = 2)\n",
    "        self.conv1 = GATConv(303*2, 216,heads = 2)\n",
    "        self.lin = Linear(216*2, num_classes)\n",
    "        \n",
    "        self.crf = CRF(self.num_labels, batch_first=self.batch_first)\n",
    "\n",
    "    def forward(self, data, labels=None):\n",
    "        x, edge_index, batch, indice_nt_tensor = data.x, data.edge_index, data.batch, data.indice_nt_tensor\n",
    "        \n",
    "        x = self.conv0(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x,0.05,training = self.training)\n",
    "        \n",
    "        edge_index =  to_undirected(edge_index)\n",
    "        \n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x,0.05,training = self.training)\n",
    "        \n",
    "        x = self.lin(x)\n",
    "        \n",
    "        if(labels is not None):\n",
    "            x_new,batch_new,label_new = remove_internal_nodes(indice_nt_tensor,x,batch,labels)\n",
    "            \n",
    "            labels, padding_mask = reshape_label_to_batch_padded(64,batch_new,label_new)\n",
    "            \n",
    "            logits = reshape_data_to_batch_padded(64,batch_new,x_new,self.num_labels)\n",
    "            \n",
    "            return -self.crf(emissions = logits, tags=labels,mask=padding_mask.byte())\n",
    "        \n",
    "        else:\n",
    "            x_new,batch_new = remove_internal_nodes_without_labels(indice_nt_tensor,x,batch)\n",
    "            \n",
    "            padding_mask = reshape_mask_to_batch_padded(64,batch_new)\n",
    "            \n",
    "            logits = reshape_data_to_batch_padded(64,batch_new,x_new,self.num_labels)\n",
    "            \n",
    "            return self.crf.decode(emissions =logits, mask=padding_mask.byte())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "54371d45-aec2-46b8-a84f-49654e9ae280",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 3\n",
    "num_node_features = 1024\n",
    "max_grad_norm = 10\n",
    "device = \"cpu\" #torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
    "model = Net_large_CRF(num_node_features,num_classes).to(device)\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
    "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')\n",
    "\n",
    "sentence_level=False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec731c02",
   "metadata": {},
   "source": [
    "# Train and Evaluate the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "631f9730",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataloader):\n",
    "        \n",
    "    model.train()\n",
    "    log_interval = 20\n",
    "    start_time = time.time()\n",
    "    dict_param = {}\n",
    "    tensor_board_grad = False\n",
    "    \n",
    "    for idx, (ele) in enumerate(dataloader):\n",
    "        ele = ele.to(device)\n",
    "        ele[\"x\"] = ele[\"x\"].float()\n",
    "        optimizer.zero_grad()\n",
    "        loss = model(ele,ele[\"y\"])\n",
    "        loss.backward()\n",
    "\n",
    "        # gradient clipping\n",
    "        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)\n",
    "        optimizer.step()\n",
    "        \n",
    "        # Print the loss at a given frequency\n",
    "        if idx % log_interval == 0 and idx > 0:\n",
    "            elapsed = time.time() - start_time\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches '\n",
    "                      '| loss {:8.3f}'.format(epoch, idx, len(dataloader),\n",
    "                                                  loss))\n",
    "            start_time = time.time()\n",
    "        \n",
    "        # Save the gradient of the layer of the model to TensorBoard\n",
    "        if(tensor_board_grad == True):\n",
    "            for i, (name, param) in enumerate(model.named_parameters()):\n",
    "                            if param.requires_grad:\n",
    "                                if (name not in dict_param.keys()):\n",
    "                                    dict_param[name]= param.grad.view(-1)\n",
    "                                else:\n",
    "                                    dict_param[name] += param.grad.view(-1)\n",
    "                                    \n",
    "    return dict_param"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "60ac0b94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(dataloader):\n",
    "    model.eval()\n",
    "    \n",
    "    total_acc, total_count, total_predict = 0, 0, 0\n",
    "    total_acc_sentence_level, total_size_sentence_level = 0, 0\n",
    "    \n",
    "    # initalize the variable to compute the F1 score at sentence level\n",
    "    total_class_predict_sentence = []\n",
    "    total_true_class_sentence = []\n",
    "    \n",
    "    # initalize the variable to compute the F1 score at token level\n",
    "    total_class_predict_token = []\n",
    "    total_true_class_token = []\n",
    "    \n",
    "    val_to_replace = torch.tensor(1e9, dtype=torch.float32).to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for idx, (ele) in enumerate(dataloader):\n",
    "            ele = ele.to(device)\n",
    "            ele[\"x\"] = ele[\"x\"].float()\n",
    "            class_predict = model(ele)\n",
    "            \n",
    "            ## Flatten the result\n",
    "            class_predict_flat = [item for sublist in class_predict for item in sublist]\n",
    "            total_class_predict_token.extend(class_predict_flat)\n",
    "            \n",
    "            ## Extract the true label without the internal nodes\n",
    "            condition = ele[\"indice_nt_tensor\"] > 0\n",
    "            ele_y_clean = ele[\"y\"][condition].tolist()\n",
    "            total_true_class_token.extend(ele_y_clean)\n",
    "            \n",
    "                        \n",
    "            # Compute the sentence level label & accuracy\n",
    "            sentence_level_label = [find_the_most_frequent_element(list_ele) for list_ele in class_predict]\n",
    "            total_class_predict_sentence.extend(sentence_level_label)\n",
    "            \n",
    "            true_sentence_level_label = ele[\"sentence_level_label\"].tolist()\n",
    "            total_true_class_sentence.extend(true_sentence_level_label) \n",
    "            \n",
    "            total_acc_sentence_level += sum([ ele_1 == ele_2 for ele_1,ele_2 in zip(sentence_level_label,true_sentence_level_label)])\n",
    "            total_size_sentence_level += len(true_sentence_level_label)\n",
    "            \n",
    "            # Count the number of non 0 values to know if the prediction is only 0 or not\n",
    "            total_predict += sum(class_predict_flat)\n",
    "            \n",
    "            # Compute the token level accuracy\n",
    "            \n",
    "            total_acc += sum([1 for (ele1,ele2) in zip(class_predict_flat,ele_y_clean) if ele1 == ele2])\n",
    "            total_count += len(ele_y_clean)\n",
    "                                            \n",
    "        f1_token_level = f1_score(total_class_predict_token,total_true_class_token,labels = [0,1,2], average=\"micro\")\n",
    "        f1_sentence_level = f1_score(total_class_predict_sentence,total_true_class_sentence,labels = [0,1,2], average=\"micro\")\n",
    "        \n",
    "    return total_acc/total_count, total_predict, f1_token_level, total_acc_sentence_level/total_size_sentence_level, f1_sentence_level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "798717af",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/20 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| epoch   0 |    20/   21 batches | loss 4789.797\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 1/20 [00:20<06:24, 20.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.22463016792875526 the number of non 0 values predicted is : 153129\n",
      "accuracy on the sentence level train is 0.29924242424242425\n",
      "\n",
      "f1 score on token level train is : 0.22463016792875526\n",
      "f1 score on sentence level train is : 0.29924242424242425\n",
      "\n",
      "accuracy on the evaluate dataset0.24418076013820694 the number of non 0 values predicted is : 31251\n",
      "accuracy on the sentence level evaluate0.17468354430379746\n",
      "\n",
      "f1 score on token level evaluate is : 0.24418076013820694\n",
      "f1 score on sentence level evaluate is : 0.17468354430379746\n",
      "\n",
      "accuracy on the test dataset0.25168856805558243 the number of non 0 values predicted is : 73119\n",
      "accuracy on the sentence level test0.32414497192445124\n",
      "\n",
      "f1 score on token level test is : 0.25168856805558243\n",
      "f1 score on sentence level  test is : 0.32414497192445124\n",
      "| epoch   1 |    20/   21 batches | loss 4714.850\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 2/20 [00:40<06:02, 20.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.5867508470496751 the number of non 0 values predicted is : 11880\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.5867508470496751\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5440534642662302 the number of non 0 values predicted is : 2370\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5440534642662302\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5254398018230729 the number of non 0 values predicted is : 5877\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5254398018230729\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch   2 |    20/   21 batches | loss 4551.806\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█▌        | 3/20 [00:59<05:38, 19.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.5867508470496751 the number of non 0 values predicted is : 11880\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.5867508470496751\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5440534642662302 the number of non 0 values predicted is : 2370\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5440534642662302\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5254398018230729 the number of non 0 values predicted is : 5877\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5254398018230729\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch   3 |    20/   21 batches | loss 4453.356\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 4/20 [01:19<05:17, 19.86s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.5867508470496751 the number of non 0 values predicted is : 11880\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.5867508470496751\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5440534642662302 the number of non 0 values predicted is : 2370\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5440534642662302\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5254398018230729 the number of non 0 values predicted is : 5877\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5254398018230729\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch   4 |    20/   21 batches | loss 4674.442\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|██▌       | 5/20 [01:39<04:56, 19.78s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch   5 |    20/   21 batches | loss 4664.516\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 6/20 [01:58<04:36, 19.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch   6 |    20/   21 batches | loss 4432.745\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 35%|███▌      | 7/20 [02:18<04:15, 19.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch   7 |    20/   21 batches | loss 4388.549\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 8/20 [02:37<03:54, 19.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch   8 |    20/   21 batches | loss 4475.168\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 9/20 [02:57<03:37, 19.74s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch   9 |    20/   21 batches | loss 4252.000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 10/20 [03:17<03:16, 19.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  10 |    20/   21 batches | loss 4488.035\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 55%|█████▌    | 11/20 [03:37<02:56, 19.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  11 |    20/   21 batches | loss 4601.744\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 12/20 [03:56<02:36, 19.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  12 |    20/   21 batches | loss 4229.199\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 65%|██████▌   | 13/20 [04:15<02:16, 19.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  13 |    20/   21 batches | loss 4279.568\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 14/20 [04:35<01:57, 19.54s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  14 |    20/   21 batches | loss 4272.545\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 75%|███████▌  | 15/20 [04:54<01:37, 19.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  15 |    20/   21 batches | loss 4304.869\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 16/20 [05:14<01:18, 19.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  16 |    20/   21 batches | loss 4121.338\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 85%|████████▌ | 17/20 [05:34<00:58, 19.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  17 |    20/   21 batches | loss 4086.607\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 18/20 [05:53<00:39, 19.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  18 |    20/   21 batches | loss 4273.128\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 95%|█████████▌| 19/20 [06:13<00:19, 19.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n",
      "| epoch  19 |    20/   21 batches | loss 4158.973\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [06:32<00:00, 19.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy on the train dataset0.626631612079021 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level train is 0.4671717171717172\n",
      "\n",
      "f1 score on token level train is : 0.626631612079021\n",
      "f1 score on sentence level train is : 0.4671717171717172\n",
      "\n",
      "accuracy on the evaluate dataset0.5788779778141481 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level evaluate0.3924050632911392\n",
      "\n",
      "f1 score on token level evaluate is : 0.5788779778141481\n",
      "f1 score on sentence level evaluate is : 0.3924050632911392\n",
      "\n",
      "accuracy on the test dataset0.5642429989742795 the number of non 0 values predicted is : 0\n",
      "accuracy on the sentence level test0.3823379275140378\n",
      "\n",
      "f1 score on token level test is : 0.5642429989742795\n",
      "f1 score on sentence level  test is : 0.3823379275140378\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# You can uncomment the commented code in order to monitor the runs with tensorboard. \n",
    "# Writer will output to ./runs/ directory by default\n",
    "\n",
    "# writer = SummaryWriter(f'runs/{datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")}')\n",
    "\n",
    "for epoch in tqdm(range(20)):\n",
    "    dict_norm = train(train_loader)\n",
    "    \n",
    "    accu_train, total_predict_train, f1_token_level_train, acc_sentence_level_train, f1_sentence_level_train = evaluate(train_loader)\n",
    "    \n",
    "    accu_val, total_predict, f1_token_level_dev, acc_sentence_level_dev, f1_sentence_level_dev = evaluate(dev_loader)\n",
    "    acc_test, total_pred_test, f1_token_level_test, acc_sentence_level_test, f1_sentence_level_test = evaluate(test_loader)\n",
    "    \n",
    "    #scheduler.step(1-f1_token_level_dev)\n",
    "    \n",
    "    #for key, values in dict_norm.items():\n",
    "    #    writer.add_histogram(tag = 'distribution of the gradient of the layer : ' + str(key), \n",
    "    #                         values = values.unsqueeze(0).T , \n",
    "    #                         global_step = epoch)\n",
    "    \n",
    "    \n",
    "    #writer.add_scalars('Accuracy score token level over batchs',  {'test':acc_test,'eval':accu_val, 'train':accu_train},epoch)\n",
    "    #writer.add_scalars('F1 score token level over batchs',  {'test':f1_token_level_test,'eval':f1_token_level_dev, 'train':f1_token_level_train},epoch)\n",
    "    \n",
    "    #writer.add_scalars('Accuracy score sentence level over batchs',  {'test':acc_sentence_level_test,'eval':acc_sentence_level_dev, 'train':acc_sentence_level_train},epoch)\n",
    "    #writer.add_scalars('F1 score sentence level over batchs',  {'test':f1_sentence_level_test,'eval':f1_sentence_level_dev, 'train':f1_sentence_level_train},epoch)\n",
    "    \n",
    "    print('accuracy on the train dataset' + str(accu_train) + \" the number of non 0 values predicted is : \" + str(int(total_predict_train)))\n",
    "    print(\"accuracy on the sentence level train is \" + str(acc_sentence_level_train))\n",
    "    \n",
    "    print(\"\")\n",
    "    print(\"f1 score on token level train is : \" + str(f1_token_level_train))\n",
    "    #print(\"f1 score on span level evaluate is : \" + str(f1_span_level_dev))\n",
    "    print(\"f1 score on sentence level train is : \" + str(f1_sentence_level_train))\n",
    "    \n",
    "    print(\"\")\n",
    "    print('accuracy on the evaluate dataset' + str(accu_val) + \" the number of non 0 values predicted is : \" + str(int(total_predict)))\n",
    "    print(\"accuracy on the sentence level evaluate\" + str(acc_sentence_level_dev))\n",
    "    \n",
    "    print(\"\")\n",
    "    print(\"f1 score on token level evaluate is : \" + str(f1_token_level_dev))\n",
    "    #print(\"f1 score on span level evaluate is : \" + str(f1_span_level_dev))\n",
    "    print(\"f1 score on sentence level evaluate is : \" + str(f1_sentence_level_dev))\n",
    "    \n",
    "    print(\"\")\n",
    "    print('accuracy on the test dataset' + str(acc_test) + \" the number of non 0 values predicted is : \" + str(int(total_pred_test)))\n",
    "    print(\"accuracy on the sentence level test\" + str(acc_sentence_level_test))\n",
    "    \n",
    "    print(\"\")\n",
    "    print(\"f1 score on token level test is : \" + str(f1_token_level_test))\n",
    "    #print(\"f1 score on span level test is : \" + str(f1_span_level_test))\n",
    "    print(\"f1 score on sentence level  test is : \" + str(f1_sentence_level_test))\n",
    "    \n",
    "#writer.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_env",
   "language": "python",
   "name": "concept_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
