{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9fb647d3",
   "metadata": {},
   "source": [
    "# Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "46e71ab0-96e7-4ce8-8477-c0703a1ad6d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"\"\n",
    "\n",
    "import spacy\n",
    "nlp = spacy.load('en_core_web_sm')\n",
    "\n",
    "import benepar\n",
    "from benepar import BeneparComponent, NonConstituentException\n",
    "benepar.download('benepar_en3')\n",
    "nlp.add_pipe(\"benepar\", config={\"model\": \"benepar_en3\"})\n",
    "\n",
    "import sys\n",
    "import json\n",
    "import os.path as osp\n",
    "from tqdm import tqdm\n",
    "\n",
    "from transformers import (BertConfig,\n",
    "                                BertForTokenClassification, BertTokenizer, BertPreTrainedModel, BertModel)\n",
    "\n",
    "import torch\n",
    "from torch_geometric.data import Dataset,Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "interstate-release",
   "metadata": {},
   "source": [
    "# Dataset creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7b7f81c4-6c6a-4302-81e3-4397948286c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_the_edges_matrix_consistuency_4_padding(doc,list_index_to_replicate, max_indice_nt):\n",
    "    edges_index = []\n",
    "    current_node_index = -1\n",
    "    current_index_replicated = 0\n",
    "    \n",
    "    for sent in doc.sents:\n",
    "        ## On regarde chaque phrases de notre document\n",
    "        current_node_index +=1\n",
    "        index_sentence = current_node_index\n",
    "        for ele in sent._.children:\n",
    "            ## On regarde chaque élément de la phrase au niveau de profondeur 1 et on regarde si eux même ont des enfants\n",
    "            num_children = sum(1 for x in ele._.children)\n",
    "            if num_children != 0 :\n",
    "                ## Il y a des éléments de profondeurs 2\n",
    "                current_node_index += 1\n",
    "                index_second_node = current_node_index\n",
    "                edges_index.append([index_second_node,index_sentence])\n",
    "                for ele_child in ele._.children:\n",
    "                    \n",
    "                    num_children_children = sum(1 for x in ele_child._.children)\n",
    "                    if num_children_children != 0 :\n",
    "                        ## Il y a des éléments de profondeurs 3\n",
    "                        current_node_index += 1\n",
    "                        index_third_node = current_node_index\n",
    "                        edges_index.append([index_third_node,index_second_node])\n",
    "                        \n",
    "                        for ele_child_child in ele_child._.children:\n",
    "                            num_children_children_children = sum(1 for x in ele_child_child._.children)\n",
    "                            if num_children_children_children != 0 :\n",
    "                                \n",
    "                                ## Il y a des éléments de profondeurs 4\n",
    "                                current_node_index += 1\n",
    "                                index_fourth_node = current_node_index\n",
    "                                edges_index.append([index_fourth_node,index_third_node])\n",
    "                                \n",
    "                                for word in ele_child_child:\n",
    "                                    if(word.i in list_index_to_replicate):\n",
    "                                        edges_index.append([word.i+max_indice_nt+current_index_replicated,index_fourth_node])\n",
    "                                        current_index_replicated +=1\n",
    "                                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_fourth_node])\n",
    "                            \n",
    "                            else:\n",
    "                                ### Il n'y a pas d'élément de profondeur 4, le noeud de profondeur 3 est une feuille\n",
    "                                for word in ele_child_child:\n",
    "                                    if(word.i in list_index_to_replicate):\n",
    "                                        edges_index.append([word.i+max_indice_nt+current_index_replicated,index_third_node])\n",
    "                                        current_index_replicated +=1\n",
    "                                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_third_node])     \n",
    "                    else:\n",
    "                        ## Il y n'y a pas des éléments de profondeurs 3, le noeud de profondeur 2 est une feuille\n",
    "                        for word in ele_child:\n",
    "                            if(word.i in list_index_to_replicate):\n",
    "                                edges_index.append([word.i+max_indice_nt+current_index_replicated,index_second_node])\n",
    "                                current_index_replicated +=1\n",
    "                            edges_index.append([word.i+max_indice_nt+current_index_replicated,index_second_node])\n",
    "            else:\n",
    "                ## Il n'y a pas d'éléments de profondeurs 2\n",
    "                for word in ele:\n",
    "                    ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase\n",
    "                    if(word.i in list_index_to_replicate):\n",
    "                        edges_index.append([word.i+max_indice_nt+current_index_replicated,index_sentence])\n",
    "                        current_index_replicated +=1\n",
    "                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_sentence])\n",
    "    return edges_index"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74a97bb5-a106-4dc0-9e1d-7c00fab64d60",
   "metadata": {},
   "source": [
    "### Import Dataset AURC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "91aa520c-d0c5-402a-bd51-c481f585ab24",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8 [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]\n",
      "8 ['abortion', 'cloning', 'death penalty', 'gun control', 'marijuana legalization', 'minimum wage', 'nuclear energy', 'school uniforms']\n"
     ]
    }
   ],
   "source": [
    "need_new_writting = True\n",
    "if(need_new_writting == True):\n",
    "    data_dir = '../../../data/aurc/bert/end_2_end_depth_4_IN'\n",
    "    input_file = 'AURC_DATA_dict.json'\n",
    "    num_labels = 3\n",
    "\n",
    "    pretrained_weights =  'bert-large-uncased'\n",
    "\n",
    "    max_sequence_length = 64\n",
    "\n",
    "    # the domains are shuffle\n",
    "    target_domain = 'In-Domain'\n",
    "\n",
    "    fname = '../../../data/aurc/AURC_DATA_dict.json'\n",
    "\n",
    "    # load the json file\n",
    "    with open(fname,'r') as my_file:\n",
    "        AURC_DATA_dict = json.load(my_file)\n",
    "    print(len(AURC_DATA_dict), [len(AURC_DATA_dict[topic]) for topic in AURC_DATA_dict.keys()])\n",
    "\n",
    "    # check the number of example per topic\n",
    "    topics = sorted(set(AURC_DATA_dict.keys()))\n",
    "    print(len(topics), topics)\n",
    "\n",
    "    # define the label to id dictionnary\n",
    "    label2id = {}\n",
    "    label2id['non'] = 0\n",
    "    label2id['con'] = 1\n",
    "    label2id['pro'] = 2\n",
    "\n",
    "    # Choose the tokenizer from Hugging Face transformers\n",
    "    tokenizer = BertTokenizer.from_pretrained(pretrained_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1af1a9a4-8cb7-4ab4-9dcc-acff2d6cb9b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/8 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
      "/home/jupyter/sguilluy/ArgumentGrammar/concept_env/lib/python3.6/site-packages/transformers/tokenization_utils_base.py:2155: FutureWarning:\n",
      "\n",
      "The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
      "\n",
      "/home/jupyter/sguilluy/ArgumentGrammar/concept_env/lib/python3.6/site-packages/torch/distributions/distribution.py:46: UserWarning:\n",
      "\n",
      "<class 'torch_struct.distributions.TreeCRF'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.\n",
      "\n",
      "100%|██████████| 8/8 [07:29<00:00, 56.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "size of the train set : 4157\n",
      "size of the test set : 1189\n",
      "size of the dev set : 593\n",
      "the number of error is : 2061\n",
      "the maximum number of internal node is :  21\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "if(need_new_writting == True):\n",
    "    train_features = []\n",
    "    eval_features = []\n",
    "    test_features = []\n",
    "    \n",
    "    max_indice_nt = 21\n",
    "    num_error = 0\n",
    "    number_data_train = 0\n",
    "    number_data_test = 0\n",
    "    number_data_dev = 0\n",
    "    \n",
    "    for topic, AD in tqdm(AURC_DATA_dict.items()):\n",
    "        for ad in AD:\n",
    "            ## Use the BERT Tokenizer from Hugging Face\n",
    "            sequence_dict = tokenizer.encode_plus(ad['sentence'], max_length=max_sequence_length, pad_to_max_length=True, \n",
    "                                                  add_special_tokens=True)\n",
    "            ## Label management\n",
    "            input_labels = [label2id[label] for label in ad['tokenized_sentence_bert_labels'].split(' ')]\n",
    "            len_sentence = min(len(input_labels[:max_sequence_length-1]),len(input_labels))\n",
    "            input_labels =  [0] + input_labels[:max_sequence_length-1] + [0]*max(0,max_sequence_length-len(input_labels)-1)\n",
    "\n",
    "            ## Input tokens\n",
    "            input_tokens = tokenizer.convert_ids_to_tokens(sequence_dict['input_ids'])\n",
    "\n",
    "            ## Reconstruct the true sentence\n",
    "            input_ids = sequence_dict['input_ids']\n",
    "            seq_len = [i for (i,t) in enumerate(input_ids)  if t==102][0]\n",
    "            seq = \" \".join(input_tokens[1:seq_len]).replace(' ##','')\n",
    "            \n",
    "            ## Construct the Spacy representation of the sentence\n",
    "            doc = nlp(seq)\n",
    "            no_problem_size_1 =  len(doc) == len(seq.split(' '))\n",
    "\n",
    "            if(no_problem_size_1):\n",
    "\n",
    "                ## Construct the labels for the graph\n",
    "                ## The labels refer to the label of the bert representation \n",
    "                y = torch.tensor(input_labels)\n",
    "                \n",
    "                list_index_to_replicate = [i for (i,ele) in enumerate(input_tokens) if (ele[0:2] == \"##\")]\n",
    "                \n",
    "                tab_edges_BERT = construct_the_edges_matrix_consistuency_4_padding(doc, list_index_to_replicate, max_indice_nt)\n",
    "                edge_matrix = torch.tensor(tab_edges_BERT).transpose(0,1)\n",
    "\n",
    "                #text_of_sentence = [\"NT\"]*max_indice_nt + seq.split(' ')\n",
    "                #indice_nt_list = max_indice_nt*[0]+(max_sequence_length-max_indice_nt)*[1]\n",
    "                #indice_nt_tensor = torch.tensor(indice_nt_list)\n",
    "                    \n",
    "                sentence_level_label = torch.tensor(label2id[ad[\"sentence_level_stance\"]])\n",
    "\n",
    "                data = Data(\n",
    "                    input_ids = torch.tensor(sequence_dict['input_ids']).unsqueeze(0),\n",
    "                    attention_mask = torch.tensor(sequence_dict['attention_mask']).unsqueeze(0), \n",
    "                    token_type_ids = torch.tensor(sequence_dict['token_type_ids']).unsqueeze(0),\n",
    "                    edge_index = edge_matrix,  \n",
    "                    y = y, \n",
    "                    #indice_nt_tensor = indice_nt_tensor, \n",
    "                    #text_of_sentence = text_of_sentence,\n",
    "                    sentence_level_label = sentence_level_label\n",
    "                )\n",
    "                \n",
    "                data.num_nodes = 64 + max_indice_nt\n",
    "\n",
    "                if ad[target_domain] == 'Train':\n",
    "                    path = data_dir + \"/\" + ad[target_domain] + \"/processed/\" + \"data_\" + str(number_data_train)+\".pt\"\n",
    "                    torch.save(data, path)\n",
    "                    number_data_train += 1\n",
    "                elif ad[target_domain] == 'Dev':\n",
    "                    path = data_dir + \"/\" + ad[target_domain] + \"/processed/\" + \"data_\" + str(number_data_dev)+\".pt\"\n",
    "                    torch.save(data, path)\n",
    "                    number_data_dev += 1\n",
    "                elif ad[target_domain] == 'Test':\n",
    "                    path = data_dir + \"/\" + ad[target_domain] + \"/processed/\" + \"data_\" + str(number_data_test)+\".pt\"\n",
    "                    torch.save(data, path)\n",
    "                    number_data_test += 1\n",
    "                else:\n",
    "                    num_error +=1 \n",
    "            else:\n",
    "                num_error +=1 \n",
    "\n",
    "\n",
    "    print(\"size of the train set : \" + str(number_data_train))\n",
    "    print(\"size of the test set : \" + str(number_data_test ))\n",
    "    print(\"size of the dev set : \" + str(number_data_dev ))\n",
    "    print(\"the number of error is : \" + str(num_error ))\n",
    "    print(\"the maximum number of internal node is : \", str(max_indice_nt))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aad95892-cc4f-4950-8f33-f3cb8034dd33",
   "metadata": {},
   "source": [
    "# Test the Pytorch Geometric Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "tribal-parts",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '../../../data/aurc/bert/end_2_end_depth_4_IN'\n",
    "\n",
    "class MyOwnDataset(Dataset):\n",
    "    def __init__(self, num_data, root, transform=None, pre_transform=None):\n",
    "        super(MyOwnDataset, self).__init__(root, transform, pre_transform)\n",
    "        self.num_data = num_data\n",
    "        \n",
    "    @property\n",
    "    def raw_file_names(self):\n",
    "        return []\n",
    "\n",
    "    @property\n",
    "    def processed_file_names(self):\n",
    "        return ['data_{}.pt'.format(idx) for idx in range(self.num_data)]\n",
    "\n",
    "    def len(self):\n",
    "        return len(self.processed_file_names)\n",
    "\n",
    "    def get(self, idx):\n",
    "        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "greatest-trading",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train size : 4157\n",
      "dev size : 593\n",
      "test size : 1189\n"
     ]
    }
   ],
   "source": [
    "#train_dataset = MyOwnDataset(num_data = 3960,root = data_dir+\"/Train\")\n",
    "#test_dataset = MyOwnDataset(num_data = 1959,root = data_dir+\"/Test\")\n",
    "#dev_dataset = MyOwnDataset(num_data = 790,root = data_dir+\"/Dev\")\n",
    "\n",
    "train_dataset = MyOwnDataset(num_data = 4157,root = data_dir+\"/Train\")\n",
    "test_dataset = MyOwnDataset(num_data = 1189,root = data_dir+\"/Test\")\n",
    "dev_dataset = MyOwnDataset(num_data = 593,root = data_dir+\"/Dev\")\n",
    "\n",
    "print(\"train size : \" + str(len(train_dataset)))\n",
    "print(\"dev size : \" + str(len(dev_dataset)))\n",
    "print(\"test size : \" + str(len(test_dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "growing-purchase",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ele representation : Data(attention_mask=[1, 64], edge_index=[2, 33], input_ids=[1, 64], sentence_level_label=2, token_type_ids=[1, 64], y=[64])\n",
      "edge_index :  tensor([[ 1,  2, 21, 22, 23, 24, 25,  3, 26, 27,  4, 28, 29, 30, 31, 32, 33, 34,\n",
      "         35, 36, 37,  5, 38, 39,  6, 40,  7, 41, 42, 43, 44, 45, 46],\n",
      "        [ 0,  1,  2,  2,  2,  2,  1,  1,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,\n",
      "          4,  0,  0,  0,  5,  5,  5,  6,  6,  7,  7,  7,  7,  7,  0]]) the size of ele index is :  torch.Size([2, 33])\n",
      "attention_mask :  tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) the size of attention_mask is :  torch.Size([1, 64])\n",
      "y :  tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) the size of y is :  torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "ele = train_dataset[1]\n",
    "print(\"ele representation : \" + str(ele)  )\n",
    "print(\"edge_index : \",  str(ele[\"edge_index\"]) , \"the size of ele index is : \",str(ele[\"edge_index\"].size()) )\n",
    "print(\"attention_mask : \",  str(ele[\"attention_mask\"]), \"the size of attention_mask is : \",str(ele[\"attention_mask\"].size()) )\n",
    "print(\"y : \",  str(ele[\"y\"]), \"the size of y is : \",str(ele[\"y\"].size()) )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_env",
   "language": "python",
   "name": "concept_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
