import numpy as np
import networkx as nx
import os
from scipy.linalg import eigvalsh
import torch
import matplotlib.pyplot as plt
from datetime import datetime
import secrets
from string import ascii_uppercase, digits
import subprocess as sp
import random

def generate_samples(model, input_graph, num_tokens, n=10, device="cuda", save=False):
  """
  generates graphs from trained model
  """
  current_time = datetime.now().strftime('%H-%M')

  generated_graphs = []
  print("Number of tokens  :", num_tokens)
  sizes = np.linspace(num_tokens, num_tokens+200, n, dtype=int)
  for i in range(n):
    start = np.random.randint(0, input_graph.number_of_nodes())
    context = torch.tensor((start), dtype=torch.long, device=device).view(1,1)
    series = model.generate(context, max_new_tokens=sizes[i])[0].tolist()

    generated_graph = series2graph(series)
    generated_graphs.append(generated_graph)

    if save:
      plt.figure()
      nx.draw_kamada_kawai(generated_graph, node_size=20)
      filename = f"./generated_figures/{current_time}_{i}.png"
      plt.savefig(filename)
      plt.close() 
    #print(f"Gen graph: {generated_graph.number_of_nodes()}")

  combined_graphs = []
  if input_graph.number_of_nodes() > 1000:
    for i in range(n):
      sampled = random.sample(generated_graphs, 4)
      generated_graph = nx.compose_all(sampled)
      combined_graphs.append(generated_graph)
    generated_graphs = combined_graphs

  return generated_graphs

def series2graph(series):
    """
    Create graph from sequence of nodes
    params:
        series (list): list of nodes
    return:
        generated (nx.Graph): networkx graph object 

    """
    generated = nx.Graph()
    for i in range(len(series)-1):
        if series[i] != series[i+1]:
            generated.add_edge(series[i], series[i+1])
    return generated


def mmd_evaluation(original_graph, generated_graphs):
    n = len(generated_graphs)
    #print(f"Original graph: {original_graph.number_of_nodes()}, {original_graph.number_of_edges()}")
    #print(f"Generated (mean): {sum([g.number_of_nodes() for g in generated_graphs])/n}, {sum([g.number_of_edges() for g in generated_graphs])/n}")
    if not isinstance(original_graph, list):
        original_graph = [original_graph]
    mmd_stats = mmd(original_graph, generated_graphs)
    print(f"Degree: {mmd_stats[0]}")
    print(f"Clustering: {mmd_stats[1]}")
    print(f"Spectral: {mmd_stats[2]}")
    print(f"Orbit: {mmd_stats[3]}")
    print(f"Motif: {mmd_stats[4]}")

def mmd(graph_gt, graph_pred, plot=False):
    mmd_degree = degree_stats(graph_gt, graph_pred, plot=plot)
    mmd_clustering = clustering_stats(graph_gt, graph_pred, plot=plot)
    mmd_spectral = spectral_stats(graph_gt, graph_pred, plot=plot)
    mmd_motif = motif_stats(graph_gt, graph_pred, plot=plot)
    mmd_orbit = orbit_stats_all(graph_gt, graph_pred, plot=plot)
    return mmd_degree, mmd_clustering, mmd_spectral, mmd_orbit, mmd_motif


def degree_stats(graph_ref_list, graph_pred_list, plot):
    sample_ref, sample_pred = [], []

    for i in range(len(graph_ref_list)):
        degree_temp = np.array(nx.degree_histogram(graph_ref_list[i]))
        sample_ref.append(degree_temp)

    for i in range(len(graph_pred_list)):
        degree_temp = np.array(nx.degree_histogram(graph_pred_list[i]))
        sample_pred.append(degree_temp)

    mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv, plot=plot)
    return mmd_dist


def clustering_stats(graph_ref_list, graph_pred_list, plot, bins=100):
    sample_ref, sample_pred = [], []

    for i in range(len(graph_ref_list)):
        clustering_coeff_list = list(nx.clustering(graph_ref_list[i]).values())
        hist, _ = np.histogram(clustering_coeff_list, bins=bins, range=(0.0,1.0), density=False)
        sample_ref.append(hist)

    for i in range(len(graph_pred_list)):
        clustering_coeff_list = list(nx.clustering(graph_pred_list[i]).values())
        hist, _ = np.histogram(clustering_coeff_list, bins=bins, range=(0.0, 1.0), density=False)
        sample_pred.append(hist)

    mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv, sigma=1.0 / 10, plot=plot)
    return mmd_dist

motif_to_indices = {
    '3path': [1, 2],
    '4cycle': [8],
}
COUNT_START_STR = 'orbit counts:'

def motif_stats(graph_ref_list, graph_pred_list, plot, motif_type='3path', ground_truth_match=None):

    total_counts_ref = []
    total_counts_pred = []

    num_matches_ref = []
    num_matches_pred = []

    graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0]
    indices = motif_to_indices[motif_type]

    for G in graph_ref_list:
        orbit_counts = orca(G)
        motif_counts = np.sum(orbit_counts[:, indices], axis=1)

        if ground_truth_match is not None:
            match_cnt = 0
            for elem in motif_counts:
                if elem == ground_truth_match:
                    match_cnt += 1
            num_matches_ref.append(match_cnt / G.number_of_nodes())

        motif_temp = np.sum(motif_counts) / G.number_of_nodes()
        total_counts_ref.append(motif_temp)

    for G in graph_pred_list_remove_empty:
        orbit_counts = orca(G)
        motif_counts = np.sum(orbit_counts[:, indices], axis=1)

        if ground_truth_match is not None:
            match_cnt = 0
            for elem in motif_counts:
                if elem == ground_truth_match:
                    match_cnt += 1
            num_matches_pred.append(match_cnt / G.number_of_nodes())

        motif_temp = np.sum(motif_counts) / G.number_of_nodes()
        total_counts_pred.append(motif_temp)

    total_counts_ref = np.array(total_counts_ref)[:, None]
    total_counts_pred = np.array(total_counts_pred)[:, None]

    mmd_dist = compute_mmd(total_counts_ref, total_counts_pred, kernel=gaussian_tv, is_hist=False, plot=plot)
    return mmd_dist

def orbit_stats_all(graph_ref_list, graph_pred_list, plot):

    total_counts_ref = []
    total_counts_pred = []

    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    for G in graph_ref_list:
        orbit_counts = orca(G)
        orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
        total_counts_ref.append(orbit_counts_graph)

    for G in graph_pred_list_remove_empty:
        orbit_counts = orca(G)
        orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
        total_counts_pred.append(orbit_counts_graph)

    total_counts_ref = np.array(total_counts_ref)
    total_counts_pred = np.array(total_counts_pred)

    mmd_dist = compute_mmd(total_counts_ref, total_counts_pred, kernel=gaussian_tv, is_hist=False, sigma=30.0, plot=plot)
    return mmd_dist



def spectral_worker(G):
    eigs = eigvalsh(nx.normalized_laplacian_matrix(G).todense())
    spectral_pmf, _ = np.histogram(eigs, bins=200, range=(-1e-5, 2), density=False)
    spectral_pmf = spectral_pmf / spectral_pmf.sum()
    return spectral_pmf

def spectral_stats(graph_ref_list, graph_pred_list, plot):
    sample_ref = []
    sample_pred = []

    for i in range(len(graph_ref_list)):
        spectral_temp = spectral_worker(graph_ref_list[i])
        sample_ref.append(spectral_temp)
    for i in range(len(graph_pred_list)):
        spectral_temp = spectral_worker(graph_pred_list[i])
        sample_pred.append(spectral_temp)

    mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv, plot=plot)
    return mmd_dist


def edge_list_reindexed(G):
    idx = 0
    id2idx = dict()
    for u in G.nodes():
        id2idx[str(u)] = idx
        idx += 1
    
    edges = []
    for (u,v) in G.edges():
        edges.append((id2idx[str(u)], id2idx[str(v)]))
    return edges

def compute_mmd(samples1, samples2, kernel, is_hist=True, plot=False, *args, **kwargs):
    if is_hist:
        samples1 = [s1 / np.sum(s1) for s1 in samples1]
        samples2 = [s2 / np.sum(s2) for s2 in samples2]
    
    # two histograms side by side

    if plot:
        fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

        # First plot
        axes[0].hist(samples1, bins=30, alpha=0.7, label='Original')
        axes[0].set_xlabel('Value')
        axes[0].set_ylabel('Frequency')
        axes[0].set_title('Original Histogram')
        axes[0].legend()

        # Second plot
        axes[1].hist(samples2, bins=30, alpha=0.7, label='Generated')
        axes[1].set_xlabel('Value')
        axes[1].set_title('Generated Histogram')
        axes[1].legend()

        plt.tight_layout()
        plt.show()
        

    return disc(samples1, samples1, kernel, *args, **kwargs) + disc(samples2, samples2, kernel, *args, **kwargs) -\
        2 * disc(samples1, samples2, kernel, *args, **kwargs)


def disc(samples1, samples2, kernel, *args, **kwargs):
    d = 0
    for s1 in samples1:
        for s2 in samples2:
            d += kernel(s1, s2, *args, **kwargs)
    d /= len(samples1) * len(samples2)
    return d


def gaussian_tv(x, y, sigma=1.0):
    support_size = max(len(x), len(y))
    x = x.astype(np.float32)
    y = y.astype(np.float32)
    if len(x) < len(y):
        x = np.hstack((x, [0.0] * (support_size - len(x))))
    elif len(y) < len(x):
        y = np.hstack((y, [0.0] * (support_size - len(y))))
    
    dist = np.abs(x - y).sum() / 2.0
    return np.exp(-dist * dist / (2 * sigma * sigma))


def orca(graph):
    tmp_fname = f'./orca/tmp_{"".join(secrets.choice(ascii_uppercase + digits) for i in range(8))}.txt'
    tmp_fname = os.path.join(os.path.dirname(os.path.realpath(__file__)), tmp_fname)
    #print(tmp_fname)
    f = open(tmp_fname, 'w')
    f.write(
        str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\n')
    for (u, v) in edge_list_reindexed(graph):
        f.write(str(u) + ' ' + str(v) + '\n')
    f.close()
    output = sp.check_output(
        [str(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'orca/orca')), 'node', '4', tmp_fname, 'std'])
    output = output.decode('utf8').strip()
    idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) + 2
    output = output[idx:]
    node_orbit_counts = np.array([
        list(map(int,
                 node_cnts.strip().split(' ')))
        for node_cnts in output.strip('\n').split('\n')
    ])

    try:
        os.remove(tmp_fname)
    except OSError:
        pass

    return node_orbit_counts

def edge_list_reindexed(G):
    idx = 0
    id2idx = dict()
    for u in G.nodes():
        id2idx[str(u)] = idx
        idx += 1

    edges = []
    for (u, v) in G.edges():
        edges.append((id2idx[str(u)], id2idx[str(v)]))
    return edges