{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6abd38c3-86c2-4f33-b461-71ad1a934307",
   "metadata": {},
   "source": [
    "# BERT-GNN model with fined tuning BERT for Argument Unit Recognition and Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fb647d3",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e71ab0-96e7-4ce8-8477-c0703a1ad6d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"\"\n",
    "\n",
    "import spacy\n",
    "from spacy import displacy\n",
    "\n",
    "nlp = spacy.load('en_core_web_sm')\n",
    "\n",
    "import benepar\n",
    "from benepar import BeneparComponent, NonConstituentException\n",
    "benepar.download('benepar_en3')\n",
    "\n",
    "nlp.add_pipe(\"benepar\", config={\"model\": \"benepar_en3\"})\n",
    "\n",
    "import json\n",
    "\n",
    "from transformers import (AdamW, BertConfig, get_linear_schedule_with_warmup, \n",
    "                                BertForTokenClassification, BertTokenizer, BertPreTrainedModel, BertModel)\n",
    "\n",
    "import os.path as osp\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import random\n",
    "import datetime\n",
    "from collections import defaultdict\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.nn import Linear, ReLU\n",
    "\n",
    "from torch_geometric.nn import GCNConv,GATConv, SAGEConv,Sequential,AGNNConv,global_mean_pool\n",
    "from torch_geometric.data import Dataset, download_url,Data\n",
    "from torch_geometric.utils import to_undirected, sort_edge_index, to_networkx\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import numpy as np\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch import nn\n",
    "\n",
    "## Personal Library\n",
    "import sys\n",
    "sys.path.insert(0, '../../../utils/')\n",
    "from constituency_parsing_utils_depth_3 import (\n",
    "    construct_the_edges_matrix_consistuency_3\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "interstate-release",
   "metadata": {},
   "source": [
    "# Dataset creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba3c03e7-949d-44ba-847e-961c2afe2074",
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_the_edges_matrix_consistuency_3_padding(doc, list_index_to_replicate, max_indice_nt):\n",
    "    edges_index = []\n",
    "    current_node_index = -1\n",
    "    \n",
    "    current_index_replicated = 0\n",
    "    \n",
    "    for sent in doc.sents:\n",
    "        ## On regarde chaque phrases de notre document\n",
    "        current_node_index +=1\n",
    "        index_sentence = current_node_index\n",
    "        for ele in sent._.children:\n",
    "            ## On regarde chaque élément de la phrase au niveau de profondeur 1 et on regarde si eux même ont des enfants\n",
    "            num_children = sum(1 for x in ele._.children)\n",
    "            if num_children != 0 :\n",
    "                ## Il y a des éléments de profondeurs 2\n",
    "                current_node_index += 1\n",
    "                index_second_node = current_node_index\n",
    "                edges_index.append([index_second_node,index_sentence])\n",
    "                for ele_child in ele._.children:\n",
    "                    \n",
    "                    num_children_children = sum(1 for x in ele_child._.children)\n",
    "                    if num_children_children != 0 :\n",
    "                        ## Il y a des éléments de profondeurs 3\n",
    "                        current_node_index += 1\n",
    "                        index_third_node = current_node_index\n",
    "                        edges_index.append([index_third_node,index_second_node])\n",
    "                        \n",
    "                        for ele_child_child in ele_child._.children:\n",
    "                            \n",
    "                            for word in ele_child_child:\n",
    "                                if(word.i in list_index_to_replicate):\n",
    "                                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_third_node])\n",
    "                                    current_index_replicated +=1\n",
    "                                edges_index.append([word.i+max_indice_nt+current_index_replicated,index_third_node])\n",
    "                    else:\n",
    "                        ## Il y n'y a pas des éléments de profondeurs 3, le noeud de profondeur 2 est une feuille\n",
    "                        for word in ele_child:\n",
    "                            if(word.i in list_index_to_replicate):\n",
    "                                edges_index.append([word.i+max_indice_nt+current_index_replicated,index_second_node])\n",
    "                                current_index_replicated +=1\n",
    "                            ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase\n",
    "                            edges_index.append([word.i+max_indice_nt+current_index_replicated,index_second_node])\n",
    "            else:\n",
    "                ## Il n'y a pas d'éléments de profondeurs 2\n",
    "                for word in ele:\n",
    "                    if(word.i in list_index_to_replicate):\n",
    "                        edges_index.append([word.i+max_indice_nt+current_index_replicated,index_sentence])\n",
    "                        current_index_replicated +=1\n",
    "                    ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase\n",
    "                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_sentence])\n",
    "    return edges_index"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74a97bb5-a106-4dc0-9e1d-7c00fab64d60",
   "metadata": {},
   "source": [
    "### Import Dataset AURC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91aa520c-d0c5-402a-bd51-c481f585ab24",
   "metadata": {},
   "outputs": [],
   "source": [
    "need_new_writting = False\n",
    "if(need_new_writting == True):\n",
    "    data_dir = '../../../data/aurc/bert/end_2_end_IN'\n",
    "    input_file = 'AURC_DATA_dict.json'\n",
    "    num_labels = 3\n",
    "\n",
    "    pretrained_weights =  'bert-large-uncased'\n",
    "\n",
    "    max_sequence_length = 64\n",
    "\n",
    "    # the domains are shuffle\n",
    "    target_domain = 'Cross-Domain'\n",
    "\n",
    "    fname = '../../../data/aurc/AURC_DATA_dict.json'\n",
    "\n",
    "    # load the json file\n",
    "    with open(fname,'r') as my_file:\n",
    "        AURC_DATA_dict = json.load(my_file)\n",
    "    print(len(AURC_DATA_dict), [len(AURC_DATA_dict[topic]) for topic in AURC_DATA_dict.keys()])\n",
    "\n",
    "    # check the number of example per topic\n",
    "    topics = sorted(set(AURC_DATA_dict.keys()))\n",
    "    print(len(topics), topics)\n",
    "\n",
    "    # define the label to id dictionnary\n",
    "    label2id = {}\n",
    "    label2id['non'] = 0\n",
    "    label2id['con'] = 1\n",
    "    label2id['pro'] = 2\n",
    "\n",
    "    # Choose the tokenizer from Hugging Face transformers\n",
    "    tokenizer = BertTokenizer.from_pretrained(pretrained_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1af1a9a4-8cb7-4ab4-9dcc-acff2d6cb9b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "if(need_new_writting == True):\n",
    "    train_features = []\n",
    "    eval_features = []\n",
    "    test_features = []\n",
    "    \n",
    "    max_indice_nt = 21\n",
    "    num_error = 0\n",
    "    number_data_train = 0\n",
    "    number_data_test = 0\n",
    "    number_data_dev = 0\n",
    "    \n",
    "    for topic, AD in tqdm(AURC_DATA_dict.items()):\n",
    "        for ad in AD:\n",
    "            ## Use the BERT Tokenizer from Hugging Face\n",
    "            sequence_dict = tokenizer.encode_plus(ad['sentence'], max_length=max_sequence_length, pad_to_max_length=True, \n",
    "                                                  add_special_tokens=True)\n",
    "            ## Label management\n",
    "            input_labels = [label2id[label] for label in ad['tokenized_sentence_bert_labels'].split(' ')]\n",
    "            len_sentence = min(len(input_labels[:max_sequence_length-1]),len(input_labels))\n",
    "            input_labels =  [0] + input_labels[:max_sequence_length-1] + [0]*max(0,max_sequence_length-len(input_labels)-1)\n",
    "\n",
    "            ## Input tokens\n",
    "            input_tokens = tokenizer.convert_ids_to_tokens(sequence_dict['input_ids'])\n",
    "\n",
    "            ## Reconstruct the true sentence\n",
    "            input_ids = sequence_dict['input_ids']\n",
    "            seq_len = [i for (i,t) in enumerate(input_ids)  if t==102][0]\n",
    "            seq = \" \".join(input_tokens[1:seq_len]).replace(' ##','')\n",
    "            \n",
    "            ## Construct the Spacy representation of the sentence\n",
    "            doc = nlp(seq)\n",
    "            no_problem_size_1 =  len(doc) == len(seq.split(' '))\n",
    "\n",
    "            if(no_problem_size_1):\n",
    "\n",
    "                ## Construct the labels for the graph\n",
    "                ## The labels refer to the label of the bert representation \n",
    "                y = torch.tensor(input_labels)\n",
    "                \n",
    "                list_index_to_replicate = [i for (i,ele) in enumerate(input_tokens) if (ele[0:2] == \"##\")]\n",
    "                \n",
    "                tab_edges_BERT = construct_the_edges_matrix_consistuency_3_padding(doc, list_index_to_replicate, max_indice_nt)\n",
    "                edge_matrix = torch.tensor(tab_edges_BERT).transpose(0,1)\n",
    "\n",
    "                #text_of_sentence = [\"NT\"]*max_indice_nt + seq.split(' ')\n",
    "                #indice_nt_list = max_indice_nt*[0]+(max_sequence_length-max_indice_nt)*[1]\n",
    "                #indice_nt_tensor = torch.tensor(indice_nt_list)\n",
    "                    \n",
    "                sentence_level_label = torch.tensor(label2id[ad[\"sentence_level_stance\"]])\n",
    "\n",
    "                data = Data(\n",
    "                    input_ids = torch.tensor(sequence_dict['input_ids']).unsqueeze(0),\n",
    "                    attention_mask = torch.tensor(sequence_dict['attention_mask']).unsqueeze(0), \n",
    "                    token_type_ids = torch.tensor(sequence_dict['token_type_ids']).unsqueeze(0),\n",
    "                    edge_index = edge_matrix,  \n",
    "                    y = y, \n",
    "                    #indice_nt_tensor = indice_nt_tensor, \n",
    "                    #text_of_sentence = text_of_sentence,\n",
    "                    sentence_level_label = sentence_level_label\n",
    "                )\n",
    "                \n",
    "                data.num_nodes = 64 + max_indice_nt\n",
    "\n",
    "                if ad[target_domain] == 'Train':\n",
    "                    path = data_dir + \"/\" + ad[target_domain] + \"/processed/\" + \"data_\" + str(number_data_train)+\".pt\"\n",
    "                    torch.save(data, path)\n",
    "                    number_data_train += 1\n",
    "                elif ad[target_domain] == 'Dev':\n",
    "                    path = data_dir + \"/\" + ad[target_domain] + \"/processed/\" + \"data_\" + str(number_data_dev)+\".pt\"\n",
    "                    torch.save(data, path)\n",
    "                    number_data_dev += 1\n",
    "                elif ad[target_domain] == 'Test':\n",
    "                    path = data_dir + \"/\" + ad[target_domain] + \"/processed/\" + \"data_\" + str(number_data_test)+\".pt\"\n",
    "                    torch.save(data, path)\n",
    "                    number_data_test += 1\n",
    "                else:\n",
    "                    num_error +=1 \n",
    "            else:\n",
    "                num_error +=1 \n",
    "\n",
    "\n",
    "    print(\"size of the train set : \" + str(number_data_train))\n",
    "    print(\"size of the test set : \" + str(number_data_test ))\n",
    "    print(\"size of the dev set : \" + str(number_data_dev ))\n",
    "    print(\"the number of error is : \" + str(num_error ))\n",
    "    print(\"the maximum number of internal node is : \", str(max_indice_nt))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aad95892-cc4f-4950-8f33-f3cb8034dd33",
   "metadata": {},
   "source": [
    "### Definition of the Pytorch Geometric Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "tribal-parts",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '../../../data/aurc/bert/end_2_end_IN'\n",
    "\n",
    "class MyOwnDataset(Dataset):\n",
    "    def __init__(self, num_data, root, transform=None, pre_transform=None):\n",
    "        super(MyOwnDataset, self).__init__(root, transform, pre_transform)\n",
    "        self.num_data = num_data\n",
    "        \n",
    "    @property\n",
    "    def raw_file_names(self):\n",
    "        return []\n",
    "\n",
    "    @property\n",
    "    def processed_file_names(self):\n",
    "        return ['data_{}.pt'.format(idx) for idx in range(self.num_data)]\n",
    "\n",
    "    def len(self):\n",
    "        return len(self.processed_file_names)\n",
    "\n",
    "    def get(self, idx):\n",
    "        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "greatest-trading",
   "metadata": {},
   "outputs": [],
   "source": [
    "#train_dataset = MyOwnDataset(num_data = 3960,root = data_dir+\"/Train\")\n",
    "#test_dataset = MyOwnDataset(num_data = 1959,root = data_dir+\"/Test\")\n",
    "#dev_dataset = MyOwnDataset(num_data = 790,root = data_dir+\"/Dev\")\n",
    "\n",
    "train_dataset = MyOwnDataset(num_data = 4157,root = data_dir+\"/Train\")\n",
    "test_dataset = MyOwnDataset(num_data = 1189,root = data_dir+\"/Test\")\n",
    "dev_dataset = MyOwnDataset(num_data = 593,root = data_dir+\"/Dev\")\n",
    "\n",
    "print(\"train size : \" + str(len(train_dataset)))\n",
    "print(\"dev size : \" + str(len(dev_dataset)))\n",
    "print(\"test size : \" + str(len(test_dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "growing-purchase",
   "metadata": {},
   "outputs": [],
   "source": [
    "ele = train_dataset[1]\n",
    "print(\"ele representation : \" + str(ele)  )\n",
    "print(\"edge_index : \",  str(ele[\"edge_index\"]) , \"the size of ele index is : \",str(ele[\"edge_index\"].size()) )\n",
    "print(\"attention_mask : \",  str(ele[\"attention_mask\"]), \"the size of attention_mask is : \",str(ele[\"attention_mask\"].size()) )\n",
    "print(\"y : \",  str(ele[\"y\"]), \"the size of y is : \",str(ele[\"y\"].size()) )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "383eff10-4466-4440-ba34-eee797b6aad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ele[\"y\"].size()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e42a3e62",
   "metadata": {},
   "source": [
    "# Model Creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "skilled-preview",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 190\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)\n",
    "dev_loader = DataLoader(dev_dataset, batch_size=batch_size,shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11b2bc3e-5d20-4d3b-b11a-22777d7342b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenbert = BertForTokenClassification.from_pretrained(\n",
    "            \"bert-large-uncased\",\n",
    "            num_labels=3,\n",
    "            output_hidden_states=False,\n",
    "            output_attentions=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f83052be-01a8-4a18-bc34-85c973668a56",
   "metadata": {},
   "source": [
    "## Expérimentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d935270e-349a-4205-a142-b8a9fdfe9003",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_size_indice_nt = 21"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1766772-6022-4e03-b8bd-cdc52cf76227",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net_large_end2end(torch.nn.Module):\n",
    "    def __init__(self,num_node_features,num_classes, batch_first=True, output_hidden_states=False,\n",
    "            output_attentions=False, max_length = 64, model_name=\"bert-large-uncased\"):\n",
    "        super(Net_large_end2end, self).__init__()\n",
    "        \n",
    "        self.num_classes = num_classes\n",
    "        self.batch_first = batch_first\n",
    "        self.max_length = max_length\n",
    "        self.num_node_features = num_node_features\n",
    "        \n",
    "        self.tokenbert = BertForTokenClassification.from_pretrained(\n",
    "            model_name,\n",
    "            num_labels=self.num_classes,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            output_attentions=output_attentions\n",
    "        )\n",
    "        \n",
    "        self.conv0 = GATConv(self.num_node_features, 290,heads = 3)\n",
    "        self.conv1 = GATConv(290*3, 103,heads = 3)\n",
    "        self.lin1 = Linear(103*3, 96)\n",
    "        \n",
    "        self.lin2 = Linear(96, self.num_classes)\n",
    "        \n",
    "    def forward(self, data, labels=None):\n",
    "        \n",
    "        #input_ids = input_ids.view(-1,self.max_length)\n",
    "        #attention_mask = attention_mask.view(-1,self.max_length)\n",
    "        #token_type_ids = token_type_ids.view(-1,self.max_length)\n",
    "        \n",
    "        outputs= self.tokenbert.bert(\n",
    "            data.input_ids,\n",
    "            attention_mask=data.attention_mask,\n",
    "            token_type_ids= data.token_type_ids\n",
    "        ).last_hidden_state \n",
    "        \n",
    "        #sequence_output = self.tokenbert.dropout(outputs) \n",
    "        #logits = self.tokenbert.classifier(outputs)\n",
    "        \n",
    "        hidden_states = torch.zeros((outputs.size(0),self.max_length+max_size_indice_nt,self.num_node_features))\n",
    "        hidden_states[:,max_size_indice_nt:,:] = outputs\n",
    "        hidden_states = hidden_states.view(-1,self.num_node_features)\n",
    "        \n",
    "        edge_index =  to_undirected(data.edge_index)\n",
    "        \n",
    "        x = self.conv0(hidden_states, data.edge_index)\n",
    "        x = F.relu(x)\n",
    "        #x = F.dropout(x,0.1,training = self.training)\n",
    "        \n",
    "        x = self.conv1(x, data.edge_index)\n",
    "        x = F.relu(x)\n",
    "        #x = F.dropout(x,0.1,training = self.training)\n",
    "        \n",
    "        x = self.lin1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.lin2(x)\n",
    "        \n",
    "        x = x.view(-1,self.max_length+max_size_indice_nt,3)\n",
    "        if(labels is not None):\n",
    "            x = x[:,max_size_indice_nt:,:]\n",
    "            x = x.reshape((-1,3))\n",
    "            loss = nn.CrossEntropyLoss()\n",
    "            loss_calc = loss(x, labels)\n",
    "            return loss_calc\n",
    "        else:\n",
    "            x = x[:,max_size_indice_nt:,:]\n",
    "            return x.argmax(1)\n",
    "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb8ea989-6423-4470-9282-c61bb8f600eb",
   "metadata": {},
   "source": [
    "## Train and evaluate the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a40e9c63-a8db-4264-9bff-c278d2bbec85",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataloader):\n",
    "        \n",
    "    model.train()\n",
    "    log_interval = 20\n",
    "    start_time = time.time()\n",
    "    dict_param = {}\n",
    "    tensor_board_grad = False\n",
    "    \n",
    "    for idx, (ele) in enumerate(dataloader):\n",
    "        ele = ele.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss = model(ele,ele[\"y\"])\n",
    "        loss.backward()\n",
    "\n",
    "        # gradient clipping\n",
    "        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)\n",
    "        optimizer.step()\n",
    "        \n",
    "        # Print the loss at a given frequency\n",
    "        if idx % log_interval == 0 and idx > 0:\n",
    "            elapsed = time.time() - start_time\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches '\n",
    "                      '| loss {:8.3f}'.format(epoch, idx, len(dataloader),\n",
    "                                                  loss))\n",
    "            start_time = time.time()\n",
    "        \n",
    "        # Save the gradient of the layer of the model to TensorBoard\n",
    "        if(tensor_board_grad == True):\n",
    "            for i, (name, param) in enumerate(model.named_parameters()):\n",
    "                            if param.requires_grad:\n",
    "                                if (name not in dict_param.keys()):\n",
    "                                    dict_param[name]= param.grad.view(-1)\n",
    "                                else:\n",
    "                                    dict_param[name] += param.grad.view(-1)\n",
    "                                    \n",
    "    return dict_param"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54371d45-aec2-46b8-a84f-49654e9ae280",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 3\n",
    "num_node_features = 1024\n",
    "max_grad_norm = 9.75\n",
    "device = \"cpu\" #torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
    "model = Net_large_end2end(num_node_features,num_classes).to(device)\n",
    "\n",
    "param_optimizer = list(model.named_parameters())\n",
    "no_decay = ['bias', 'gamma', 'beta']\n",
    "optimizer_grouped_parameters = [\n",
    "    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
    "     'weight_decay_rate': 0.01},\n",
    "    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],\n",
    "     'weight_decay_rate': 0.0}\n",
    "]\n",
    "\n",
    "optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=3e-5)\n",
    "\n",
    "lmbda = lambda epoch: 0.99\n",
    "scheduler =  torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)\n",
    "\n",
    "sentence_level = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4e768cc-d855-4abc-8dbc-e31217495328",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "def find_the_most_frequent_element(list_ele):\n",
    "    dict_size_ele = Counter(list_ele)\n",
    "    if(len(dict_size_ele) == 1):\n",
    "        return list(dict_size_ele.keys())[0]\n",
    "    else:\n",
    "        if( 1 in list(dict_size_ele.keys())):\n",
    "            dict_size_ele.pop(1)\n",
    "        return max(dict_size_ele, key=dict_size_ele.get)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "845a6651-666f-412a-ae80-861959576baa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(dataloader):\n",
    "    model.eval()\n",
    "    \n",
    "    total_acc, total_count, total_predict = 0, 0, 0\n",
    "    total_acc_sentence_level, total_size_sentence_level = 0, 0\n",
    "    \n",
    "    # initalize the variable to compute the F1 score at sentence level\n",
    "    total_class_predict_sentence = []\n",
    "    total_true_class_sentence = []\n",
    "    \n",
    "    # initalize the variable to compute the F1 score at token level\n",
    "    total_class_predict_token = []\n",
    "    total_true_class_token = []\n",
    "    \n",
    "    val_to_replace = torch.tensor(1e9, dtype=torch.float32).to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for idx, (ele) in enumerate(dataloader):\n",
    "            ele = ele.to(device)\n",
    "            \n",
    "            class_predict = model(ele)\n",
    "            \n",
    "            ## Flatten the result\n",
    "            class_predict_flat = [item for sublist in class_predict for item in sublist]\n",
    "            total_class_predict_token.extend(class_predict_flat)\n",
    "            \n",
    "            ## Extract the true label without the padding\n",
    "            attention_mask = ele[\"attention_mask\"].flatten()\n",
    "            ele_y_clean = torch.where(attention_mask  > 0, ele[\"y\"], 10)\n",
    "            ele_y_clean = ele_y_clean[ele_y_clean!=10].tolist()\n",
    "            total_true_class_token.extend(ele_y_clean)\n",
    "                        \n",
    "            # Compute the sentence level label & accuracy\n",
    "            sentence_level_label = [find_the_most_frequent_element(list_ele) for list_ele in class_predict]\n",
    "            total_class_predict_sentence.extend(sentence_level_label)\n",
    "            \n",
    "            true_sentence_level_label = ele[\"sentence_level_label\"].tolist()\n",
    "            total_true_class_sentence.extend(true_sentence_level_label) \n",
    "            \n",
    "            total_acc_sentence_level += sum([ ele_1 == ele_2 for ele_1,ele_2 in zip(sentence_level_label,true_sentence_level_label)])\n",
    "            total_size_sentence_level += len(true_sentence_level_label)\n",
    "            \n",
    "            # Count the number of non 0 values to know if the prediction is only 0 or not\n",
    "            total_predict += sum(class_predict_flat)\n",
    "            \n",
    "            # Compute the token level accuracy\n",
    "            \n",
    "            total_acc += sum([1 for (ele1,ele2) in zip(class_predict_flat,ele_y_clean) if ele1 == ele2])\n",
    "            total_count += len(ele_y_clean)\n",
    "                                            \n",
    "        f1_token_level = f1_score(total_class_predict_token,total_true_class_token,labels = [0,1,2], average=\"micro\")\n",
    "        f1_sentence_level = f1_score(total_class_predict_sentence,total_true_class_sentence,labels = [0,1,2], average=\"micro\")\n",
    "        \n",
    "    return total_acc/total_count, total_predict, f1_token_level, total_acc_sentence_level/total_size_sentence_level, f1_sentence_level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c508aeff-d33e-4818-affc-b9123e432ea1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# You can uncomment the commented code in order to monitor the runs with tensorboard. \n",
    "# Writer will output to ./runs/ directory by default\n",
    "# writer = SummaryWriter(f'runs/{datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")}')\n",
    "time_string = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "for epoch in tqdm(range(6)):\n",
    "    dict_norm = train(train_loader)\n",
    "    \n",
    "    accu_train, total_predict_train, f1_token_level_train, acc_sentence_level_train, f1_sentence_level_train = evaluate(train_loader)\n",
    "    \n",
    "    #accu_val, total_predict, f1_token_level_dev, acc_sentence_level_dev, f1_sentence_level_dev = evaluate(dev_loader)\n",
    "    acc_test, total_pred_test, f1_token_level_test, acc_sentence_level_test, f1_sentence_level_test = evaluate(test_loader)\n",
    "    scheduler.step()\n",
    "    \n",
    "    #for key, values in dict_norm.items():\n",
    "    #    writer.add_histogram(tag = 'distribution of the gradient of the layer : ' + str(key), \n",
    "    #                         values = values.unsqueeze(0).T , \n",
    "    #                         global_step = epoch)\n",
    "    \n",
    "    \n",
    "    #writer.add_scalars('Accuracy score token level over batchs',  {'test':acc_test,'eval':accu_val, 'train':accu_train},epoch)\n",
    "    #writer.add_scalars('F1 score token level over batchs',  {'test':f1_token_level_test,'eval':f1_token_level_dev, 'train':f1_token_level_train},epoch)\n",
    "    \n",
    "    #writer.add_scalars('Accuracy score sentence level over batchs',  {'test':acc_sentence_level_test,'eval':acc_sentence_level_dev, 'train':acc_sentence_level_train},epoch)\n",
    "    #writer.add_scalars('F1 score sentence level over batchs',  {'test':f1_sentence_level_test,'eval':f1_sentence_level_dev, 'train':f1_sentence_level_train},epoch)\n",
    "    \n",
    "    print('accuracy on the train dataset' + str(accu_train) + \" the number of non 0 values predicted is : \" + str(int(total_predict_train)))\n",
    "    print(\"accuracy on the sentence level train is \" + str(acc_sentence_level_train))\n",
    "    \n",
    "    print(\"\")\n",
    "    print(\"f1 score on token level train is : \" + str(f1_token_level_train))\n",
    "    #print(\"f1 score on span level evaluate is : \" + str(f1_span_level_train))\n",
    "    print(\"f1 score on sentence level train is : \" + str(f1_sentence_level_train))\n",
    "    \n",
    "    #print(\"\")\n",
    "    #print('accuracy on the evaluate dataset' + str(accu_val) + \" the number of non 0 values predicted is : \" + str(int(total_predict)))\n",
    "    #print(\"accuracy on the sentence level evaluate\" + str(acc_sentence_level_dev))\n",
    "    \n",
    "    #print(\"\")\n",
    "    #print(\"f1 score on token level evaluate is : \" + str(f1_token_level_dev))\n",
    "    #print(\"f1 score on span level evaluate is : \" + str(f1_span_level_dev))\n",
    "    #print(\"f1 score on sentence level evaluate is : \" + str(f1_sentence_level_dev))\n",
    "    \n",
    "   \n",
    "    \n",
    "    print(\"\")\n",
    "    print('accuracy on the test dataset' + str(acc_test) + \" the number of non 0 values predicted is : \" + str(int(total_pred_test)))\n",
    "    print(\"accuracy on the sentence level test\" + str(acc_sentence_level_test))\n",
    "    \n",
    "    print(\"\")\n",
    "    print(\"f1 score on token level test is : \" + str(f1_token_level_test))\n",
    "    #print(\"f1 score on span level test is : \" + str(f1_span_level_test))\n",
    "    print(\"f1 score on sentence level  test is : \" + str(f1_sentence_level_test))\n",
    "    \n",
    "    f = open(time_string+\".txt\", \"a\")\n",
    "    f.write('accuracy on the train dataset' + str(accu_train) + \"for the epoch : \" + str(epoch)+\"\\n\")\n",
    "    f.write(\"f1 score on token level train is : \" + str(f1_token_level_train)+\"\\n\")\n",
    "    f.write(\"f1 score on token level test is : \" + str(f1_token_level_test)+\"\\n\")\n",
    "    f.write(\"f1 score on sentence level  test is : \" + str(f1_sentence_level_test) +\"\\n\")\n",
    "    f.close()\n",
    "    \n",
    "#writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86694621-6a4f-432f-9f35-3d2eb35b80f4",
   "metadata": {},
   "source": [
    "## Others"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63ee3375-3ab8-477f-88de-64199fa42732",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = nn.CrossEntropyLoss()\n",
    "input = torch.randn(3, 5, requires_grad=True)\n",
    "target = torch.empty(3, dtype=torch.long).random_(5)\n",
    "output = loss(input, target)\n",
    "print(input.size())\n",
    "print(target.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bb5d385-3139-404d-87af-f7ea90227b22",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = train_loader\n",
    "model.eval()\n",
    "    \n",
    "total_acc, total_count, total_predict = 0, 0, 0\n",
    "total_acc_sentence_level, total_size_sentence_level = 0, 0\n",
    "    \n",
    "# initalize the variable to compute the F1 score at sentence level\n",
    "total_class_predict_sentence = []\n",
    "total_true_class_sentence = []\n",
    "    \n",
    "# initalize the variable to compute the F1 score at token level\n",
    "total_class_predict_token = []\n",
    "total_true_class_token = []\n",
    "    \n",
    "val_to_replace = torch.tensor(1e9, dtype=torch.float32).to(device)\n",
    "    \n",
    "with torch.no_grad():\n",
    "    for idx, (ele) in enumerate(dataloader):\n",
    "        ele = ele.to(device)\n",
    "            \n",
    "        class_predict = model(ele)\n",
    "        print(len(class_predict))\n",
    "        print(len(class_predict[0]))\n",
    "        \n",
    "        print(\"ele[y] : \" , str(ele[\"y\"].size()))\n",
    "            \n",
    "        ## Flatten the result\n",
    "        class_predict_flat = [item for sublist in class_predict for item in sublist]\n",
    "        print(\"class_predict_flat\" , str(len(class_predict_flat)))\n",
    "        total_class_predict_token.extend(class_predict_flat)\n",
    "        \n",
    "        \n",
    "        \n",
    "        ## Extract the true label without the padding\n",
    "        print(ele[\"attention_mask\"].size())\n",
    "        attention_mask = ele[\"attention_mask\"].flatten()\n",
    "        ele_y_clean = torch.where(attention_mask  > 0, ele[\"y\"], 10)\n",
    "        ele_y_clean = ele_y_clean[ele_y_clean!=10].tolist()\n",
    "        print(\"ele_y_clean\" , str(len(ele_y_clean)))\n",
    "        total_true_class_token.extend(ele_y_clean)\n",
    "            \n",
    "                        \n",
    "        # Compute the sentence level label & accuracy\n",
    "        sentence_level_label = [find_the_most_frequent_element(list_ele) for list_ele in class_predict]\n",
    "        print(\"sentence_level_label\" , str(len(sentence_level_label)))\n",
    "        total_class_predict_sentence.extend(sentence_level_label)\n",
    "            \n",
    "        true_sentence_level_label = ele[\"sentence_level_label\"].tolist()\n",
    "        print(\"true_sentence_level_label\" , str(len(true_sentence_level_label)))\n",
    "        total_true_class_sentence.extend(true_sentence_level_label) \n",
    "            \n",
    "        total_acc_sentence_level += sum([ ele_1 == ele_2 for ele_1,ele_2 in zip(sentence_level_label,true_sentence_level_label)])\n",
    "        total_size_sentence_level += len(true_sentence_level_label)\n",
    "        # Count the number of non 0 values to know if the prediction is only 0 or not\n",
    "        total_predict += sum(class_predict_flat)\n",
    "            \n",
    "        # Compute the token level accuracy\n",
    "            \n",
    "        total_acc += sum([1 for (ele1,ele2) in zip(class_predict_flat,ele_y_clean) if ele1 == ele2])\n",
    "        total_count += len(ele_y_clean)\n",
    "                                            \n",
    "    f1_token_level = f1_score(total_class_predict_token,total_true_class_token,labels = [0,1,2], average=\"micro\")\n",
    "    f1_sentence_level = f1_score(total_class_predict_sentence,total_true_class_sentence,labels = [0,1,2], average=\"micro\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f617aa71-e088-4b7c-8917-c70f0ea9df40",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = train_loader\n",
    "for epoch in range(60):\n",
    "    loss = nn.CrossEntropyLoss()\n",
    "    model.train()\n",
    "    total_acc, total_count = 0, 0\n",
    "    log_interval = 20\n",
    "    start_time = time.time()\n",
    "    val_to_replace = torch.tensor(-1e9, dtype=torch.float32).to(device)\n",
    "    dict_param = {}\n",
    "    tensor_board_grad = False\n",
    "\n",
    "    for idx, (ele) in enumerate(dataloader):\n",
    "        ele = ele.to(device)\n",
    "        #ele[\"x\"] = ele[\"x\"].float()\n",
    "        #ele[\"edge_index\"] =  to_undirected(ele[\"edge_index\"])\n",
    "        optimizer.zero_grad()\n",
    "        loss = model(ele,ele[\"y\"])\n",
    "        loss.backward()\n",
    "\n",
    "        # gradient clipping\n",
    "        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)\n",
    "        optimizer.step()\n",
    "\n",
    "        if idx % log_interval == 0 and idx > 0:\n",
    "            elapsed = time.time() - start_time\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches '\n",
    "                      '| loss {:8.3f}'.format(epoch, idx, len(dataloader),\n",
    "                                                  loss))\n",
    "            total_acc, total_count = 0, 0\n",
    "            start_time = time.time()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_env",
   "language": "python",
   "name": "concept_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
