import networkx as nx
import numpy as np
from scipy.linalg import eigvalsh
from string import ascii_uppercase, digits
import os
import secrets
import subprocess as sp


def evaluate_model(original_graph, generated_graphs, metrics_list):
    to_log = {}
    
    if "degree" in metrics_list:
        degree = degree_stats(original_graph, generated_graphs)
        to_log["degree"] = degree.item()
    
    if "clustering" in metrics_list:
        clustering = clustering_stats(original_graph, generated_graphs)
        to_log["clustering"] = clustering.item()

    if 'spectral' in metrics_list:
        spectral = spectral_stats(original_graph, generated_graphs)
        to_log["spectral"] = spectral.item()

    if 'orbit' in metrics_list:
        orbit = orbit_stats_all(original_graph, generated_graphs)
        to_log["orbit"] = orbit.item()

    if 'motif' in metrics_list:
        motif = motif_stats(original_graph, generated_graphs)
        to_log["motif"] = motif.item()

    return to_log


def degree_stats(graph_ref_list, graph_pred_list):

    sample_ref = []
    sample_pred = []
    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    for i in range(len(graph_ref_list)):
        degree_temp = np.array(nx.degree_histogram(graph_ref_list[i]))
        sample_ref.append(degree_temp)
    for i in range(len(graph_pred_list_remove_empty)):
        degree_temp = np.array(nx.degree_histogram(graph_pred_list_remove_empty[i]))
        sample_pred.append(degree_temp)

    mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv)

    return mmd_dist

def clustering_stats(graph_ref_list, graph_pred_list, bins=20):

    sample_ref = []
    sample_pred = []
    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    for i in range(len(graph_ref_list)):
        clustering_coeffs_list = list(nx.clustering(graph_ref_list[i]).values())
        hist, _ = np.histogram(clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False)
        sample_ref.append(hist)

    for i in range(len(graph_pred_list_remove_empty)):
        clustering_coeffs_list = list(nx.clustering(graph_pred_list_remove_empty[i]).values())
        hist, _ = np.histogram(clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False)
        sample_pred.append(hist)

    mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv)

    return mmd_dist

def motif_stats(graph_ref_list, graph_pred_list, motif_type='3path', ground_truth_match=None):

    total_counts_ref = []
    total_counts_pred = []

    num_matches_ref = []
    num_matches_pred = []

    graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0]
    indices = motif_to_indices[motif_type]

    for G in graph_ref_list:
        orbit_counts = orca(G)
        motif_counts = np.sum(orbit_counts[:, indices], axis=1)

        if ground_truth_match is not None:
            match_cnt = 0
            for elem in motif_counts:
                if elem == ground_truth_match:
                    match_cnt += 1
            num_matches_ref.append(match_cnt / G.number_of_nodes())

        motif_temp = np.sum(motif_counts) / G.number_of_nodes()
        total_counts_ref.append(motif_temp)

    for G in graph_pred_list_remove_empty:
        orbit_counts = orca(G)
        motif_counts = np.sum(orbit_counts[:, indices], axis=1)

        if ground_truth_match is not None:
            match_cnt = 0
            for elem in motif_counts:
                if elem == ground_truth_match:
                    match_cnt += 1
            num_matches_pred.append(match_cnt / G.number_of_nodes())

        motif_temp = np.sum(motif_counts) / G.number_of_nodes()
        total_counts_pred.append(motif_temp)

    total_counts_ref = np.array(total_counts_ref)[:, None]
    total_counts_pred = np.array(total_counts_pred)[:, None]

    mmd_dist = compute_mmd(total_counts_ref, total_counts_pred, kernel=gaussian, is_hist=False)
    return mmd_dist


def orbit_stats_all(graph_ref_list, graph_pred_list):

    total_counts_ref = []
    total_counts_pred = []

    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    for G in graph_ref_list:
        orbit_counts = orca(G)
        orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
        total_counts_ref.append(orbit_counts_graph)

    for G in graph_pred_list_remove_empty:
        orbit_counts = orca(G)
        orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
        total_counts_pred.append(orbit_counts_graph)

    total_counts_ref = np.array(total_counts_ref)
    total_counts_pred = np.array(total_counts_pred)

    mmd_dist = compute_mmd(total_counts_ref, total_counts_pred, kernel=gaussian, is_hist=False)
    return mmd_dist


def spectral_stats(graph_ref_list, graph_pred_list, n_eigvals=-1):

    sample_ref = []
    sample_pred = []

    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    for i in range(len(graph_ref_list)):
        spectral_temp = spectral_worker(graph_ref_list[i], n_eigvals)
        sample_ref.append(spectral_temp)
    for i in range(len(graph_pred_list_remove_empty)):
        spectral_temp = spectral_worker(graph_pred_list_remove_empty[i], n_eigvals)
        sample_pred.append(spectral_temp)

    mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv)

    return mmd_dist



def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs):
    if is_hist:
        samples1 = [s1 / (np.sum(s1) + 1e-6) for s1 in samples1]
        samples2 = [s2 / (np.sum(s2) + 1e-6) for s2 in samples2]
    return disc(samples1, samples1, kernel, *args, **kwargs) + disc(samples2, samples2, kernel, *args, **kwargs) - \
                2 * disc(samples1, samples2, kernel, *args, **kwargs)


def disc(samples1, samples2, kernel, *args, **kwargs):
    d = 0

    for s1 in samples1:
        for s2 in samples2:
            d += kernel(s1, s2, *args, **kwargs)

    if len(samples1) * len(samples2) > 0:
        d /= len(samples1) * len(samples2)
    else:
        d = 1e+6
    return d

def gaussian_tv(x, y, sigma=1.0):  
    support_size = max(len(x), len(y))
    x = x.astype(float)
    y = y.astype(float)
    if len(x) < len(y):
        x = np.hstack((x, [0.0] * (support_size - len(x))))
    elif len(y) < len(x):
        y = np.hstack((y, [0.0] * (support_size - len(y))))

    dist = np.abs(x - y).sum() / 2.0
    return np.exp(-dist * dist / (2 * sigma * sigma))

def gaussian(x, y, sigma=1.0):
    support_size = max(len(x), len(y))
    x = x.astype(float)
    y = y.astype(float)
    if len(x) < len(y):
        x = np.hstack((x, [0.0] * (support_size - len(x))))
    elif len(y) < len(x):
        y = np.hstack((y, [0.0] * (support_size - len(y))))

    dist = np.linalg.norm(x - y, 2)
    return np.exp(-dist * dist / (2 * sigma * sigma))

def spectral_worker(G, n_eigvals=-1):
    try:
        eigs = eigvalsh(nx.normalized_laplacian_matrix(G).todense())
    except:
        eigs = np.zeros(G.number_of_nodes())
    if n_eigvals > 0:
        eigs = eigs[1:n_eigvals + 1]
    spectral_pmf, _ = np.histogram(eigs, bins=100, range=(-1e-5, 2), density=False)
    spectral_pmf = spectral_pmf / spectral_pmf.sum()
    return spectral_pmf

motif_to_indices = {
    '3path': [1, 2],
    '4cycle': [8],
}
COUNT_START_STR = 'orbit counts:'

def orca(graph):
    tmp_fname = f'./orca/tmp_{"".join(secrets.choice(ascii_uppercase + digits) for i in range(8))}.txt'
    tmp_fname = os.path.join(os.path.dirname(os.path.realpath(__file__)), tmp_fname)
    #print(tmp_fname)
    f = open(tmp_fname, 'w')
    f.write(
        str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\n')
    for (u, v) in edge_list_reindexed(graph):
        f.write(str(u) + ' ' + str(v) + '\n')
    f.close()
    output = sp.check_output(
        [str(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'orca/orca')), 'node', '4', tmp_fname, 'std'])
    output = output.decode('utf8').strip()
    idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) + 2
    output = output[idx:]
    node_orbit_counts = np.array([
        list(map(int,
                 node_cnts.strip().split(' ')))
        for node_cnts in output.strip('\n').split('\n')
    ])

    try:
        os.remove(tmp_fname)
    except OSError:
        pass

    return node_orbit_counts


def edge_list_reindexed(G):
    idx = 0
    id2idx = dict()
    for u in G.nodes():
        id2idx[str(u)] = idx
        idx += 1

    edges = []
    for (u, v) in G.edges():
        edges.append((id2idx[str(u)], id2idx[str(v)]))
    return edges