import sys
import os
sys.path.append((os.path.abspath('../utils')))
import numpy as np
from grakel import Graph
import networkx as nx
from Torelli import *
import mgraph as mg
import time

# create synthetic graphs

def generate_random_graph(n, max_edges=None):
    """
    Generate a random connected graph with n nodes.
    max_edges: Maximum number of edges in the graph. If None, defaults to n.
    """
    
    # initialize the graph
    G = mg.MetricGraph()
    G.add_nodes_from(range(n))

    # create a line graph
    for i in range(n-1):
        # round the weight for better visualization
        weight = round(np.random.rand(), 2)  
        G.add_edge(i, i+1, weight=weight)
    
    # Randomly connect edges
    # Add additional random edges (up to max_edges if specified)
    max_edges = max_edges or n
    while G.number_of_edges() < max_edges:
        node1, node2 = np.random.choice(range(n), size=2, replace=False)
        if not G.has_edge(node1, node2):
            weight = round(np.random.rand(),2)
            G.add_edge(node1, node2, weight=weight)
    
    return G

def to_grakel_graphs(class_list):
    '''
    save the graphs in grakel format
    
    class list: list of lists of graphs, each list corresponds to a class'''

    grakel_graphs = []
    for G_list in class_list:
        for G in G_list:
            # Convert the graph to a format suitable for grakel
            grakel_graph = Graph(nx.adjacency_matrix(G, weight='weight').toarray())
            grakel_graphs.append(grakel_graph)

    return grakel_graphs

node_list = np.arange(50, 151, 10) # number of nodes in the graph
N = 5
rep = 5

s_coeff_list = [5, 10, 20]
ss_coeff_list = [1.0, 1.5, 2.0]
d_coeff_list = [0.1, 0.2, 0.3]

s_f_time_mean =[]
s_f_time_std = []
s_e_time_mean =[]
s_e_time_std = []
s_w_time_mean =[]
s_w_time_std = []

ss_f_time_mean =[]
ss_f_time_std = []
ss_e_time_mean =[]
ss_e_time_std = []
ss_w_time_mean =[]
ss_w_time_std = []

d_f_time_mean =[]
d_f_time_std = []
d_e_time_mean =[]
d_e_time_std = []
d_w_time_mean =[]
d_w_time_std = []

for n in node_list:
    # sparse graphs
    s_f_time_temp_mean = []
    s_f_time_temp_std = []
    s_e_time_temp_mean = []
    s_e_time_temp_std = []
    s_w_time_temp_mean = []
    s_w_time_temp_std = []
    for s_coeff in s_coeff_list:
        g = s_coeff
        s_f_time_temp = []
        s_e_time_temp = []
        s_w_time_temp = []
        for i in range(rep):
            # generate N graphs
            G_list = [generate_random_graph(n, n+g-1) for _ in range(N)]
            grakel_G = to_grakel_graphs([G_list])
            kernel_w = TorelliWasserstein(dimbound=g)
            kernel_e = TorelliEuclidean(dimbound=g)
            # fit time
            t0 = time.time()
            kernel_e.fit(grakel_G)
            t1 = time.time()
            kernel_w.fit(grakel_G)
            t2 = time.time()
            # transform time
            mat = kernel_e.transform(grakel_G)
            t3 = time.time()
            mat2 = kernel_w.transform(grakel_G)
            t4 = time.time()
            s_f_time_temp.append(t1-t0)
            s_e_time_temp.append(t3-t2)
            s_w_time_temp.append(t4-t3)
        s_f_time_temp_mean.append(np.mean(s_f_time_temp))
        s_f_time_temp_std.append(np.std(s_f_time_temp))
        s_e_time_temp_mean.append(np.mean(s_e_time_temp))
        s_e_time_temp_std.append(np.std(s_e_time_temp))
        s_w_time_temp_mean.append(np.mean(s_w_time_temp))
        s_w_time_temp_std.append(np.std(s_w_time_temp))
    s_f_time_mean.append(s_f_time_temp_mean)
    s_f_time_std.append(s_f_time_temp_std)
    s_e_time_mean.append(s_e_time_temp_mean)
    s_e_time_std.append(s_e_time_temp_std)
    s_w_time_mean.append(s_w_time_temp_mean)
    s_w_time_std.append(s_w_time_temp_std)
 
    # semi-sparse graphs
    ss_f_time_temp_mean = []
    ss_f_time_temp_std = []
    ss_e_time_temp_mean = []
    ss_e_time_temp_std = []
    ss_w_time_temp_mean = []
    ss_w_time_temp_std = []
    for ss_coeff in ss_coeff_list:
        g = round(ss_coeff * n)
        ss_f_time_temp = []
        ss_e_time_temp = []
        ss_w_time_temp = []
        for i in range(rep):
            # generate N graphs
            G_list = [generate_random_graph(n, n+g-1) for _ in range(N)]
            grakel_G = to_grakel_graphs([G_list])
            kernel_w = TorelliWasserstein(dimbound=g)
            kernel_e = TorelliEuclidean(dimbound=g)
            # fit time
            t0 = time.time()
            kernel_e.fit(grakel_G)
            t1 = time.time()
            kernel_w.fit(grakel_G)
            t2 = time.time()
            # transform time
            mat = kernel_e.transform(grakel_G)
            t3 = time.time()
            mat2 = kernel_w.transform(grakel_G)
            t4 = time.time()
            ss_f_time_temp.append(t1-t0)
            ss_e_time_temp.append(t3-t2)
            ss_w_time_temp.append(t4-t3)
        ss_f_time_temp_mean.append(np.mean(ss_f_time_temp))
        ss_f_time_temp_std.append(np.std(ss_f_time_temp))
        ss_e_time_temp_mean.append(np.mean(ss_e_time_temp))
        ss_e_time_temp_std.append(np.std(ss_e_time_temp))
        ss_w_time_temp_mean.append(np.mean(ss_w_time_temp))
        ss_w_time_temp_std.append(np.std(ss_w_time_temp))
    ss_f_time_mean.append(ss_f_time_temp_mean)
    ss_f_time_std.append(ss_f_time_temp_std)
    ss_e_time_mean.append(ss_e_time_temp_mean)
    ss_e_time_std.append(ss_e_time_temp_std)
    ss_w_time_mean.append(ss_w_time_temp_mean)
    ss_w_time_std.append(ss_w_time_temp_std)

    # dense graphs
    d_f_time_temp_mean = []
    d_f_time_temp_std = []
    d_e_time_temp_mean = []
    d_e_time_temp_std = []
    d_w_time_temp_mean = []
    d_w_time_temp_std = []
    for d_coeff in d_coeff_list:
        g = round(2*n**(1+d_coeff))
        d_f_time_temp = []
        d_e_time_temp = []
        d_w_time_temp = []
        for i in range(rep):
            # generate N graphs
            G_list = [generate_random_graph(n, n+g-1) for _ in range(N)]
            grakel_G = to_grakel_graphs([G_list])
            kernel_w = TorelliWasserstein(dimbound=g)
            kernel_e = TorelliEuclidean(dimbound=g)
            # fit time
            t0 = time.time()
            kernel_e.fit(grakel_G)
            t1 = time.time()
            kernel_w.fit(grakel_G)
            t2 = time.time()
            # transform time
            mat = kernel_e.transform(grakel_G)
            t3 = time.time()
            mat2 = kernel_w.transform(grakel_G)
            t4 = time.time()
            d_f_time_temp.append(t1-t0)
            d_e_time_temp.append(t3-t2)
            d_w_time_temp.append(t4-t3)
        d_f_time_temp_mean.append(np.mean(d_f_time_temp))
        d_f_time_temp_std.append(np.std(d_f_time_temp))
        d_e_time_temp_mean.append(np.mean(d_e_time_temp))
        d_e_time_temp_std.append(np.std(d_e_time_temp))
        d_w_time_temp_mean.append(np.mean(d_w_time_temp))
        d_w_time_temp_std.append(np.std(d_w_time_temp))
    d_f_time_mean.append(d_f_time_temp_mean)
    d_f_time_std.append(d_f_time_temp_std)
    d_e_time_mean.append(d_e_time_temp_mean)
    d_e_time_std.append(d_e_time_temp_std)
    d_w_time_mean.append(d_w_time_temp_mean)
    d_w_time_std.append(d_w_time_temp_std)

path = os.path.abspath('../results/time_test/time.npz')
np.savez(path,
         s_f_time_mean=s_f_time_mean, s_f_time_std=s_f_time_std,
        s_e_time_mean=s_e_time_mean, s_e_time_std=s_e_time_std,
        s_w_time_mean=s_w_time_mean, s_w_time_std=s_w_time_std,
        ss_f_time_mean=ss_f_time_mean, ss_f_time_std=ss_f_time_std,
        ss_e_time_mean=ss_e_time_mean, ss_e_time_std=ss_e_time_std,
        ss_w_time_mean=ss_w_time_mean, ss_w_time_std=ss_w_time_std,
        d_f_time_mean=d_f_time_mean, d_f_time_std=d_f_time_std,
        d_e_time_mean=d_e_time_mean, d_e_time_std=d_e_time_std,
        d_w_time_mean=d_w_time_mean, d_w_time_std=d_w_time_std)