import os
import numpy as np
import matplotlib
import networkx as nx
import torch
import matplotlib.pyplot as plt
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_networkx
from tqdm import tqdm

from gpl import TMP_DIR


Mutag_node_labels = {
        0: 'C',
        1: 'O',
        2: 'Cl',
        3: 'H',
        4: 'N',
        5: 'F',
        6: 'Br',
        7: 'S',
        8: 'P',
        9: 'I',
        10: 'Na',
        11: 'K',
        12: 'Li',
        13: 'Ca',
	
}

# http://jmol.sourceforge.net/jscolors/
Mutag_node_colors = {
        'C': '#909090',
        'O': '#FF0D0D',
        'Cl': '#1FF01F',
        # 'H': '#FFFFFF',
        'H': '#D6D6D6',
        'N': '#3050F8',
        'F': '#90E050',
        'Br': '#A62929',
        'S': '#FFFF30',
        'P': '#FF8000',
        'I': '#940094',
        'Na': '#AB5CF2',
        'K': '#8F40D4',
        'Li': '#CC80FF',
        'Ca': '#3DFF00',
	
}

Spmotif_node_colors = {
    0: 'black',
    1: 'green'
}


def visualize_a_graph_v2(G_nx, node_mask, edge_mask, plot_original, ax, exp_name):
    pos = nx.kamada_kawai_layout(G_nx)
    node_mask = node_mask if node_mask is not None else None
    edge_mask = edge_mask if edge_mask is not None else None

    draw_options = {}
    draw_options['node_size'] = 120
    draw_options['width'] = 2 # edge width

    if 'mutag' in exp_name:
        node_labels = {node: Mutag_node_labels[G_nx.nodes[node]['node_type']] for node in G_nx.nodes()}
        node_colors = [Mutag_node_colors[ node_labels[node] ] for node in G_nx.nodes()]
    elif 'spmotif' in exp_name:
        node_labels = {node: int(G_nx.nodes[node]['node_label']) for node in G_nx.nodes()}
        node_colors = [Spmotif_node_colors[ node_labels[node] ] for node in G_nx.nodes()]
    else:
        raise NotImplementedError


    draw_options['labels'] = node_labels
    draw_options['node_color'] = node_colors

    if plot_original is False:
        if node_mask is not None:
            node_colors = node_mask
            node_labels = { n:f"{node_colors[n]:.2f}" for n in G_nx.nodes() }
            node_cmap = matplotlib.cm.get_cmap('Reds')

            draw_options['cmap'] = node_cmap
            draw_options['labels'] = node_labels
            draw_options['node_color'] = node_colors

        if edge_mask is not None:
            edge_colors = edge_mask
            edge_cmap = matplotlib.cm.get_cmap('Reds')

            draw_options['edge_color'] = edge_colors
            draw_options['edge_cmap'] = edge_cmap


    nx.draw_networkx(G_nx, pos, ax=ax, **draw_options)
    ax.set_title(label=f'Class {int(G_nx.y_)}')
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)



def split_batch_node_edge_mask(node_mask, edge_mask, batch_data):
    graphs = batch_data.to_data_list()
    slices_nodes = torch.zeros(len(graphs)+1, dtype=torch.int64)
    slices_edges = torch.zeros(len(graphs)+1, dtype=torch.int64)

    number_of_edges = torch.tensor([ g.edge_index.shape[1] for g in graphs ])
    torch.cumsum(number_of_edges, 0, out=slices_edges[1:])
    number_of_nodes = torch.tensor([ g.x.shape[0] for g in graphs ])
    torch.cumsum(number_of_nodes, 0, out=slices_nodes[1:])

    node_mask_splitted = [None for i in range(len(graphs))]
    edge_mask_splitted = [None for i in range(len(graphs))]
    if node_mask is not None:
        for gid, _ in enumerate(graphs):
            start, end = slices_nodes[gid], slices_nodes[gid+1]
            node_mask_splitted[gid] = node_mask[start:end]
    
    if edge_mask is not None:
        for gid, _ in enumerate(graphs):
            start, end = slices_edges[gid], slices_edges[gid+1]
            edge_mask_splitted[gid] = edge_mask[start:end]
    return node_mask_splitted, edge_mask_splitted

def get_undirected_edge_color(data, G_nx, edge_mask):
    
    edge_colors = {}
    for (u, v, mask) in zip(data.edge_index[0], data.edge_index[1], edge_mask):
        edge_colors[f'{u.item()}_{v.item()}'] = mask.item()
    G_nx_edge_colors = []
    for (u, v) in G_nx.edges():
        G_nx_edge_colors.append( edge_colors[f'{int(u)}_{int(v)}'] )

    return G_nx_edge_colors

def normalize_mask_values(mask):
   
    return mask

def visualize_one_graph(G: Data, exp_name, node_mask, edge_mask):
    
    if 'mutag' in exp_name:
        G_nx = to_networkx(G, node_attrs=['node_type'], to_undirected=True)
    else:
        G_nx = to_networkx(G, node_attrs=['node_label'], to_undirected=True)

    
    G_nx.y_ = G.y.item()
    
    node_mask = node_mask.flatten().detach().cpu().tolist()
    edge_mask = get_undirected_edge_color(G, G_nx, edge_mask)
    node_mask = None

    edge_mask = normalize_mask_values(edge_mask)

    fig, axes = plt.subplots(1, 1, figsize=(7, 6))
    visualize_a_graph_v2(G_nx, node_mask, edge_mask, plot_original=False, ax=axes, exp_name=exp_name)
    
    return fig
    
def investigate_Z_interpolation_evaluate_callback(test_results, **kwargs):
    device = kwargs['__trainer__'].device
    model = kwargs['model']
    exp_name = kwargs['__trainer__'].EXP_NAME
    test_dataloader = kwargs['dataloaders'].test_dataloader
    g = test_dataloader.dataset[0]
    N = g.x.shape[0]
    x = g.x.to(device)
    edge_index = g.edge_index.to(device)
    batch = torch.zeros((x.shape[0]), dtype=torch.long).to(device)
    edge_attr = g.edge_attr.to(device)


    sampled_Z1 = torch.normal(mean=0.0, std=1.0, size=(1, 64)).to(device)
    sampled_Z2 = torch.normal(mean=0.0, std=1.0, size=(1, 64)).to(device)
    sampled_Z_ipl = (sampled_Z1 + sampled_Z2)/2

    embs = model.encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr)
    edge_mask1, node_mask1 = model.get_mask(N, embs, edge_index, batch, sampled_Z1)
    edge_mask2, node_mask2 = model.get_mask(N, embs, edge_index, batch, sampled_Z2)
    edge_mask_ipl, node_mask_ipl = model.get_mask(N, embs, edge_index, batch, sampled_Z_ipl)


    fig1 = visualize_one_graph(g, exp_name, node_mask1, edge_mask1)
    fig2 = visualize_one_graph(g, exp_name, node_mask2, edge_mask2)
    fig_ipl = visualize_one_graph(g, exp_name, node_mask_ipl, edge_mask_ipl) # interpolation
    import ipdb; ipdb.set_trace()


    fig_save_dir = TMP_DIR/'Z_interpolation_vis'
    fig1.savefig(fig_save_dir/'fig1.pdf')
    fig2.savefig(fig_save_dir/'fig2.pdf')
    fig_ipl.savefig(fig_save_dir/'fig_ipl.pdf')



def visualize_assignments(batch_data: Batch, node_mask, edge_mask, batch_id, fig_save_dir, **kwargs): # assignment_logits
    """
        data: a batch of graphs
    """
    exp_name = kwargs['experiment_name']
    graphs = batch_data.to_data_list()
    node_mask_splitted, edge_mask_splitted = split_batch_node_edge_mask(node_mask, edge_mask, batch_data)
    
    # splitting into subgraphs
    for gid, G in tqdm(enumerate(graphs), total=len(graphs)): # for each subgraph, visualize it
        y = G.y.item()
        node_mask, edge_mask = node_mask_splitted[gid], edge_mask_splitted[gid]
        fig = visualize_one_graph(G, exp_name, node_mask, edge_mask)
        fig.savefig(fig_save_dir/f'y_{y}_batch_{batch_id}_gid_{gid}.pdf')
        plt.close(fig)

        if gid > 80: # only plot 80 graphs for testing
            break

def show_metrics_callback(test_results, **kwargs):
    from gpl.utils.evaluate import process_one_set
    loss, acc, auc = process_one_set(test_results)
    print(f'test loss: {loss:.4f}, test acc: {acc:.4f}, test auc: {auc:.4f}')

def show_assignments_callback(test_results, **kwargs):
    model = kwargs['model']
    dataloaders = kwargs['dataloaders']
    exp_name = kwargs['experiment_name']
    fig_save_dir = TMP_DIR/'show_assignment_matrix'/exp_name
    if not fig_save_dir.exists():
        os.mkdir(fig_save_dir)

    dataloader = dataloaders.train_dataloader
    for batch_i, batch_data in enumerate(dataloader):
        with torch.no_grad():
            loss_dict = model.forward_pass(batch_data, batch_i)
            node_mask = loss_dict.get('node_mask', None)
            edge_mask = loss_dict.get('edge_mask', None)
        visualize_assignments(batch_data, node_mask, edge_mask, batch_i, fig_save_dir=fig_save_dir, **kwargs)
        break
    
    pass


def show_prototypes_criticisms_callback(test_results, **kwargs):
    from gpl.utils.evaluate import get_embeddings
    from gpl.utils.mmd.mmd_critic import select_prototypes_criticisms

    model = kwargs['model']
    dataloaders = kwargs['dataloaders']
    exp_name = kwargs['experiment_name']
    fig_save_dir = TMP_DIR/'show_prototypes'/exp_name
    if not fig_save_dir.exists():
        os.makedirs(fig_save_dir, exist_ok=True)

    dataloader = dataloaders.train_dataloader

    prototype_indices = None
    criticism_indices = None

    embs_dict, train_y, _ = get_embeddings(model, dataloaders.train_dataloader, key=['subg_embs', 'embs_recon_graph'])
    train_embs = np.concatenate([embs_dict['subg_embs'], embs_dict['embs_recon_graph']], axis=1)
    train_embs = torch.tensor(train_embs, dtype=torch.float)
    train_y = torch.tensor(train_y, dtype=torch.long)
    train_y = train_y.reshape(-1)
    
    num_prototypes = 100
    num_criticisms = 100


    prototyes_dataset = select_prototypes_criticisms(train_embs, train_y, num_prototypes=num_prototypes, num_criticisms=num_criticisms)
    prototype_indices = prototyes_dataset.prototype_indices.numpy()
    criticism_indices = prototyes_dataset.criticism_indices.numpy()

    train_emb_indices = dataloaders.train_dataloader.sampler.indexes
    proto_in_dataset_indices = np.array(train_emb_indices)[prototype_indices]
    criti_in_dataset_indices = np.array(train_emb_indices)[criticism_indices]

    for i, idx in enumerate(proto_in_dataset_indices):
        g = dataloaders.train_dataloader.dataset[idx]
        batch_data = Batch.from_data_list([g,])
        loss_dict = model.forward_pass(batch_data, batch_idx=0, compute_loss=False)
        node_mask = loss_dict.get('node_mask', None)
        edge_mask = loss_dict.get('edge_mask', None)

        fig = visualize_one_graph(g, exp_name, node_mask, edge_mask)
        fig_save_idx = i
        fig.savefig(fig_save_dir/f'prototype_{fig_save_idx}.jpg', bbox_inches='tight', dpi=300)
        print(fig_save_dir/f'prototype_{fig_save_idx}.jpg', 'saved')
    
    for i, idx in enumerate(criti_in_dataset_indices):
        g = dataloaders.train_dataloader.dataset[idx]
        batch_data = Batch.from_data_list([g,])
        loss_dict = model.forward_pass(batch_data, batch_idx=0, compute_loss=False)
        # assignment_logits = loss_dict['assign_matrix']
        node_mask = loss_dict.get('node_mask', None)
        edge_mask = loss_dict.get('edge_mask', None)

        fig = visualize_one_graph(g, exp_name, node_mask, edge_mask)
        fig_save_idx = i
        fig.savefig(fig_save_dir/f'criticism_{fig_save_idx}.jpg', bbox_inches='tight', dpi=300)
        print(fig_save_dir/f'criticism_{fig_save_idx}.jpg', 'saved')
