import networkx as nx
from builtins import str
from typing import Optional, Tuple, List
import numpy as np



def PointCloud2nxGraph(PC: np.array, initial_labels: Optional[dict] = None) -> nx.Graph:
    '''
        To convert a point cloud to a distance graph.
        Input a point cloud: np.array, shape = (n, 3)
        Middle is a distance graph: np.array, shape = (n, n)
        Output is a networkx graph
    '''
    DG = np.sqrt(
            np.sum(
                (PC[:, np.newaxis, :] - PC[np.newaxis, :, :]) ** 2, 
                axis=2
                )
            )
    
    G = nx.DiGraph(np.array(DG))
    
    # initial label
    if initial_labels is None:
        for i in range(len(G.nodes)):
            G.nodes[i]['label'] = 1      
    else:
        assert len(initial_labels.keys()) == len(G.nodes)
        nx.set_node_attributes(G, initial_labels, 'label')
    
    return G


def color_aggr(G: nx.Graph, round_num):
    color_texts = {}
    label_list = nx.get_node_attributes(G,'label')
    for u in G.nodes():  
        color_text = []
        
        # multiset
        adj_nodes = G.adj[u]   
        for adj_node, attr in adj_nodes.items():
            label_adj = label_list[adj_node]
            tup = (label_adj, round(attr['weight'], round_num))
            color_text.append(tup)
        color_text.sort()
        
        # self
        label_u = label_list[u]
        color_text.insert(0,label_u)
        
        color_texts[u] = str(color_text)
        
    return color_texts
        


def color_reassignment(color_text_l: dict, color_text_r: dict, Gl: nx.Graph, Gr: nx.Graph):
    color_idx = 1 # always relabel from 1
    
    # get the map from color text to new color index
    new_color = {}
    for color_text in color_text_l.values():
        if color_text not in new_color.keys():
            new_color[color_text] = color_idx
            color_idx += 1
    for color_text in color_text_r.values():
        if color_text not in new_color.keys():
            new_color[color_text] = color_idx
            color_idx += 1

    # relabel
    new_label_list_l, new_label_list_r = {}, {}
    for u in Gl.nodes():
        color_text = color_text_l[u]
        new_label_list_l[u] = new_color[color_text]
    nx.set_node_attributes(Gl, new_label_list_l, 'label')
    for u in Gr.nodes():
        color_text = color_text_r[u]
        new_label_list_r[u] = new_color[color_text]
    nx.set_node_attributes(Gr, new_label_list_r, 'label')
    
    return new_label_list_l, new_label_list_r
        


def CR_E(Gl: nx.Graph, Gr: nx.Graph, verbose: bool = False, round_num: int = 1) -> Tuple[bool, Optional[dict], Optional[dict], int]:
    if verbose:
        print("-----------Beginning CR_E-----------")
    iteration = 0
    
    label_list_l, label_list_r = nx.get_node_attributes(Gl, 'label'), nx.get_node_attributes(Gr, 'label')
    if verbose:
        print("-----------Initial Labels-----------")
        print("Graph L:", label_list_l)
        print("Graph R:", label_list_r)
    while True:  
        # color refinement
        color_text_l, color_text_r = color_aggr(Gl, round_num), color_aggr(Gr, round_num) 
        new_label_list_l, new_label_list_r = color_reassignment(color_text_l,color_text_r,Gl,Gr)
        iteration += 1
        
        if new_label_list_l == label_list_l and new_label_list_r == label_list_r:
            break
        
        if verbose:
            print(f"-----------Iteration {iteration}-----------")
            print("Graph L:", new_label_list_l)
            print("Graph R:", new_label_list_r)
        
        result_l = sorted(map(int, new_label_list_l.values()))
        result_r = sorted(map(int, new_label_list_r.values()))
        
        if result_l != result_r:
            if verbose:
                print("-----------Ending CR-E-----------")
            return True, new_label_list_l, new_label_list_r, iteration
        
        
        label_list_l, label_list_r = new_label_list_l, new_label_list_r
        
    if verbose:
        print("-----------Ending CR-E-----------")
    return False, new_label_list_l, new_label_list_r, iteration


