#!/usr/bin/env python

### For depth of 2
from collections import Counter

def find_number_internal_node_2(doc):
    ## Determine the number of sentence
    num_sent = sum(1 for x in doc.sents)
    num_first_step_indice = num_sent
    ## Determine the number of first depth and sentences
    num_second_step_indice = num_sent
    for sent in doc.sents:
        for ele in sent._.children:
            num_child = sum(1 for x in ele._.children)
            if num_child != 0:
                num_second_step_indice += 1
    ## Determine the number of seconde depth node + first depth node + sentences
    indice_nt = num_second_step_indice
    return(indice_nt)


def construct_the_edges_matrix_consistuency_2(doc):
    edges_index = []
    indice_nt = find_number_internal_node_2(doc)
    current_node_index = -1
    for sent in doc.sents:
        ## On regarde chaque phrases de notre document
        current_node_index +=1
        index_sentence = current_node_index
        for ele in sent._.children:
            ## On regarde chaque élément de la phrase au niveau de profondeur 1 et on regarde si eux même ont des enfants
            num_children = sum(1 for x in ele._.children)
            if num_children != 0 :
                ## Il y a des éléments de profondeurs 2
                current_node_index += 1
                index_second_node = current_node_index
                edges_index.append([index_second_node,index_sentence])
                for ele_child in ele._.children:
                    for word in ele_child:
                        edges_index.append([word.i+indice_nt,index_second_node])
            else:
                ## Il n'y a pas d'éléments de profondeurs 2
                for word in ele:
                    ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase
                    edges_index.append([word.i+indice_nt,index_sentence])
    return indice_nt,edges_index


def find_the_related_nodes_2(doc):
    indice_nt = find_number_internal_node_2(doc)
    dict_relation_node_input = dict()
    current_node_index = -1
    for sent in doc.sents:
        ## On regarde chaque phrases de notre document
        current_node_index +=1
        index_sentence = current_node_index
        for ele in sent._.children:
            ## On regarde chaque élément de la phrase au niveau de profondeur 1 et on regarde si eux même ont des enfants
            num_children = sum(1 for x in ele._.children)
            if num_children != 0 :
                ## Il y a des éléments de profondeurs 2
                current_node_index += 1
                index_second_node = current_node_index
                ## on ajoute le lien du nouveau noeuds au noeud racine de la phrase
                if(index_sentence in dict_relation_node_input.keys()):
                    dict_relation_node_input[index_sentence].append(index_second_node)
                else:
                    dict_relation_node_input[index_sentence] = [index_second_node]
                    
                for ele_child in ele._.children:
                    num_children_children = sum(1 for x in ele_child._.children)
                    if(index_second_node in dict_relation_node_input.keys()):
                        for word in ele_child : 
                            dict_relation_node_input[index_second_node].append(word.i +indice_nt)
                    else:
                        dict_relation_node_input[index_second_node] = [word.i +indice_nt for word in ele_child]
            else:
                if(index_sentence in dict_relation_node_input.keys()):
                    for word in ele:
                        dict_relation_node_input[index_sentence].append(word.i + indice_nt)
                else:
                    dict_relation_node_input[index_sentence] = [word.i + indice_nt for word in ele]
    return dict_relation_node_input,indice_nt



def reshape_labels_nodes_2(doc, input_labels):
    """Add the labels of the NT nodes to the label list."""
    list_doc = [str(ele) for ele in doc]
    no_problem_size =  len(input_labels) == len(list_doc)
    if(no_problem_size):
        dict_relation_node_input,indice_nt = find_the_related_nodes_2(doc)
        new_labels = indice_nt*[0] + input_labels
        for i in range(indice_nt):
            j = indice_nt-(i+1)
            
            if(j in dict_relation_node_input.keys()):
                neigbour_j = dict_relation_node_input[j]
                #print(neigbour_j)
                labels_j = Counter([new_labels[ele] for ele in neigbour_j])
                #print(labels_j)
                if(len(labels_j.keys()) == 1):
                    new_labels[j] = list(labels_j.keys())[0]
                else:
                    if(0 in labels_j.keys()):
                        labels_j.pop(0)
                    new_labels[j] = max(labels_j, key=labels_j.get)
        return new_labels,no_problem_size
    else:
        return [],no_problem_size
    

