#!/usr/bin/env python


### For depth of 3

def find_number_internal_node_3(doc):
    ## Determine the number of sentence
    num_sent = sum(1 for x in doc.sents)
    num_first_step_indice = num_sent
    
    ## Determine the number of first depth and sentences
    num_second_step_indice = num_sent
    for sent in doc.sents:
        for ele in sent._.children:
            num_child = sum(1 for x in ele._.children)
            if num_child != 0:
                num_second_step_indice += 1
                
    ## Determine the number of second depth and sentences
    num_third_step_indice = num_second_step_indice
    for sent in doc.sents:
        for ele in sent._.children:
            num_child = sum(1 for x in ele._.children)
            if num_child != 0:
                for ele_child in ele._.children:
                    num_child_child = sum(1 for x in ele_child._.children)
                    if(num_child_child !=0):
                        num_third_step_indice += 1
                        
    ## Determine the number of seconde depth node + first depth node + sentences                    
    indice_nt = num_third_step_indice
    
    return(indice_nt)


def construct_the_edges_matrix_consistuency_3_padding(doc,max_indice_nt):
    edges_index = []
    current_node_index = -1
    for sent in doc.sents:
        ## On regarde chaque phrases de notre document
        current_node_index +=1
        index_sentence = current_node_index
        for ele in sent._.children:
            ## On regarde chaque élément de la phrase au niveau de profondeur 1 et on regarde si eux même ont des enfants
            num_children = sum(1 for x in ele._.children)
            if num_children != 0 :
                ## Il y a des éléments de profondeurs 2
                current_node_index += 1
                index_second_node = current_node_index
                edges_index.append([index_second_node,index_sentence])
                for ele_child in ele._.children:
                    
                    num_children_children = sum(1 for x in ele_child._.children)
                    if num_children_children != 0 :
                        ## Il y a des éléments de profondeurs 3
                        current_node_index += 1
                        index_third_node = current_node_index
                        edges_index.append([index_third_node,index_second_node])
                        
                        for ele_child_child in ele_child._.children:
                            
                            for word in ele_child_child:
                                edges_index.append([word.i+max_indice_nt,index_third_node])
                    else:
                        ## Il y n'y a pas des éléments de profondeurs 3, le noeud de profondeur 2 est une feuille
                        for word in ele_child:
                            ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase
                            edges_index.append([word.i+max_indice_nt,index_second_node])
            else:
                ## Il n'y a pas d'éléments de profondeurs 2
                for word in ele:
                    ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase
                    edges_index.append([word.i+max_indice_nt,index_sentence])
    return edges_index

def construct_the_edges_matrix_consistuency_3(doc):
    edges_index = []
    indice_nt = find_number_internal_node_3(doc)
    current_node_index = -1
    for sent in doc.sents:
        ## On regarde chaque phrases de notre document
        current_node_index +=1
        index_sentence = current_node_index
        for ele in sent._.children:
            ## On regarde chaque élément de la phrase au niveau de profondeur 1 et on regarde si eux même ont des enfants
            num_children = sum(1 for x in ele._.children)
            if num_children != 0 :
                ## Il y a des éléments de profondeurs 2
                current_node_index += 1
                index_second_node = current_node_index
                edges_index.append([index_second_node,index_sentence])
                for ele_child in ele._.children:
                    
                    num_children_children = sum(1 for x in ele_child._.children)
                    if num_children_children != 0 :
                        ## Il y a des éléments de profondeurs 3
                        current_node_index += 1
                        index_third_node = current_node_index
                        edges_index.append([index_third_node,index_second_node])
                        
                        for ele_child_child in ele_child._.children:
                            
                            for word in ele_child_child:
                                edges_index.append([word.i+indice_nt,index_third_node])
                    else:
                        ## Il y n'y a pas des éléments de profondeurs 3, le noeud de profondeur 2 est une feuille
                        for word in ele_child:
                            ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase
                            edges_index.append([word.i+indice_nt,index_second_node])
            else:
                ## Il n'y a pas d'éléments de profondeurs 2
                for word in ele:
                    ## Normalement il n'y a que 1 mot, on doit le relier au noeud de la phrase
                    edges_index.append([word.i+indice_nt,index_sentence])
                    
    return indice_nt,edges_index




def find_the_related_nodes_3(doc):
    indice_nt = find_number_internal_node_3(doc)
    dict_relation_node_input = dict()
    current_node_index = -1
    for sent in doc.sents:
        ## On regarde chaque phrases de notre document
        current_node_index +=1
        index_sentence = current_node_index
        for ele in sent._.children:
            ## On regarde chaque élément de la phrase au niveau de profondeur 1 et on regarde si eux même ont des enfants
            num_children = sum(1 for x in ele._.children)
            if num_children != 0 :
                ## Il y a des éléments de profondeurs 2
                current_node_index += 1
                index_second_node = current_node_index
                
                ## on ajoute le lien du nouveau noeuds au noeud racine de la phrase
                if(index_sentence in dict_relation_node_input.keys()):
                    dict_relation_node_input[index_sentence].append(index_second_node)
                else:
                    dict_relation_node_input[index_sentence] = [index_second_node]
                    
                for ele_child in ele._.children:
                    num_children_children = sum(1 for x in ele_child._.children)
                    if num_children_children != 0 :
                        current_node_index += 1
                        index_third_node = current_node_index
                        
                        # on ajoute le lien entre le noeud de profondeur 2 et 3
                        if(index_second_node in dict_relation_node_input.keys()):
                            dict_relation_node_input[index_second_node].append(index_third_node)
                        else:
                            dict_relation_node_input[index_second_node] = [index_third_node]
                        
                        # on ajoute les noeuds des feuilles au noeud de profondeur 3
                        for ele_child_child in ele_child._.children:
                            
                            if(index_third_node in dict_relation_node_input.keys()):
                                for word in ele_child_child : 
                                    dict_relation_node_input[index_third_node].append(word.i +indice_nt)
                            else:
                                dict_relation_node_input[index_third_node] = [word.i +indice_nt for word in ele_child_child]
                                
                    # Le noeud de profondeur 2 n'a pas de noeud plus bas
                    else:
                        if(index_second_node in dict_relation_node_input.keys()):
                            for word in ele_child:
                                dict_relation_node_input[index_second_node].append(word.i + indice_nt)
                        else:
                            dict_relation_node_input[index_second_node] = [word.i + indice_nt for word in ele_child]
                            
            # le noeud de profondeur 1 n'a pas de noeud plus bas
            else:
                if(index_sentence in dict_relation_node_input.keys()):
                    for word in ele:
                        dict_relation_node_input[index_sentence].append(word.i + indice_nt)
                else:
                    dict_relation_node_input[index_sentence] = [word.i + indice_nt for word in ele]
    return dict_relation_node_input,indice_nt





def reshape_labels_nodes_3(doc, input_labels):
    """Add the labels of the NT nodes to the label list."""
    list_doc = [str(ele) for ele in doc]
    no_problem_size =  len(input_labels) == len(list_doc)
    if(no_problem_size):
        dict_relation_node_input,indice_nt = find_the_related_nodes_3(doc)
        new_labels = indice_nt*[0] + input_labels
        for i in range(indice_nt):
            j = indice_nt-(i+1)
            if(j in dict_relation_node_input.keys()):
                neigbour_j = dict_relation_node_input[j]
                ## Construction of the dictionnary which count the number of label
                dict_test = {}
                for ele in neigbour_j:
                    if(new_labels[ele] in dict_test.keys()):
                        dict_test[new_labels[ele]] += 1
                    else:
                        dict_test[new_labels[ele]] = 1
                        
                ## attribution of the label        
                if(len(dict_test.keys())) == 1:
                    new_labels[j] = list(dict_test.keys())[0]
                    
                elif(len(dict_test.keys())) == 2:
                    if(0 in dict_test.keys()):
                        dict_test.pop(0)
                        new_labels[j] = list(dict_test.keys())[0]
                    else:
                        #print(dict_test)
                        if(dict_test[2] > dict_test[1]):
                            new_labels[j] = 2
                        else:
                            new_labels[j] = 1
                            
                elif(len(dict_test.keys())) == 3:
                    if(dict_test[2] > dict_test[1]):
                        new_labels[j] = 2
                    else:
                        new_labels[j] = 1
                        
        return new_labels,no_problem_size
    else:
        return [],no_problem_size
    
    
    
    