import json, pickle, yaml
import copy

import numpy as np
import networkx as nx
import re

import matplotlib.pyplot as plt

from pyHEXgraph import HEXGraph

## UTILS

def readLeaves(file_name="map_clsloc.txt"):
    # Read file
    with open(file_name, 'r') as leaves_file:
        raw_synsets = leaves_file.readlines()
    
    leaves_synsets = [synset.split()[0][1:] for synset in raw_synsets]
    return leaves_synsets

def Synset2Node(D):
    s2n = {}
    nodes = D.nodes
    for (node_idx, synset_id) in enumerate(nodes):
        s2n[synset_id] = node_idx
        
    return s2n

myGreen="#B5D33D"
# myPineGreen="#01796f"
myPineGreen="#0F3325"
myYellow="#FED23F"
myBlue="#6CA2EA"
myRed="#EB7D5B"
myViolet="#442288"

def plot(L, root, colors={True:myBlue, False:myViolet}, special_colors={}):
    is_leaf = dict(L.nodes.data("is_leaf"))
    node_color = [colors[leaf] for leaf in is_leaf.values()]
    for (k, v) in special_colors.items():
        node_color[k] = v
    tr=nx.transitive_reduction(L.reverse())
    pos = nx.nx_agraph.graphviz_layout(tr, prog="twopi", root=root)
    nx.draw(tr, pos, with_labels=False, node_size=20, arrowstyle='-', node_color=node_color, width=0.5)
    plt.show()

def getPaths(nb_keep=100, pruning=True, dir_path='./ImageNet/'):
    path = dir_path+"compilations/"+str(nb_keep)
    if pruning:
        path = path + "p"
    config_path = path+"_config.json"
    graph_path = path+"_graph.pkl"

    return path, config_path, graph_path


## CORE

def grow(hyp, D=None):
    '''
    Grow a Directed Acyclic Graph (DAG) from a list of hierarchical links
    '''
    
    if D==None:
        D = nx.DiGraph()

    pattern = '(\d+)'

    for h in hyp:
        matches = re.findall(pattern, h)
        son = matches[0][1:]
        father = matches[1][1:]
        son_type = matches[0][0]
        father_type = matches[1][0]

        if son_type=='1' and father_type=='1':

            if not(son in D.nodes):
                D.add_node(son)

            if not(father in D.nodes):
                D.add_node(father)

            D.add_edge(father, son)
        
    return D


def pot(hyp, leaves, D=None):
    '''
    Build a Directed Acyclic Graph (DAG) from a list of hierarchical links and a list of leaves,
    only keeping the leaves and the nodes that lead to them.
    (much faster than growing the full DAG and then triming it)
    '''
    if D==None:
        D = nx.DiGraph()

    for l in leaves:
        D.add_node(l, is_leaf=True)

    pattern = '(\d+)'

    uncomplete = True
    while uncomplete:
        uncomplete = False
        for h in hyp:
            matches = re.findall(pattern, h)
            son = matches[0][1:]
            father = matches[1][1:]
            son_type = matches[0][0]
            father_type = matches[1][0]

            if (son in D.nodes) and son_type=='1' and father_type=='1':
                D.add_node(father, is_leaf=False)
                if not D.has_edge(father, son): # check if the graph has the edge
                    uncomplete = True # set back the uncomplete bool to True
                    D.add_edge(father, son)

    return D


def prune(G):
    '''
    Delete a node if it has only one child,
    and connect the child to the parents of the deleted node.
    Do it iteratively until the graph is stable.
    '''
    T = copy.deepcopy(G)
    stable=False
    while not(stable):
        stable=True
        nodes = copy.deepcopy(T.nodes)
        for n in nodes:
            children = T.successors(n)
            if len(list(copy.deepcopy(children)))==1:
                parents = T.predecessors(n)
                
                for c in children:    
                    for p in parents:
                        # if not(p in descendants(T, c)):
                        T.add_edge(p, c)
                T.remove_node(n)
                stable=False

    return T


def trim(G, leaves):
    '''
    Delete a node if it has no path towards a leaf.
    '''
    has_leaves = True
    for l in leaves:
        if not(l in G.nodes):
            print(l)
            has_leaves = False

    if not(has_leaves):
        raise Exception("All the leaves shoud belong the the graph")
    T = copy.deepcopy(G)
    nodes = copy.deepcopy(T.nodes)
    for n in nodes:
        cut=True
        for l in leaves:
            if nx.has_path(T, n, l):
                cut=False
        
        if cut:
            T.remove_node(n)

    return T

def graft(Ls, L, n=None, mapping={}, s2n=None, s2n_path=None):
    '''
    Get mapping from the larger graph L to the subset graph graph Ls.
    
    Inputs:
            - Ls : the subset lattice
            - L : the larger lattice (must inclue Ls)
            - n : the node to inspect
            - mapping : the mapping so far
            - s2n : a dict to transform synset names into node indexes
            - s2n_path : a path to load s2n in case it is not provided
        
    Outputs :
            - mapping : the mapping of nodes in L that do not belong to Ls
             to all of their parents present in the smaller graph.
    '''
    if not(s2n_path is None):
        with open(s2n_path, 'r') as f:
            s2n = json.load(f)

    if n is None:
        for nl in L.nodes:
            if not(nl in Ls.nodes) and not(nl in mapping.keys()):
                mapping = graft(Ls, L, nl, mapping, s2n)

        return mapping

    else:
        mapsto = []
        for p in L.predecessors(n):
            if p in Ls.nodes:
                if s2n is None:
                    mapsto.append(p)
                else:
                    mapsto.append(s2n[p])
            elif p in mapping.keys():
                for gp in mapping[p]:
                    mapsto.append(gp)
            else:
                mapping = graft(Ls, L, p, mapping, s2n)
                for gp in mapping[p]:
                    mapsto.append(gp)

        mapping[n] = mapsto
        return mapping

def harvest(nb_keep=100,
            save=True, dir_path='./',
            leaves=None, leaves_file_name="map_clsloc.txt",
            hyp_file_name='wn_hyp.pl',
            seed=42,
            pruning=True):

    '''
    Build the DAG corresponding to the subset of the leaves selected randomly.
    Also build the s2n and mapping dictionaries and save them.
    '''

    if leaves==None:
        # Get the leaves synsets
        leaves = readLeaves(dir_path+leaves_file_name)

    # Get the hyp links
    with open(dir_path+hyp_file_name, "r") as hyp_file:
        hyplinks = hyp_file.readlines()

    L = pot(hyplinks, leaves)
    if pruning:
        L = prune(L)
    
    if nb_keep<1000:
        # get random subset
        np.random.seed(seed)
        leaves = np.array(leaves)
        np.random.shuffle(leaves)
        subleaves = leaves[0:nb_keep]
        Ls = trim(L, subleaves)

    else:
        Ls = L

    if save:
        config = {}
        config["s2n"] = Synset2Node(Ls)
        if nb_keep<1000:
            config["mapping"] = graft(Ls, L, s2n=config["s2n"])
        path, config_path, graph_path = getPaths(nb_keep, pruning, dir_path)
        # Open a file and use dump()
        with open(config_path, 'w') as f:
                json.dump(config, f)

        with open(graph_path, 'wb') as f:
                pickle.dump(Ls, f)

        # Generate HEX-graph and save it
        hexg = HEXGraph(Ls=Ls, L=L)
        hexg.save(path)

    return Ls

# INCREMENTAL LEARNING

def getPathsCL(split=0, pruning=True, dir_path='./ImageNet/'):
    if pruning:
        path = dir_path+"compilations/CL/"+'CL_p'+str(split)
    else:
        path = dir_path+'CL_'+str(split)
    config_path = path+"_config.json"
    graph_path = path+"_graph.pkl"

    return path, config_path, graph_path

def orchard(nb_splits=10, pruning=True, save=True, leaves=None,
                dir_path='./',
                leaves_file_name="map_clsloc.txt",
                hyp_file_name='wn_hyp.pl',
                seed=42):

    if leaves==None:
        # Get the leaves synsets
        leaves = readLeaves(leaves_file_name)

    # Split the leaves at random
    np.random.seed(seed)
    leaves = np.array(leaves)
    np.random.shuffle(leaves)
    splits = np.split(leaves, nb_splits)

    # Get the hyp links
    with open(hyp_file_name, "r") as hyp_file:
        hyplinks = hyp_file.readlines()

    T = pot(hyplinks, leaves)
    if pruning:
        T = prune(T)

    graphs = []
    subleaves = []
    for (i, split) in enumerate(splits):
        subleaves = subleaves + list(split)
        Ts = trim(T.copy(), subleaves)
        graphs.append(Ts)
        if save:
            config = {}
            config["s2n"] = Synset2Node(Ts)
            config["mapping"] = graft(Ts, T, s2n=config["s2n"])
            path, config_path, graph_path = getPathsCL(i, pruning, dir_path)
            # Open a file and use dump()
            with open(config_path, 'w') as f:
                    json.dump(config, f)

            with open(graph_path, 'wb') as f:
                    pickle.dump(Ts, f)

            # Generate HEX-graph and save it
            hexg = HEXGraph(Ls=Ts, L=T)
            hexg.save(path)

    return graphs
    