{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "yellow-honduras",
   "metadata": {},
   "source": [
    "# Install the libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bizarre-static",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"\"\n",
    "import torch\n",
    "print(torch.version.cuda)\n",
    "print(torch.__version__)\n",
    "import spacy\n",
    "from spacy import displacy\n",
    "nlp = spacy.load('en_core_web_sm')\n",
    "import benepar\n",
    "from benepar import BeneparComponent, NonConstituentException\n",
    "benepar.download('benepar_en3')\n",
    "nlp.add_pipe(\"benepar\", config={\"model\": \"benepar_en3\"})\n",
    "from torch_geometric.data import Data\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "from transformers import (AdamW, BertConfig, get_linear_schedule_with_warmup, \n",
    "                                BertForTokenClassification, BertTokenizer, BertPreTrainedModel, BertModel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "processed-milwaukee",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Personal Library\n",
    "import sys\n",
    "sys.path.insert(0, '../../utils/')\n",
    "from utils_AURC import InputFeatures_Constituent, parse_data\n",
    "\n",
    "from utils_GNN_CRF import (\n",
    "    add_dimensions_null_to_embedding\n",
    ")\n",
    "from collections import Counter\n",
    "\n",
    "from constituency_parsing_utils_depth_3 import (\n",
    "    reshape_labels_nodes_3,\n",
    "    find_the_related_nodes_3,\n",
    "    construct_the_edges_matrix_consistuency_3,\n",
    "    construct_the_edges_matrix_consistuency_3_padding,\n",
    "    find_number_internal_node_3\n",
    ")\n",
    "\n",
    "from constituency_parsing_utils_depth_2 import (\n",
    "    reshape_labels_nodes_2,\n",
    "    find_the_related_nodes_2,\n",
    "    construct_the_edges_matrix_consistuency_2,\n",
    "    find_number_internal_node_2\n",
    ")\n",
    "\n",
    "from constituency_parsing_utils_depth_4 import (\n",
    "    reshape_labels_nodes_4,\n",
    "    find_the_related_nodes_4,\n",
    "    construct_the_edges_matrix_consistuency_4,\n",
    "    find_number_internal_node_4\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "lasting-cleanup",
   "metadata": {},
   "source": [
    "## Import our Dataset AURC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "saving-friday",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8 [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]\n",
      "8 ['abortion', 'cloning', 'death penalty', 'gun control', 'marijuana legalization', 'minimum wage', 'nuclear energy', 'school uniforms']\n"
     ]
    }
   ],
   "source": [
    "data_dir = '../../data/aurc/bert/large_depth_IN_connected' #'../data/aurc/bert/pro_con_3'\n",
    "input_file = 'AURC_DATA_dict.json'\n",
    "num_labels = 3\n",
    "\n",
    "pretrained_weights =  'bert-large-uncased' #'bert-base-uncased'\n",
    "\n",
    "max_sequence_length = 64\n",
    "\n",
    "# the domains are shuffle\n",
    "target_domain = 'Cross-Domain'\n",
    "\n",
    "fname = '../../data/aurc/AURC_DATA_dict.json'\n",
    "\n",
    "# load the json file\n",
    "with open(fname,'r') as my_file:\n",
    "    AURC_DATA_dict = json.load(my_file)\n",
    "print(len(AURC_DATA_dict), [len(AURC_DATA_dict[topic]) for topic in AURC_DATA_dict.keys()])\n",
    "\n",
    "# check the number of example per topic\n",
    "topics = sorted(set(AURC_DATA_dict.keys()))\n",
    "print(len(topics), topics)\n",
    "\n",
    "# define the label to id dictionnary\n",
    "label2id = {}\n",
    "label2id['non'] = 0\n",
    "label2id['con'] = 1\n",
    "label2id['pro'] = 2\n",
    "    \n",
    "# Choose the tokenizer from Hugging Face transformers\n",
    "tokenizer = BertTokenizer.from_pretrained(pretrained_weights)\n",
    "model = BertModel.from_pretrained(pretrained_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0dfba96-b750-4258-b520-d8cd9f309bb9",
   "metadata": {},
   "source": [
    "# Tests of the different part of the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35930346-8c43-4966-b5c3-ea7b579d968f",
   "metadata": {},
   "source": [
    "## Count_similarity_neighbourg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b74b168b-9606-4c4b-a43b-1da150051309",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_similarity_neighbourg(new_labels,indice_nt):\n",
    "    total_len = len(new_labels)\n",
    "    num_similar_neighbourg = 0\n",
    "    total_link = total_len-indice_nt-1\n",
    "    for i in range(indice_nt,total_len-1):\n",
    "        if(new_labels[i] == new_labels[i+1]):\n",
    "            num_similar_neighbourg+=1\n",
    "        \n",
    "    return (num_similar_neighbourg,total_link)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "013890b2-bb0b-438b-a45e-219ed09753d5",
   "metadata": {},
   "source": [
    "## Count_similarity_neighbourg_same_parent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "77ec5934-9d41-497f-8159-443fd4ec4ecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_similarity_neighbourg_same_parent(new_labels,indice_nt,tab_edges_BERT):\n",
    "    \n",
    "    dict_parent = dict()\n",
    "    for ele in tab_edges_BERT:\n",
    "        dict_parent[ele[0]] = ele[1]\n",
    "    \n",
    "    total_len = len(new_labels)\n",
    "    num_similar_neighbourg = 0\n",
    "    total_link = 0\n",
    "    \n",
    "    for i in range(indice_nt,total_len-1):\n",
    "        if(i in dict_parent.keys()):\n",
    "            if((i+1) in dict_parent.keys()):\n",
    "                if(dict_parent[i] == dict_parent[i+1]):\n",
    "                    if(new_labels[i] == new_labels[i+1]):\n",
    "                        num_similar_neighbourg+=1\n",
    "                    total_link +=1\n",
    "        \n",
    "    return (num_similar_neighbourg,total_link)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91cebdd3-8d44-434a-8823-1082d114392a",
   "metadata": {},
   "source": [
    "## Count_similarity_in_leaf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f1e03624-0a1b-4175-a1fa-2bcc733ab28a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_similarity_in_leaf(new_labels,indice_nt,tab_edges_BERT):\n",
    "    total_len = len(new_labels)\n",
    "    total_word = total_len-indice_nt\n",
    "    num_similar_labels = 0\n",
    "    ## Parcourir l'intégralité des feuilles\n",
    "    for i in range(indice_nt,total_len):\n",
    "        ## Trouver le noeud interne associé à cette feuille\n",
    "        for ele in tab_edges_BERT:\n",
    "            if(ele[0] == i):\n",
    "                ## Regarder si ils ont le même label\n",
    "                if(new_labels[ele[1]] == new_labels[i]):\n",
    "                    num_similar_labels +=1\n",
    "    return (num_similar_labels,total_word)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "013ff6fb-60e5-4714-8e0a-be6f7402ff46",
   "metadata": {},
   "source": [
    "## Count_similarity_intern"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "278cb842-fc47-4868-98a4-493a8ca7b8fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_similarity_intern(new_labels,indice_nt,tab_edges_BERT):\n",
    "    total_len = len(new_labels)\n",
    "    num_similar_labels = 0\n",
    "    num_relation = 0\n",
    "    ## On parcours tous les noeuds internes sauf le noeud racine\n",
    "    for j in range(1,indice_nt):\n",
    "        ## Trouver le noeud interne parent\n",
    "        for ele in tab_edges_BERT:\n",
    "            if(ele[0] == j):\n",
    "                ## Regarder si ils ont le même label\n",
    "                if(new_labels[ele[1]] == new_labels[j]):\n",
    "                    num_similar_labels +=1\n",
    "                ## Compter le nombre totale de relation\n",
    "                num_relation+=1\n",
    "    return (num_similar_labels,num_relation)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "422b6e21-079c-4e1a-ab6a-0ed6f9bc1794",
   "metadata": {},
   "source": [
    "## Stat computation for a maximum depth of 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4a472cf2-73aa-4557-8282-299a6ae34cd3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/8 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
      "/home/jupyter/sguilluy/ArgumentGrammar/concept_env/lib/python3.6/site-packages/transformers/tokenization_utils_base.py:2155: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
      "  FutureWarning,\n",
      "/home/jupyter/sguilluy/ArgumentGrammar/concept_env/lib/python3.6/site-packages/torch/distributions/distribution.py:46: UserWarning: <class 'torch_struct.distributions.TreeCRF'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.\n",
      "  'with `validate_args=False` to turn off validation.')\n",
      "100%|██████████| 8/8 [06:56<00:00, 52.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the results for a miximum depth of 3 are : \n",
      "The neighbourg similarity :  0.9682025507684262\n",
      "The neighbourg similarity by grammatical class :  0.982525430301584\n",
      "The similarity between children and parents :  0.905047942394508\n",
      "The similarity between internal nodes :  0.9316578744591311\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "total_num_similar_neighbourg = 0\n",
    "total_total_link = 0\n",
    "\n",
    "total_num_similar_neighbourg_same_parent = 0\n",
    "total_total_link_same_parent = 0\n",
    "\n",
    "total_num_similar_labels_extern = 0\n",
    "total_total_word = 0\n",
    "\n",
    "total_num_similar_labels_intern = 0\n",
    "total_num_relation_intern = 0\n",
    "\n",
    "for topic, AD in tqdm(AURC_DATA_dict.items()):\n",
    "    #print(topic)\n",
    "    for ad in AD:\n",
    "        ## Use the BERT Tokenizer from Hugging Face\n",
    "        sequence_dict = tokenizer.encode_plus(ad['sentence'], max_length=max_sequence_length, pad_to_max_length=True, \n",
    "                                              add_special_tokens=True)\n",
    "        \n",
    "        ## Label management\n",
    "        input_labels = [label2id[label] for label in ad['tokenized_sentence_bert_labels'].split(' ')]\n",
    "        len_sentence = min(len(input_labels[:max_sequence_length-1]),len(input_labels))\n",
    "        input_labels =  [0] + input_labels[:max_sequence_length-1] + [0]*max(0,max_sequence_length-len(input_labels)-1)\n",
    "        sequence_dict['label_ids'] = input_labels\n",
    "        \n",
    "        ## Input tokens\n",
    "        input_tokens = tokenizer.convert_ids_to_tokens(sequence_dict['input_ids'])\n",
    "        \n",
    "        ## Compute the last hidden state from pre-trained BERT\n",
    "        input_ids = torch.tensor(sequence_dict['input_ids']).unsqueeze(0)\n",
    "        \n",
    "        ## Re construct the true sentence\n",
    "        input_ids = input_ids.cpu().squeeze(0)\n",
    "        seq_len = [i for (i,t) in enumerate(input_ids)  if t==102][0]\n",
    "        correct_label_ids = [l for t,l in zip(input_tokens[1:seq_len], input_labels[1:seq_len]) if not t.startswith('##')]\n",
    "        \n",
    "\n",
    "        seq = \" \".join(input_tokens[1:seq_len]).replace(' ##','')\n",
    "        assert len(correct_label_ids)==len(seq.split(' '))\n",
    "        doc = nlp(seq)\n",
    "        \n",
    "        no_problem_size_1 =  len(doc)==len(correct_label_ids)\n",
    "        \n",
    "        if(no_problem_size_1):\n",
    "            \n",
    "            ## Construct the labels for the graph\n",
    "            new_labels,no_problem_size_2 = reshape_labels_nodes_3(doc, correct_label_ids)\n",
    "            \n",
    "            if(no_problem_size_2):\n",
    "                indice_nt, tab_edges_BERT = construct_the_edges_matrix_consistuency_3(doc)\n",
    "                \n",
    "                (num_similar_neighbourg,total_link) = count_similarity_neighbourg(new_labels,indice_nt)\n",
    "                total_num_similar_neighbourg += num_similar_neighbourg\n",
    "                total_total_link += total_link\n",
    "                \n",
    "                (num_similar_neighbourg_same_parent,total_link_same_parent) = count_similarity_neighbourg_same_parent(new_labels,indice_nt,tab_edges_BERT)\n",
    "                \n",
    "                total_num_similar_neighbourg_same_parent += num_similar_neighbourg_same_parent\n",
    "                total_total_link_same_parent += total_link_same_parent\n",
    "                \n",
    "                (num_similar_labels_extern,total_word) = count_similarity_in_leaf(new_labels,indice_nt,tab_edges_BERT)\n",
    "                total_num_similar_labels_extern += num_similar_labels_extern\n",
    "                total_total_word += total_word\n",
    "                \n",
    "                \n",
    "                \n",
    "                (num_similar_labels_intern,num_relation_intern) = count_similarity_intern(new_labels,indice_nt,tab_edges_BERT)\n",
    "                total_num_similar_labels_intern += num_similar_labels_intern\n",
    "                total_num_relation_intern += num_relation_intern\n",
    "                \n",
    "        \n",
    "    \n",
    "print(\"the results for a miximum depth of 3 are : \")\n",
    "\n",
    "print(\"The neighbourg similarity : \",str(total_num_similar_neighbourg/total_total_link))\n",
    "\n",
    "print(\"The neighbourg similarity by grammatical class : \",str(total_num_similar_neighbourg_same_parent/total_total_link_same_parent))\n",
    "\n",
    "print(\"The similarity between children and parents : \",str(total_num_similar_labels_extern/total_total_word))\n",
    "\n",
    "print(\"The similarity between internal nodes : \",str(total_num_similar_labels_intern/total_num_relation_intern))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "036a952d-5bb7-4d19-985a-7c9538e4060f",
   "metadata": {},
   "source": [
    "## Stat computation for a maximum depth of 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2ca0d1ca-2787-431e-8a76-afb1a24f1590",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [06:49<00:00, 51.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the results for a miximum depth of 2 are : \n",
      "The neighbourg similarity :  0.9682025507684262\n",
      "The neighbourg similarity by grammatical class :  0.9816489361702128\n",
      "The similarity between children and parents :  0.8940976756333247\n",
      "The similarity between internal nodes :  0.8806667508523803\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "total_num_similar_neighbourg = 0\n",
    "total_total_link = 0\n",
    "\n",
    "total_num_similar_neighbourg_same_parent = 0\n",
    "total_total_link_same_parent = 0\n",
    "\n",
    "total_num_similar_labels_extern = 0\n",
    "total_total_word = 0\n",
    "\n",
    "total_num_similar_labels_intern = 0\n",
    "total_num_relation_intern = 0\n",
    "\n",
    "for topic, AD in tqdm(AURC_DATA_dict.items()):\n",
    "    #print(topic)\n",
    "    for ad in AD:\n",
    "        ## Use the BERT Tokenizer from Hugging Face\n",
    "        sequence_dict = tokenizer.encode_plus(ad['sentence'], max_length=max_sequence_length, pad_to_max_length=True, \n",
    "                                              add_special_tokens=True)\n",
    "        \n",
    "        ## Label management\n",
    "        input_labels = [label2id[label] for label in ad['tokenized_sentence_bert_labels'].split(' ')]\n",
    "        len_sentence = min(len(input_labels[:max_sequence_length-1]),len(input_labels))\n",
    "        input_labels =  [0] + input_labels[:max_sequence_length-1] + [0]*max(0,max_sequence_length-len(input_labels)-1)\n",
    "        sequence_dict['label_ids'] = input_labels\n",
    "        \n",
    "        ## Input tokens\n",
    "        input_tokens = tokenizer.convert_ids_to_tokens(sequence_dict['input_ids'])\n",
    "        \n",
    "        ## Compute the last hidden state from pre-trained BERT\n",
    "        input_ids = torch.tensor(sequence_dict['input_ids']).unsqueeze(0)\n",
    "\n",
    "        ## Re construct the true sentence\n",
    "        input_ids = input_ids.cpu().squeeze(0)\n",
    "        seq_len = [i for (i,t) in enumerate(input_ids)  if t==102][0]\n",
    "        correct_label_ids = [l for t,l in zip(input_tokens[1:seq_len], input_labels[1:seq_len]) if not t.startswith('##')]\n",
    "        \n",
    "\n",
    "        seq = \" \".join(input_tokens[1:seq_len]).replace(' ##','')\n",
    "        assert len(correct_label_ids)==len(seq.split(' '))\n",
    "        doc = nlp(seq)\n",
    "        \n",
    "        no_problem_size_1 =  len(doc)==len(correct_label_ids)\n",
    "        \n",
    "        if(no_problem_size_1):\n",
    "            \n",
    "            ## Construct the labels for the graph\n",
    "            new_labels,no_problem_size_2 = reshape_labels_nodes_2(doc, correct_label_ids)\n",
    "            \n",
    "            if(no_problem_size_2):\n",
    "\n",
    "                indice_nt, tab_edges_BERT = construct_the_edges_matrix_consistuency_2(doc)\n",
    "                \n",
    "                (num_similar_neighbourg,total_link) = count_similarity_neighbourg(new_labels,indice_nt)\n",
    "                total_num_similar_neighbourg += num_similar_neighbourg\n",
    "                total_total_link += total_link\n",
    "                \n",
    "                (num_similar_neighbourg_same_parent,total_link_same_parent) = count_similarity_neighbourg_same_parent(new_labels,indice_nt,tab_edges_BERT)\n",
    "                \n",
    "                total_num_similar_neighbourg_same_parent += num_similar_neighbourg_same_parent\n",
    "                total_total_link_same_parent += total_link_same_parent\n",
    "                \n",
    "                (num_similar_labels_extern,total_word) = count_similarity_in_leaf(new_labels,indice_nt,tab_edges_BERT)\n",
    "                total_num_similar_labels_extern += num_similar_labels_extern\n",
    "                total_total_word += total_word\n",
    "                \n",
    "                \n",
    "                \n",
    "                (num_similar_labels_intern,num_relation_intern) = count_similarity_intern(new_labels,indice_nt,tab_edges_BERT)\n",
    "                total_num_similar_labels_intern += num_similar_labels_intern\n",
    "                total_num_relation_intern += num_relation_intern\n",
    "                \n",
    "        \n",
    "    \n",
    "print(\"the results for a miximum depth of 2 are : \")\n",
    "\n",
    "print(\"The neighbourg similarity : \",str(total_num_similar_neighbourg/total_total_link))\n",
    "\n",
    "print(\"The neighbourg similarity by grammatical class : \",str(total_num_similar_neighbourg_same_parent/total_total_link_same_parent))\n",
    "\n",
    "print(\"The similarity between children and parents : \",str(total_num_similar_labels_extern/total_total_word))\n",
    "\n",
    "print(\"The similarity between internal nodes : \",str(total_num_similar_labels_intern/total_num_relation_intern))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e21c1b6-2b9b-4e4b-a786-6cc71cc196da",
   "metadata": {},
   "source": [
    "## Stat computation for a maximum depth of 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5f100091-a478-4f94-af24-f0398b299487",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [07:09<00:00, 53.70s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the results for a miximum depth of 4 are : \n",
      "The neighbourg similarity :  0.9682025507684262\n",
      "The neighbourg similarity by grammatical class :  0.9848786179282182\n",
      "The similarity between children and parents :  0.8306346304518151\n",
      "The similarity between internal nodes :  0.9282240321148725\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "total_num_similar_neighbourg = 0\n",
    "total_total_link = 0\n",
    "\n",
    "total_num_similar_neighbourg_same_parent = 0\n",
    "total_total_link_same_parent = 0\n",
    "\n",
    "total_num_similar_labels_extern = 0\n",
    "total_total_word = 0\n",
    "\n",
    "total_num_similar_labels_intern = 0\n",
    "total_num_relation_intern = 0\n",
    "\n",
    "for topic, AD in tqdm(AURC_DATA_dict.items()):\n",
    "    #print(topic)\n",
    "    for ad in AD:\n",
    "        ## Use the BERT Tokenizer from Hugging Face\n",
    "        sequence_dict = tokenizer.encode_plus(ad['sentence'], max_length=max_sequence_length, pad_to_max_length=True, \n",
    "                                              add_special_tokens=True)\n",
    "        \n",
    "        ## Label management\n",
    "        input_labels = [label2id[label] for label in ad['tokenized_sentence_bert_labels'].split(' ')]\n",
    "        len_sentence = min(len(input_labels[:max_sequence_length-1]),len(input_labels))\n",
    "        input_labels =  [0] + input_labels[:max_sequence_length-1] + [0]*max(0,max_sequence_length-len(input_labels)-1)\n",
    "        sequence_dict['label_ids'] = input_labels\n",
    "        \n",
    "        ## Input tokens\n",
    "        input_tokens = tokenizer.convert_ids_to_tokens(sequence_dict['input_ids'])\n",
    "        \n",
    "        ## Compute the last hidden state from pre-trained BERT\n",
    "        input_ids = torch.tensor(sequence_dict['input_ids']).unsqueeze(0)\n",
    "\n",
    "        ## Re construct the true sentence\n",
    "        input_ids = input_ids.cpu().squeeze(0)\n",
    "        seq_len = [i for (i,t) in enumerate(input_ids)  if t==102][0]\n",
    "        correct_label_ids = [l for t,l in zip(input_tokens[1:seq_len], input_labels[1:seq_len]) if not t.startswith('##')]\n",
    "\n",
    "        seq = \" \".join(input_tokens[1:seq_len]).replace(' ##','')\n",
    "        assert len(correct_label_ids)==len(seq.split(' '))\n",
    "        doc = nlp(seq)\n",
    "        \n",
    "        no_problem_size_1 =  len(doc)==len(correct_label_ids)\n",
    "        \n",
    "        if(no_problem_size_1):\n",
    "            \n",
    "            ## Construct the labels for the graph\n",
    "            new_labels,no_problem_size_2 = reshape_labels_nodes_4(doc, correct_label_ids)\n",
    "            \n",
    "            if(no_problem_size_2):\n",
    "                indice_nt, tab_edges_BERT = construct_the_edges_matrix_consistuency_4(doc)\n",
    "                \n",
    "                (num_similar_neighbourg,total_link) = count_similarity_neighbourg(new_labels,indice_nt)\n",
    "                total_num_similar_neighbourg += num_similar_neighbourg\n",
    "                total_total_link += total_link\n",
    "                \n",
    "                (num_similar_neighbourg_same_parent,total_link_same_parent) = count_similarity_neighbourg_same_parent(new_labels,indice_nt,tab_edges_BERT)\n",
    "                \n",
    "                total_num_similar_neighbourg_same_parent += num_similar_neighbourg_same_parent\n",
    "                total_total_link_same_parent += total_link_same_parent\n",
    "                \n",
    "                (num_similar_labels_extern,total_word) = count_similarity_in_leaf(new_labels,indice_nt,tab_edges_BERT)\n",
    "                total_num_similar_labels_extern += num_similar_labels_extern\n",
    "                total_total_word += total_word\n",
    "                \n",
    "                \n",
    "                \n",
    "                (num_similar_labels_intern,num_relation_intern) = count_similarity_intern(new_labels,indice_nt,tab_edges_BERT)\n",
    "                total_num_similar_labels_intern += num_similar_labels_intern\n",
    "                total_num_relation_intern += num_relation_intern\n",
    "                \n",
    "        \n",
    "    \n",
    "print(\"the results for a miximum depth of 4 are : \")\n",
    "\n",
    "print(\"The neighbourg similarity : \",str(total_num_similar_neighbourg/total_total_link))\n",
    "\n",
    "print(\"The neighbourg similarity by grammatical class : \",str(total_num_similar_neighbourg_same_parent/total_total_link_same_parent))\n",
    "\n",
    "print(\"The similarity between children and parents : \",str(total_num_similar_labels_extern/total_total_word))\n",
    "\n",
    "print(\"The similarity between internal nodes : \",str(total_num_similar_labels_intern/total_num_relation_intern))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_env",
   "language": "python",
   "name": "concept_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
