#!/usr/bin/env python

import os
import spacy
import benepar
from benepar import BeneparComponent, NonConstituentException
benepar.download('benepar_en3')
import sys
import json
import os.path as osp
from tqdm import tqdm
from transformers import (BertConfig,
                                BertForTokenClassification, BertTokenizer, BertPreTrainedModel, BertModel)
import torch
from torch_geometric.data import Dataset,Data


def construct_the_edges_matrix_consistuency_3_padding(doc, list_index_to_replicate, max_indice_nt):
    edges_index = []
    current_node_index = -1
    
    current_index_replicated = 0
    
    for sent in doc.sents:
        ## On regarde chaque phrases de notre document
        current_node_index +=1
        index_sentence = current_node_index
        for ele in sent._.children:
            ## On regarde chaque élément de la phrase au niveau de profondeur 1 et on regarde si eux même ont des enfants
            num_children = sum(1 for x in ele._.children)
            if num_children != 0 :
                ## Il y a des éléments de profondeurs 2
                current_node_index += 1
                index_second_node = current_node_index
                edges_index.append([index_second_node,index_sentence])
                for ele_child in ele._.children:
                    
                    num_children_children = sum(1 for x in ele_child._.children)
                    if num_children_children != 0 :
                        ## Il y a des éléments de profondeurs 3
                        current_node_index += 1
                        index_third_node = current_node_index
                        edges_index.append([index_third_node,index_second_node])
                        
                        for ele_child_child in ele_child._.children:
                            
                            for word in ele_child_child:
                                if(word.i in list_index_to_replicate):
                                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_third_node])
                                    current_index_replicated +=1
                                edges_index.append([word.i+max_indice_nt+current_index_replicated,index_third_node])
                    else:
                        ## Il y n'y a pas des éléments de profondeurs 3, le noeud de profondeur 2 est une feuille
                        for word in ele_child:
                            if(word.i in list_index_to_replicate):
                                edges_index.append([word.i+max_indice_nt+current_index_replicated,index_second_node])
                                current_index_replicated +=1
                            ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase
                            edges_index.append([word.i+max_indice_nt+current_index_replicated,index_second_node])
            else:
                ## Il n'y a pas d'éléments de profondeurs 2
                for word in ele:
                    if(word.i in list_index_to_replicate):
                        edges_index.append([word.i+max_indice_nt+current_index_replicated,index_sentence])
                        current_index_replicated +=1
                    ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase
                    edges_index.append([word.i+max_indice_nt+current_index_replicated,index_sentence])
    return edges_index


def main():
    os.environ["CUDA_VISIBLE_DEVICES"]=""
    nlp = spacy.load('en_core_web_sm')
    nlp.add_pipe("benepar", config={"model": "benepar_en3"})
    data_dir = '../../../data/aurc/bert/end_2_end_IN'
    input_file = 'AURC_DATA_dict.json'
    num_labels = 3
    pretrained_weights =  'bert-large-uncased'
    max_sequence_length = 64
    # the domains are shuffle
    target_domain = 'In-Domain'
    fname = '../../../data/aurc/AURC_DATA_dict.json'

    # load the json file
    with open(fname,'r') as my_file:
        AURC_DATA_dict = json.load(my_file)
    print(len(AURC_DATA_dict), [len(AURC_DATA_dict[topic]) for topic in AURC_DATA_dict.keys()])

    # check the number of example per topic
    topics = sorted(set(AURC_DATA_dict.keys()))
    print(len(topics), topics)

    # define the label to id dictionnary
    label2id = {}
    label2id['non'] = 0
    label2id['con'] = 1
    label2id['pro'] = 2

    # Choose the tokenizer from Hugging Face transformers
    tokenizer = BertTokenizer.from_pretrained(pretrained_weights)


    train_features = []
    eval_features = []
    test_features = []

    max_indice_nt = 21
    num_error = 0
    number_data_train = 0
    number_data_test = 0
    number_data_dev = 0

    for topic, AD in tqdm(AURC_DATA_dict.items()):
        for ad in AD:
            ## Use the BERT Tokenizer from Hugging Face
            sequence_dict = tokenizer.encode_plus(ad['sentence'], max_length=max_sequence_length, pad_to_max_length=True, 
                                                  add_special_tokens=True)
            ## Label management
            input_labels = [label2id[label] for label in ad['tokenized_sentence_bert_labels'].split(' ')]
            len_sentence = min(len(input_labels[:max_sequence_length-1]),len(input_labels))
            input_labels =  [0] + input_labels[:max_sequence_length-1] + [0]*max(0,max_sequence_length-len(input_labels)-1)

            ## Input tokens
            input_tokens = tokenizer.convert_ids_to_tokens(sequence_dict['input_ids'])

            ## Reconstruct the true sentence
            input_ids = sequence_dict['input_ids']
            seq_len = [i for (i,t) in enumerate(input_ids)  if t==102][0]
            seq = " ".join(input_tokens[1:seq_len]).replace(' ##','')

            ## Construct the Spacy representation of the sentence
            doc = nlp(seq)
            no_problem_size_1 =  len(doc) == len(seq.split(' '))

            if(no_problem_size_1):

                ## Construct the labels for the graph
                ## The labels refer to the label of the bert representation 
                y = torch.tensor(input_labels)

                list_index_to_replicate = [i for (i,ele) in enumerate(input_tokens) if (ele[0:2] == "##")]

                tab_edges_BERT = construct_the_edges_matrix_consistuency_3_padding(doc, list_index_to_replicate, max_indice_nt)
                edge_matrix = torch.tensor(tab_edges_BERT).transpose(0,1)
                sentence_level_label = torch.tensor(label2id[ad["sentence_level_stance"]])

                data = Data(
                    input_ids = torch.tensor(sequence_dict['input_ids']).unsqueeze(0),
                    attention_mask = torch.tensor(sequence_dict['attention_mask']).unsqueeze(0), 
                    token_type_ids = torch.tensor(sequence_dict['token_type_ids']).unsqueeze(0),
                    edge_index = edge_matrix,  
                    y = y, 
                    sentence_level_label = sentence_level_label
                )

                data.num_nodes = 64 + max_indice_nt

                if ad[target_domain] == 'Train':
                    path = data_dir + "/" + ad[target_domain] + "/processed/" + "data_" + str(number_data_train)+".pt"
                    torch.save(data, path)
                    number_data_train += 1
                elif ad[target_domain] == 'Dev':
                    path = data_dir + "/" + ad[target_domain] + "/processed/" + "data_" + str(number_data_dev)+".pt"
                    torch.save(data, path)
                    number_data_dev += 1
                elif ad[target_domain] == 'Test':
                    path = data_dir + "/" + ad[target_domain] + "/processed/" + "data_" + str(number_data_test)+".pt"
                    torch.save(data, path)
                    number_data_test += 1
                else:
                    num_error +=1 
            else:
                num_error +=1 
        
                            
if __name__ == "__main__":
    main()