import numpy as np
import networkx as nx
import random
import seaborn as sns
import matplotlib.pyplot as plt
from dodiscover.metrics import toporder_divergence
from tqdm import tqdm
import pandas as pd
import igraph as ig
from intersort import _sort_ranking, local_search_extended, score_ordering
from diffintersort import diffintersort

def random_graph(rng, p_edge, n_vars):
    # sample
    mat = rng.binomial(n=1, p=p_edge, size=(n_vars, n_vars)).astype(int) # bernoulli

    # make DAG by zeroing above diagonal; k=-1 indicates that diagonal is zero too
    dag = np.tril(mat, k=-1)

    # randomly permute
    p = rng.permutation(np.eye(n_vars).astype(int))
    dag = p.T @ dag @ p
    return dag

def create_graph(g):
    graph = nx.DiGraph()
    num_nodes = g.shape[0]
    graph.add_nodes_from(range(num_nodes))
    for i, j in zip(*np.where(g == 1)):
        graph.add_edge(i, j)
    return graph

def graph_to_mat(g):
    """Returns adjacency matrix of ig.Graph object """
    return np.array(g.get_adjacency().data).astype(int)

def generate_scale_free_dag(rng, edges_per_var, power, n_vars):
    _ = rng.normal()
    perm = rng.permutation(n_vars).tolist()
    g = ig.Graph.Barabasi(n=n_vars, m=edges_per_var, directed=True, power=power).permute_vertices(perm)
    mat = graph_to_mat(g)
    return mat

def compute_bound(g, p_int):
    graph = create_graph(g)

    bound = 0.0

    for e in graph.edges:
        i, j = e
        an_j = nx.ancestors(graph, j) | {j}
        an_i = nx.ancestors(graph, i)
        tot = an_j - an_i
        bound += (1 - p_int)**(len(tot))

    return bound 

import itertools
def brute_force(score_function, score_matrix, d):
    permutations = list(itertools.permutations(range(d)))
    # Shuffle the list of permutations
    random.shuffle(permutations)
    max_score = -1000
    candidate = permutations[0] 
    for perm in permutations:
        score = score_function(perm, score_matrix, d)
        if score > max_score:
            max_score = score
            candidate = perm
    return candidate

def transitive_closure(matrix):
    n = len(matrix)
    reach = np.array(matrix).astype(bool)
    
    for k in range(n):
        for i in range(n):
            for j in range(n):
                reach[i][j] = reach[i][j] or (reach[i][k] and reach[k][j])
    
    return reach.astype(int)

config = {
    3: {
        "p_edge_list": [0.5,],
        "skip_exact": False,
        "skip_approx": False,
        "repeat": 5,
        "num_sim": 3,
        "lr": 0.05,
    },
    5: {
        "p_edge_list": [0.5, 0.66, 0.75],
        "skip_exact": False,
        "skip_approx": False,
        "repeat": 10,
        "num_sim": 10,
        "lr": 0.05,
    },
    30: {
        "p_edge_list": [0.05, 0.1001, 0.2001],
        "skip_exact": True,
        "skip_approx": False,
        "repeat": 5,
        "num_sim": 5,
        "lr": 0.01,
    },
    100: {
        "p_edge_list": [0.02, 0.01, 0.005],
        "skip_exact": True,
        "skip_approx": True,
        "repeat": 3,
        "num_sim": 2,
        "lr": 0.001,
    },
    1000: {
        "p_edge_list": [0.005, 0.002, 0.001],
        "skip_exact": True,
        "skip_approx": True,
        "repeat": 1,
        "num_sim": 2,
        "lr": 0.0005,
    },
    2000: {
        "p_edge_list": [0.0001, 0.00005, 0.00002],
        "skip_exact": True,
        "skip_approx": True,
        "repeat": 1,
        "num_sim": 2,
        "lr": 0.0001,
    }
}

def main():
    p_int_list = [0.25, 0.33, 0.5, 0.66, 0.75]
    n_vars = 3
    local_config = config[n_vars]
    p_edge_list = local_config["p_edge_list"] 
    repeat = local_config["repeat"]
    scale_free = False
    rng = np.random.default_rng(seed=1000)
    num_sim = local_config["num_sim"]
    edges_per_var = [ 1, 2, 3 ]
    xs = []
    ys = []
    xs_2 = []
    xs_3 = []
    xs_4 = []
    scores = []
    scores_approx = []
    scores_diff = []
    upper_bound = []
    eps = 0.3
    skip_exact = local_config["skip_exact"]
    skip_approx = local_config["skip_approx"]
    

    for p_int in p_int_list:
        edges_densities = edges_per_var if scale_free else p_edge_list
        for edge_density in edges_densities:
            bounds = []
            for _ in tqdm(range(num_sim)):
                if scale_free:
                    g = generate_scale_free_dag(rng, edge_density, 1.0, n_vars)
                    p_edge = np.sum(g) / (n_vars * (n_vars - 1) / 2)
                else:
                    p_edge = edge_density
                    g = random_graph(rng, p_edge, n_vars)
                graph = create_graph(g)
                bound = compute_bound(g, p_int) 
                

                score_matrix = transitive_closure(g)
                
                score_matrix = score_matrix.astype(float)
                x = p_int / np.sqrt(p_edge)
                estimate = []
                estimate_approx = []
                estimate_diff = []
                for _ in range(repeat):
                    mask = rng.binomial(1, p_int, n_vars)
                    score_matrix_new = score_matrix.copy() * mask[:, None]
                    score_matrix_new = score_matrix_new  + 0.01 * mask[:, None]
                    if not skip_exact: 
                        intersort_pred = brute_force(score_ordering, score_matrix_new, n_vars)
                        intersort_score = toporder_divergence(graph, intersort_pred)
                        estimate.append(intersort_score)

                    
                    pred_sort_ranking, _ = _sort_ranking(score_matrix_new, eps)
                    topological_order_sortranking = list(nx.topological_sort(pred_sort_ranking)) 
                    if not skip_approx:
                        candidate = local_search_extended(topological_order_sortranking, score_matrix_new, n_vars)
                    else:
                        candidate = topological_order_sortranking
                    score_ordering_pagerank = toporder_divergence(graph, candidate)
                    
                    estimate_approx.append(score_ordering_pagerank)
                    
                    pred_diffintersort, _ = diffintersort(score_matrix_new, n_vars, topological_order_sortranking, n_iter=5000, lr=local_config["lr"], t_sinkhorn = 0.05, n_iter_sinkhorn=500, eps=eps, sinkhorn=False)
                    score_ordering_diff = toporder_divergence(graph, pred_diffintersort)
                    estimate_diff.append(score_ordering_diff)
                if not skip_exact:
                    scores.append(np.mean(estimate))
                    xs_2.append(x)
                
                xs_3.append(x)
                scores_approx.append(np.mean(estimate_approx))
                xs_4.append(x)
                scores_diff.append(np.mean(estimate_diff))
                xs.append(x)
                ys.append(bound)
                op = (1 - p_int * p_edge)
                inside = op - op**(n_vars + 1)
                inside /= p_int * p_edge
                upper_bound.append((n_vars - inside) * (1 - p_int)**2 / p_int)
                bounds.append(bound)
            sns.kdeplot(bounds, label=f'p_int={p_int}, p_edge={p_edge}', warn_singular=False)


    precision = 10000.
    xs = [int(x * precision) / precision for x in xs]
    xs_2 = [int(x * precision) / precision for x in xs_2]
    xs_3 = [int(x * precision) / precision for x in xs_3]
    xs_4 = [int(x * precision) / precision for x in xs_4]
    # Create a DataFrame from your data
    df = pd.DataFrame({
        'x': xs + xs + xs_3 + xs_4 + xs_2,  # repeat xs twice
        'y': upper_bound + ys + scores_approx + scores_diff + scores,  # concatenate ys and scores
        'label': ['Upperbound Thm 4'] * len(upper_bound)  + ['Upperbound Thm 2'] * len(ys) + ['Approx opt'] * len(scores_approx) + ['Diff opt'] * len(scores_diff) + ['Actual opt'] * len(scores)  # create labels
    })

    df.to_csv(f"results_sim_soft_{n_vars}.csv")

    # Create the bar plot with error bars
    plt.figure(figsize=(9,6))
    sns.barplot(x='x', y='y', hue='label', data=df) #, yerr=std['y'])
    plt.title(f'D_top of Intersort for {n_vars} variables')
    plt.xlabel('Effective intervention ratio',)
    plt.ylabel('Dtop')
    plt.ylim(bottom=0)
    plt.tight_layout()
    plt.grid(True)
    plt.legend()
    plt.savefig(f"./violin_bound_{n_vars}.png", dpi=400)
    plt.show()

if __name__ == "__main__":
    main()
