import numpy as np
from ot.utils import unif
from fngw import fngw_barycenters

import networkx as nx

from joblib import Parallel, delayed
import os
from texttable import Texttable


def plot_qm9(G, ax, text=None, pos=None, draw_edge_feature=True):

    edge_style_map = {"0": "solid", "1": "dotted", "2": "dashed"}
    node_atom_map = {"0": "C", "1": "N", "2": "O", "3": "F"}
    node_color_map =  {"0": 'lightgreen', "1": "turquoise", "2": "lightblue", "3": "tan"}
    if pos is None:
        pos = nx.kamada_kawai_layout(G)

    node_labels = dict(G.nodes.data("feature"))
    for key, val in node_labels.items():
        node_labels[key] = node_atom_map[val]
    
    nodes_colors = []
    for _, val in list(G.nodes.data("feature")):
        nodes_colors.append(node_color_map[val])

    if draw_edge_feature:
        edge_styles = [edge_style_map[edge[-1]] for edge in G.edges.data("bond")]
        nx.draw(G, pos, labels=node_labels, style=edge_styles, width=3, ax=ax, node_color=nodes_colors, node_size=500)
    else:
        nx.draw(G, pos, labels=node_labels, width=3, ax=ax, node_color=nodes_colors, node_size=500)
    ax.set_title(text)
    ax.set_axis_on()

    return pos


def to_networkx(y, use_edge_feature=True, thres=None):
    if use_edge_feature:
        E = y["E"]
        adj = np.argmax(E, axis=-1)
        idx_edge = E.shape[-1] - 1
        A = np.asarray(adj != idx_edge, dtype=int)
    else:
        A = y["A"]
        A = A.copy()
        np.fill_diagonal(A, 0.0)
        A = np.where(A > thres, 1, 0)

    F = y["F"]
    F = np.argmax(F, axis=1)

    rows, cols = np.where(A == 1)
    edges = list(zip(rows.tolist(), cols.tolist()))
    G = nx.Graph()
    G.add_edges_from(edges)
    G.add_nodes_from(list(range(len(F))))

    F_dic = {}
    for k, l in enumerate(F):
        F_dic[k] = str(l.item())

    nx.set_node_attributes(G, F_dic, name="feature")

    if use_edge_feature:
        E_dict = {}
        for i, j in edges:
            E_dict[(i, j)] = {"bond": str(adj[i, j])}

        nx.set_edge_attributes(G, E_dict)

    numeric_indices = [index for index in range(G.number_of_nodes())]
    node_indices = sorted([node for node in G.nodes()])
    assert numeric_indices == node_indices

    return G


def fngw_barycentre_func(Y_tr, lambdas, alpha, beta, n_bary, N=20):
    idxs = np.argsort(lambdas)[-n_bary:]
    lambdas_nbary = [lambdas[idx] for idx in idxs]
    lambdas_nbary = np.clip(lambdas_nbary, 0, 1e8) + 1e-12
    lambdas_nbary /= np.sum(lambdas_nbary)
    ys = [Y_tr[idx] for idx in idxs]

    Fs = [y["F"] for y in ys]
    As = [y["A"] for y in ys]
    Es = [y["E"] for y in ys]

    ps = [unif(A.shape[0]) for A in As]

    F_bary, A_bary, E_bary, log = fngw_barycenters(
        N=N,
        Fs=Fs,
        As=As,
        Cs=Es,
        ps=ps,
        lambdas=lambdas_nbary,
        alpha=alpha,
        beta=beta,
        random_state=42,
        log=True,
    )

    return {
        "F": F_bary,
        "A": A_bary,
        "E": E_bary,
        "cands": idxs,
        "weights": lambdas_nbary,
        "Ts": log["T"],
    }

def eval_graph(G_preds, G_trgts, with_edge_feature=True):
    res_total = {}

    node_match = lambda x, y: x["feature"] == y["feature"]
    edge_match = lambda x, y: x["bond"] == y["bond"]
    if with_edge_feature:
        geds = [
            nx.graph_edit_distance(
                G_pred, G_trgt, node_match=node_match, edge_match=edge_match
            )
            for G_pred, G_trgt in zip(G_preds, G_trgts)
        ]
    else:
        geds = [
            nx.graph_edit_distance(G_pred, G_trgt, node_match=node_match)
            for G_pred, G_trgt in zip(G_preds, G_trgts)
        ]

    res_total["edit_distance"] = np.mean(geds)
    res_total["eds"] = geds

    return res_total

def eval_graph_parallel(G_preds, G_trgts, with_edge_feature=True, n_jobs=64):
    res_total = {}

    node_match = lambda x, y: x["feature"] == y["feature"]
    edge_match = lambda x, y: x["bond"] == y["bond"]
    if with_edge_feature:
        geds = Parallel(n_jobs=n_jobs)(
            delayed(nx.graph_edit_distance)(
                G_pred, G_trgt, node_match=node_match, edge_match=edge_match
            )
            for G_pred, G_trgt in zip(G_preds, G_trgts)
        )
    else:
        geds = Parallel(n_jobs=n_jobs)(
            delayed(nx.graph_edit_distance)(G_pred, G_trgt, node_match=node_match)
            for G_pred, G_trgt in zip(G_preds, G_trgts)
        )

    res_total["edit_distance"] = np.mean(geds)
    res_total["eds"] = geds

    return res_total


def tab_printer(args):
    args = vars(args)
    keys = sorted(args.keys())
    t = Texttable()
    t.set_precision(10)
    t.add_rows(
        [["Parameter", "Value"]]
        + [[k.replace("_", " ").capitalize(), args[k]] for k in keys]
    )
    return t.draw()
