#!/usr/bin/env python

import os
import spacy
import benepar
from benepar import BeneparComponent, NonConstituentException
import sys
import json
import os.path as osp
from tqdm import tqdm
from transformers import (BertConfig,
                                BertForTokenClassification, BertTokenizer, BertPreTrainedModel, BertModel)
import torch
from torch_geometric.data import Dataset,Data


def construct_the_edges_matrix_consistuency_4_padding(doc,list_index_to_replicate, max_indice_nt):
    edges_index = []
    current_node_index = -1
    current_index_replicated = 0
    
    for sent in doc.sents:
        ## On regarde chaque phrases de notre document
        current_node_index +=1
        index_sentence = current_node_index
        for ele in sent._.children:
            ## On regarde chaque élément de la phrase au niveau de profondeur 1 et on regarde si eux même ont des enfants
            num_children = sum(1 for x in ele._.children)
            if num_children != 0 :
                ## Il y a des éléments de profondeurs 2
                current_node_index += 1
                index_second_node = current_node_index
                edges_index.append([index_second_node,index_sentence])
                for ele_child in ele._.children:
                    
                    num_children_children = sum(1 for x in ele_child._.children)
                    if num_children_children != 0 :
                        ## Il y a des éléments de profondeurs 3
                        current_node_index += 1
                        index_third_node = current_node_index
                        edges_index.append([index_third_node,index_second_node])
                        
                        for ele_child_child in ele_child._.children:
                            num_children_children_children = sum(1 for x in ele_child_child._.children)
                            if num_children_children_children != 0 :
                                
                                ## Il y a des éléments de profondeurs 4
                                current_node_index += 1
                                index_fourth_node = current_node_index
                                edges_index.append([index_fourth_node,index_third_node])
                                
                                for word in ele_child_child:
                                    if(word.i in list_index_to_replicate):
                                        edges_index.append([word.i+max_indice_nt+current_index_replicated,index_fourth_node])
                                        current_index_replicated +=1
                                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_fourth_node])
                            
                            else:
                                ### Il n'y a pas d'élément de profondeur 4, le noeud de profondeur 3 est une feuille
                                for word in ele_child_child:
                                    if(word.i in list_index_to_replicate):
                                        edges_index.append([word.i+max_indice_nt+current_index_replicated,index_third_node])
                                        current_index_replicated +=1
                                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_third_node])     
                    else:
                        ## Il y n'y a pas des éléments de profondeurs 3, le noeud de profondeur 2 est une feuille
                        for word in ele_child:
                            if(word.i in list_index_to_replicate):
                                edges_index.append([word.i+max_indice_nt+current_index_replicated,index_second_node])
                                current_index_replicated +=1
                            edges_index.append([word.i+max_indice_nt+current_index_replicated,index_second_node])
            else:
                ## Il n'y a pas d'éléments de profondeurs 2
                for word in ele:
                    ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase
                    if(word.i in list_index_to_replicate):
                        edges_index.append([word.i+max_indice_nt+current_index_replicated,index_sentence])
                        current_index_replicated +=1
                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_sentence])
    return edges_index


def main():
    os.environ["CUDA_VISIBLE_DEVICES"]=""
    nlp = spacy.load('en_core_web_sm')
    benepar.download('benepar_en3')
    nlp.add_pipe("benepar", config={"model": "benepar_en3"})
    
    data_dir = '../../../data/aurc/bert/end_2_end_depth_4_IN'
    input_file = 'AURC_DATA_dict.json'
    num_labels = 3

    pretrained_weights =  'bert-large-uncased'

    max_sequence_length = 64

    # the domains are shuffle
    target_domain = 'In-Domain'

    fname = '../../../data/aurc/AURC_DATA_dict.json'

    # load the json file
    with open(fname,'r') as my_file:
        AURC_DATA_dict = json.load(my_file)
    print(len(AURC_DATA_dict), [len(AURC_DATA_dict[topic]) for topic in AURC_DATA_dict.keys()])

    # check the number of example per topic
    topics = sorted(set(AURC_DATA_dict.keys()))
    print(len(topics), topics)

    # define the label to id dictionnary
    label2id = {}
    label2id['non'] = 0
    label2id['con'] = 1
    label2id['pro'] = 2

    # Choose the tokenizer from Hugging Face transformers
    tokenizer = BertTokenizer.from_pretrained(pretrained_weights)

    train_features = []
    eval_features = []
    test_features = []

    max_indice_nt = 21
    num_error = 0
    number_data_train = 0
    number_data_test = 0
    number_data_dev = 0

    for topic, AD in tqdm(AURC_DATA_dict.items()):
        for ad in AD:
            ## Use the BERT Tokenizer from Hugging Face
            sequence_dict = tokenizer.encode_plus(ad['sentence'], max_length=max_sequence_length, pad_to_max_length=True, 
                                                  add_special_tokens=True)
            ## Label management
            input_labels = [label2id[label] for label in ad['tokenized_sentence_bert_labels'].split(' ')]
            len_sentence = min(len(input_labels[:max_sequence_length-1]),len(input_labels))
            input_labels =  [0] + input_labels[:max_sequence_length-1] + [0]*max(0,max_sequence_length-len(input_labels)-1)

            ## Input tokens
            input_tokens = tokenizer.convert_ids_to_tokens(sequence_dict['input_ids'])

            ## Reconstruct the true sentence
            input_ids = sequence_dict['input_ids']
            seq_len = [i for (i,t) in enumerate(input_ids)  if t==102][0]
            seq = " ".join(input_tokens[1:seq_len]).replace(' ##','')

            ## Construct the Spacy representation of the sentence
            doc = nlp(seq)
            no_problem_size_1 =  len(doc) == len(seq.split(' '))

            if(no_problem_size_1):

                ## Construct the labels for the graph
                ## The labels refer to the label of the bert representation 
                y = torch.tensor(input_labels)

                list_index_to_replicate = [i for (i,ele) in enumerate(input_tokens) if (ele[0:2] == "##")]

                tab_edges_BERT = construct_the_edges_matrix_consistuency_4_padding(doc, list_index_to_replicate, max_indice_nt)
                edge_matrix = torch.tensor(tab_edges_BERT).transpose(0,1)
                sentence_level_label = torch.tensor(label2id[ad["sentence_level_stance"]])

                data = Data(
                    input_ids = torch.tensor(sequence_dict['input_ids']).unsqueeze(0),
                    attention_mask = torch.tensor(sequence_dict['attention_mask']).unsqueeze(0), 
                    token_type_ids = torch.tensor(sequence_dict['token_type_ids']).unsqueeze(0),
                    edge_index = edge_matrix,  
                    y = y, 
                    sentence_level_label = sentence_level_label
                )

                data.num_nodes = 64 + max_indice_nt

                if ad[target_domain] == 'Train':
                    path = data_dir + "/" + ad[target_domain] + "/processed/" + "data_" + str(number_data_train)+".pt"
                    torch.save(data, path)
                    number_data_train += 1
                elif ad[target_domain] == 'Dev':
                    path = data_dir + "/" + ad[target_domain] + "/processed/" + "data_" + str(number_data_dev)+".pt"
                    torch.save(data, path)
                    number_data_dev += 1
                elif ad[target_domain] == 'Test':
                    path = data_dir + "/" + ad[target_domain] + "/processed/" + "data_" + str(number_data_test)+".pt"
                    torch.save(data, path)
                    number_data_test += 1
                else:
                    num_error +=1 
            else:
                num_error +=1 

                
if __name__ == "__main__":
    main()