# -----------------------------------------------------------------------------
# This file contains the functions to compute the node embeddings and to compute
# the WL-subtree-wasserstein distance matrix
# -----------------------------------------------------------------------------
import ot
import copy
import numpy as np
import random
import copy
import json
import sys

from collections import defaultdict
from sklearn.base import TransformerMixin

def compute_wl_embeddings_discrete(graphs, h):
    wl = WeisfeilerLehman()
    sc = subtree_compression(h) 

    total = len(graphs)
    current = 0
    tree_labels_lists = []
    for graph in graphs:
        preprocessed_graph = wl.fit_transform(graph, h)
        tree_label_list = sc._decomposition(preprocessed_graph)
        tree_labels_lists.append(tree_label_list)
        graph._free()
        current += 1
        percent = '{:.1%}'.format(current/total)
        print('\rBuilding WL-subtrees... %d/%d [%s]' % (current, total, percent), end='')

    graph_embeddings = []
    for tree_labels_list in tree_labels_lists:
        graph_feature = np.zeros((len(tree_labels_list), sc.cnt+1))
        for idx,tree_labels in enumerate(tree_labels_list):
            node_feature = np.zeros(sc.cnt+1)
            for tree_label in tree_labels:
                node_feature[tree_label] += 1
            graph_feature[idx,:] = node_feature
        graph_embeddings.append(graph_feature)

    print(f'\nThe number of types of wl-subtree is {sc.cnt+1}')
    return graph_embeddings

def tree(): 
    return defaultdict(tree)

def add(tree, path):
    for node in path:
        tree = tree[node]
    
class WeisfeilerLehman(TransformerMixin):
    """
    Class that implements the Weisfeiler-Lehman subtree transform
    """
    def __init__(self):
        self._label_dict = {}
        self._last_new_label = 0
        self._preprocess_relabel_dict = {}
        self._results = defaultdict(dict)
        self._label_dicts = {}

    def _reset_label_generation(self):
        self._last_new_label = 0

    def _get_next_label(self):
        self._last_new_label += 1
        return self._last_new_label

    def _relabel_graph(self, graph):
        edge_list1, edge_list2 = copy.deepcopy(graph.edge)
        for edge_idx1, edge_idx2 in zip(edge_list1,edge_list2):
            if edge_idx1 not in graph.adj_node.keys():
                graph.adj_node[edge_idx1] = [edge_idx2]
            else:
                graph.adj_node[edge_idx1].append(edge_idx2)

        labels = list(map(str,np.argmax(graph.attr,axis=1)))

        new_labels = []
        for label in labels:
            if label in self._preprocess_relabel_dict.keys():
                new_labels.append(self._preprocess_relabel_dict[label])
            else:
                self._preprocess_relabel_dict[label] = self._get_next_label()
                new_labels.append(self._preprocess_relabel_dict[label])
        graph.attr = new_labels


        for node_num, adj_node_list in graph.adj_node.items():
            adj_label = []
            for adj_node in adj_node_list:
                label = str(new_labels[adj_node])+'('+str(adj_node)+')'
                adj_label.append( (label,adj_node) )
            graph.adj_label[node_num] = adj_label


        for node_num, label in zip(graph.adj_node.keys(),new_labels):
            wl_subtree = tree()
            label = str(label)+'('+str(node_num)+')'
            add(wl_subtree, label.split('>'))
            graph.wl_subtrees.append(wl_subtree)
            graph.paths.append( [(label,node_num)] )
        
        return graph

    def fit_transform(self, graph, num_iterations=4):
        graph = self._relabel_graph(graph)

        for it in np.arange(1, num_iterations+1, 1):
            for i, (wl_subtree, path_list) in enumerate(zip(graph.wl_subtrees,graph.paths)):
                new_path = []

                for (path, node_num) in path_list:
                    for (next_node_label, next_node_num) in graph.adj_label[node_num]:
                        new_path.append( (path+'>'+next_node_label,next_node_num) )
                graph.paths[i] = new_path
                
                for j, (wl_subtree, path_list) in enumerate(zip(graph.wl_subtrees,graph.paths)):
                    for (path, node_num) in path_list:
                        add(wl_subtree, path.split('>'))
                        graph.wl_subtrees[j] = wl_subtree

        return graph

class subtree_compression(TransformerMixin):
    mod = 1000000007

    def __init__(self, h=4):
        self.count_dict = {}
        self.cnt = 0
        self.random_num = [random.randint(0,self.mod) for i in range(h)]

    def _add_count(self, res):
        if res not in self.count_dict.keys():
            self.cnt += 1
            self.count_dict[res] = self.cnt
            new_label = self.cnt
        else:
            new_label = self.count_dict[res]
        return new_label

    def _search_tree(self, current_tree, tree_labels, depth):
        current_key = list(current_tree.keys())[0]
        current_label = int(current_key[:current_key.find('(')])

        next_trees = current_tree[current_key]
        next_keys = list(next_trees.keys())
        if next_keys == []:
            tree_labels.append(self._add_count(current_label))
            return tree_labels, current_label

        next_labels = []
        for next_key in next_keys:
            next_tree = tree()
            next_tree[next_key] = next_trees[next_key]
            _, next_label = self._search_tree(next_tree, tree_labels, depth+1)
            next_labels.append(next_label)
            

        res = 0
        for next_label in next_labels:
            res += self.random_num[depth] * next_label
        res = (res + current_label) % self.mod
        tree_labels.append(self._add_count(res))
        return tree_labels, current_label

    def _decomposition(self, graph):
        tree_labels_list = []
        for i, wl_subtree in enumerate(graph.wl_subtrees):
            tree_labels, _ = self._search_tree(wl_subtree, [], 0)
            tree_labels_list.append(tree_labels)
        return tree_labels_list

def wl_subtree_wasserstein_distance(graph_embedding1, graph_embedding2):
    m = len(graph_embedding1)
    n = len(graph_embedding2)
    C = ot.dist(graph_embedding1,graph_embedding2,metric='cityblock')
    a = np.ones(m) / m
    b = np.ones(n) / n
    P = ot.emd(a,b,C)
    wasserstein_distance = np.sum(C*P)
    return wasserstein_distance

def compute_wl_subtree_wasserstein_distance(graph_embeddings, h):  
    n = len(graph_embeddings)
    M = np.zeros((n,n))        
    total = n*(n-1)/2
    current = 0
    for i in range(n):
        graph_embedding1 = graph_embeddings[i]
        for j in range(n):
            if i < j :
                graph_embedding2 = graph_embeddings[j]
                M[i][j] = wl_subtree_wasserstein_distance(graph_embedding1, graph_embedding2)
                current += 1
                percent = '{:.1%}'.format(current/total)
                print('\rCalclulating WL-subtree-Wasserstein distances... %d/%d [%s]' % (current, total, percent), end='')

    M = (M + M.T)
    return M
