import math
import heapq
import copy
import numpy as np
import numba as nb
import networkx as nx
import matplotlib.pyplot as plt
import torch

def get_id():
    i = 0
    while True:
        yield i
        i += 1


def get_graph_parse(adj_matrix):
    g_num_nodes = adj_matrix.shape[0]  # num of nodes
    adj_table = {}
    VOL = 0  # Graph's degree
    node_vol = []
    for i in range(g_num_nodes):
        n_v = 0
        adj = set()
        for j in range(g_num_nodes):
            # use din for directed graph
            if adj_matrix[i,j] != 0:
                n_v += adj_matrix[i,j]
                VOL += adj_matrix[i,j]
                adj.add(j)
        adj_table[i] = adj  # node i's neighbors ID
        node_vol.append(n_v)  # every nodes' in_degree
    return g_num_nodes, VOL, node_vol, adj_table


@nb.jit(nopython=True)
def cut_volume(adj_matrix, p1, p2):
    c12 = 0
    for i in range(len(p1)):
        for j in range(len(p2)):
            c = adj_matrix[p1[i], p2[j]]
            if c != 0:
                c12 += c
    return c12


def LayerFirst(node_dict,start_id):
    stack = [start_id]
    while len(stack) != 0:
        node_id = stack.pop(0)
        yield node_id
        if node_dict[node_id].children:
            for c_id in node_dict[node_id].children:
                stack.append(c_id)


def merge(new_ID, id1, id2, cut_v, node_dict):
    # assert (len(cut_v_s) == 2)
    new_partition = node_dict[id1].partition + node_dict[id2].partition
    v = node_dict[id1].vol + node_dict[id2].vol
    g = node_dict[id1].g + node_dict[id2].g - 2 * cut_v
    # g = node_dict[id1].g + node_dict[id2].g - 2 * cut_v
    child_h = max(node_dict[id1].child_h, node_dict[id2].child_h) + 1
    new_node = PartitionTreeNode(ID=new_ID,partition=new_partition, children=[id1,id2],
                                 g=g, vol=v, child_h= child_h, child_cut=cut_v)
    node_dict[id1].parent = new_ID
    node_dict[id2].parent = new_ID
    node_dict[new_ID] = new_node


def compressNode(node_dict, node_id, parent_id):
    p_child_h = node_dict[parent_id].child_h
    node_children = node_dict[node_id].children
    # cut_left/right
    node_dict[parent_id].child_cut += node_dict[node_id].child_cut
    node_dict[parent_id].children.remove(node_id)
    node_dict[parent_id].children = node_dict[parent_id].children + node_children
    for c in node_children:
        node_dict[c].parent = parent_id
    com_node_child_h = node_dict[node_id].child_h
    node_dict.pop(node_id)

    if (p_child_h - com_node_child_h) == 1:
        while True:
            max_child_h = max([node_dict[f_c].child_h for f_c in node_dict[parent_id].children])
            if node_dict[parent_id].child_h == (max_child_h + 1):
                break
            node_dict[parent_id].child_h = max_child_h + 1
            parent_id = node_dict[parent_id].parent
            if parent_id is None:
                break

def deleteNode(code_tree, node_id):
    node_dict = code_tree.tree_node
    parent_id = node_dict[node_id].parent
    node_dict[parent_id].children.remove(node_id)
    node_dict[parent_id].partition.remove(node_dict[node_id].partition[0])
    node_dict[parent_id].vol -= 1
    cut_v = cut_volume(code_tree.adj_matrix, np.array(node_dict[node_id].partition), np.array(node_dict[parent_id].partition))
    node_dict[parent_id].g = node_dict[parent_id].g - node_dict[node_id].v + 2*cut_v

def divide_community(code_tree, community_id, subgroups):
    node_dict = code_tree.tree_node
    adj_matrix = code_tree.adj_matrix
    root_id = code_tree.root_id
    g = 0
    v = 0
    parent_id = node_dict[community_id].parent
    node_dict[community_id].children.clear()
    old_partition = node_dict[community_id].partition
    node_dict[parent_id].children.remove(community_id)
    node_dict.pop(community_id)
    for subgroup in subgroups:
        if len(subgroup) == 0:
            continue
        elif len(subgroup) == 1:
            node_dict[subgroup[0]].parent = parent_id
        else:
            new_id = copy.copy(code_tree.id_g)
            code_tree.id_g += 1
            for i in subgroup:
                v += node_dict[i].vol
                g += node_dict[i].vol - cut_volume(adj_matrix, node_dict[i].partition, [x for x in old_partition if x not in node_dict[i].partition])
                node_dict[i].parent = new_id
            child_h = 1
            new_node = PartitionTreeNode(ID=new_id, partition=subgroup, children=subgroup,
                                        g=g, vol=v, child_h=child_h)
            new_node.parent = parent_id
            node_dict[new_id] = new_node
            node_dict[parent_id].children.append(new_id)
            node_dict[parent_id].child_h = max(node_dict[parent_id].child_h, node_dict[new_id].child_h+1)

    community_id = new_id
    while(True):
        node_dict[parent_id].child_h = max(node_dict[parent_id].child_h, node_dict[community_id].child_h+1)
        if parent_id == root_id:
            break
        community_id = parent_id
        parent_id = node_dict[community_id].parent

def child_tree_deepth(node_dict, nid):
    node = node_dict[nid]
    deepth = 0
    while node.parent is not None:
        node = node_dict[node.parent]
        deepth+=1
    deepth += node_dict[nid].child_h
    return deepth


def CompressDelta(node1, p_node):
    a = 2 * node1.child_cut
    v1 = node1.vol + 1
    v2 = p_node.vol + 1
    return a * math.log2(v2/v1)


def CombineDelta(node1, node2, cut_v, g_vol):
    # assert (len(cut_v_s) == 2)
    v1 = node1.vol + 1
    v2 = node2.vol + 1
    g1 = node1.g + 1
    g2 = node2.g + 1
    v12 = v1 + v2
    return ((v1 - g1) * math.log2(v12/v1) + (v2 - g2) * math.log2(v12/v2) - 2 * cut_v * math.log2(g_vol/v12)) / g_vol


class PartitionTreeNode():
    def __init__(self, ID, partition, vol, g, children: list=None, parent=None, child_h=0, child_cut=0):
        self.ID = ID 
        self.partition = partition
        self.parent = parent
        self.children = children
        self.vol = vol
        self.g = g

        self.merged = False
        self.child_h = child_h 
        self.child_cut = child_cut
        self.node_list = None
        self.sim_list = None

    def __str__(self):
        return "{" + "{}:{}".format(self.__class__.__name__, self.gatherAttrs()) + "}"

    def gatherAttrs(self):
        return ",".join("{}={}"
                        .format(k, getattr(self, k))
                        for k in self.__dict__.keys())
    

class PartitionTree():
    def __init__(self, adj_matrix):
        self.adj_matrix = adj_matrix
        self.tree_node = {}  # dict of PartitionTreeNodes
        self.g_num_nodes, self.VOL, self.node_vol, self.adj_table = get_graph_parse(adj_matrix)
        # self.id_g = get_id()
        self.id_g = 0
        self.leaves = []
        self.build_leaves()
    def build_leaves(self):
        for vertex in range(self.g_num_nodes):
            ID = copy.copy(self.id_g)
            self.id_g += 1
            v = self.node_vol[vertex]
            leaf_node = PartitionTreeNode(ID=ID, partition=[vertex], g=v, vol=v)
            self.tree_node[ID] = leaf_node
            self.leaves.append(ID)

    # Calculate the entropy of the entire tree.
    def entropy(self, node_dict=None):
        if node_dict is None:
            node_dict = self.tree_node
        ent = 0
        for node_id, node in node_dict.items():
            if node.parent is not None:
                node_p = node_dict[node.parent]
                node_vol = node.vol + 1
                node_g = node.g
                node_p_vol = node_p.vol + 1
                ent += - (node_g / self.VOL) * math.log2(node_vol/node_p_vol)
        return ent

    def __build_k_tree(self, g_vol, nodes_dict:dict, k=None):
        min_heap = [] 
        cmp_heap = []
        delta_list = []
        nodes_ids = copy.deepcopy(list(nodes_dict.keys())) 
        new_id = None
        for i in nodes_ids:
            for j in self.adj_table[i]:
                if j > i:
                    n1 = nodes_dict[i]
                    n2 = nodes_dict[j]
                    cut_v = []
                    if len(n1.partition) == 1 and len(n2.partition) == 1:
                        cut_v = self.adj_matrix[n1.partition[0],n2.partition[0]]
                    else:
                        cut_v = cut_volume(self.adj_matrix,p1 = np.array(n1.partition),p2=np.array(n2.partition))
                    diff = CombineDelta(nodes_dict[i], nodes_dict[j], cut_v, g_vol)
                    heapq.heappush(min_heap, (diff, i, j, cut_v)) 

        # Merge the leaf nodes until all are merged.
        unmerged_count = len(nodes_ids)
        while unmerged_count > 1:
            if len(min_heap) == 0:
                break
            diff, id1, id2, cut_v = heapq.heappop(min_heap)
            if nodes_dict[id1].merged or nodes_dict[id2].merged:
                continue
            nodes_dict[id1].merged = True
            nodes_dict[id2].merged = True
            new_id = copy.copy(self.id_g)
            self.id_g += 1
            # Mutable objects, with functions passed by reference; the merge function will affect the values in nodes_dict.
            merge(new_id, id1, id2, cut_v, nodes_dict)
            # Adjacent nodes, bidirectional operations.
            self.adj_table[new_id] = self.adj_table[id1].union(self.adj_table[id2])
            for i in self.adj_table[new_id]:
                self.adj_table[i].add(new_id)
            # compress delta
            if nodes_dict[id1].child_h > 0:
                heapq.heappush(cmp_heap,[CompressDelta(nodes_dict[id1],nodes_dict[new_id]),id1,new_id])
            if nodes_dict[id2].child_h > 0:
                heapq.heappush(cmp_heap,[CompressDelta(nodes_dict[id2],nodes_dict[new_id]),id2,new_id])
            unmerged_count -= 1

            # Add the change in structural entropy from merging the new node with other nodes to the min-heap.
            for ID in self.adj_table[new_id]:
                if not nodes_dict[ID].merged:
                    n1 = nodes_dict[ID]
                    n2 = nodes_dict[new_id]
                    cut_v = cut_volume(self.adj_matrix,np.array(n1.partition), np.array(n2.partition))
                    new_diff = CombineDelta(nodes_dict[ID], nodes_dict[new_id], cut_v, g_vol)
                    delta_list.append(new_diff)
                    heapq.heappush(min_heap, (new_diff, ID, new_id, cut_v))
        root = new_id
        
        if unmerged_count > 1:      
            # combine solitary node
            assert len(min_heap) == 0
            unmerged_nodes = [i for i, j in nodes_dict.items() if not j.merged]
            new_child_h = max([nodes_dict[i].child_h for i in unmerged_nodes]) + 1

            new_id = copy.copy(self.id_g)
            self.id_g += 1
            new_node = PartitionTreeNode(ID=new_id, partition=list(nodes_ids), children=unmerged_nodes,
                                         vol=g_vol, g=0, child_h=new_child_h)
            nodes_dict[new_id] = new_node

            for i in unmerged_nodes:
                nodes_dict[i].merged = True
                nodes_dict[i].parent = new_id
                if nodes_dict[i].child_h > 0:
                    heapq.heappush(cmp_heap, [CompressDelta(nodes_dict[i], nodes_dict[new_id]), i, new_id])
            root = new_id

        if k is not None:
            while nodes_dict[root].child_h > k:
                diff, node_id, p_id = heapq.heappop(cmp_heap)
                if child_tree_deepth(nodes_dict, node_id) <= k:
                    continue
                children = nodes_dict[node_id].children
                compressNode(nodes_dict, node_id, p_id)
                if nodes_dict[root].child_h == k:
                    break
                for e in cmp_heap:
                    if e[1] == p_id:
                        if child_tree_deepth(nodes_dict, p_id) > k:
                            e[0] = CompressDelta(nodes_dict[e[1]], nodes_dict[e[2]])
                    if e[1] in children:
                        if nodes_dict[e[1]].child_h == 0:
                            continue
                        if child_tree_deepth(nodes_dict, e[1]) > k:
                            e[2] = p_id
                            e[0] = CompressDelta(nodes_dict[e[1]], nodes_dict[p_id])
                heapq.heapify(cmp_heap)
        return root


    def build_coding_tree(self, k=2, mode='v1'):
        if k == 1:
            return
        # print(k)
        if mode == 'v1' or k is None:
            self.root_id = self.__build_k_tree(self.VOL, self.tree_node, k=k)
        
        count = 0
        for _ in LayerFirst(self.tree_node, self.root_id):
            count += 1
        assert len(self.tree_node) == count

    def deduct_se(self, leaf_id, root_id=None):
        node_dict = self.tree_node
        path_id = [leaf_id]
        current_id = leaf_id
        while True:
            parent_id = node_dict[current_id].parent
            if parent_id == root_id:
                break
            if node_dict[parent_id].partition == node_dict[current_id].partition:
                current_id = parent_id
                path_id[-1] = current_id
                continue
            path_id.append(parent_id)
            current_id = parent_id
        if root_id == None:
            path_id = path_id[0:-1]  #remove root
        g = []
        vol = []
        parent_vol = []
        for e in path_id:
            g.append(node_dict[e].g)
            vol.append(node_dict[e].vol)
            parent_vol.append(node_dict[node_dict[e].parent].vol)
        g = torch.tensor(g)
        vol = torch.tensor(vol) + 1
        parent_vol = torch.tensor(parent_vol) + 1
        deduct_se = -(g / self.VOL * torch.log2(vol / parent_vol)).sum()
        return deduct_se
    
