import numpy as np
import networkx as nx
import random
import seaborn as sns
import argparse
import matplotlib.pyplot as plt
from dodiscover.metrics import toporder_divergence
from tqdm import tqdm
import pandas as pd
import igraph as ig
from intersort import _sort_ranking, local_search_extended, score_ordering
from diffintersort import diffintersort
import time
import os

def generate_generalized_BA_DAG(n, m0, m, kappa, seed=None):
    if seed is not None:
        np.random.seed(seed)

    # Start with m0 isolated nodes
    edges = []
    in_degrees = np.zeros(n, dtype=int)  # In-degree only matters (for P(i))

    for new_node in range(m0, n):
        # Eligible older nodes: 0 .. new_node - 1
        eligible_nodes = np.arange(new_node)
        attractiveness = in_degrees[:new_node] + kappa
        probs = attractiveness.astype(float)
        probs = probs / probs.sum() 
        print(probs)
        # Sample m unique parents without replacement
        parents = np.random.choice(eligible_nodes, size=min(m, new_node), replace=False, p=probs)

        # Add directed edges: parent → new_node (DAG)
        for parent in parents:
            edges.append((parent, new_node))
            in_degrees[parent] += 1

    # Create the graph using igraph
    g = ig.Graph(n=n, edges=edges, directed=True)
    return graph_to_mat(g)

def random_graph(rng, p_edge, n_vars):
    # sample
    mat = rng.binomial(n=1, p=p_edge, size=(n_vars, n_vars)).astype(int) # bernoulli

    # make DAG by zeroing above diagonal; k=-1 indicates that diagonal is zero too
    dag = np.tril(mat, k=-1)

    # randomly permute
    p = rng.permutation(np.eye(n_vars).astype(int))
    dag = p.T @ dag @ p
    return dag

def create_graph(g):
    graph = nx.DiGraph()
    num_nodes = g.shape[0]
    graph.add_nodes_from(range(num_nodes))
    for i, j in zip(*np.where(g == 1)):
        graph.add_edge(i, j)
    return graph

def graph_to_mat(g):
    """Returns adjacency matrix of ig.Graph object """
    return np.array(g.get_adjacency().data).astype(int)

def generate_scale_free_dag(rng, edges_per_var, power, n_vars):
    _ = rng.normal()
    perm = rng.permutation(n_vars).tolist()
    g = ig.Graph.Barabasi(n=n_vars, m=edges_per_var, directed=True, power=power).permute_vertices(perm)
    mat = graph_to_mat(g)
    return mat

def compute_bound(g, p_int):
    graph = create_graph(g)

    bound = 0.0

    for e in graph.edges:
        i, j = e
        an_j = nx.ancestors(graph, j) | {j}
        an_i = nx.ancestors(graph, i)
        tot = an_j - an_i
        bound += (1 - p_int)**(len(tot))

    return bound 

import itertools
def brute_force(score_function, score_matrix, d):
    permutations = list(itertools.permutations(range(d)))
    # Shuffle the list of permutations
    random.shuffle(permutations)
    max_score = -1000
    candidate = permutations[0] 
    for perm in permutations:
        score = score_function(perm, score_matrix, d)
        if score > max_score:
            max_score = score
            candidate = perm
    return candidate

def transitive_closure(matrix):
    n = len(matrix)
    reach = np.array(matrix).astype(bool)
    
    for k in range(n):
        for i in range(n):
            for j in range(n):
                reach[i][j] = reach[i][j] or (reach[i][k] and reach[k][j])
    
    return reach.astype(int)

config = {
    3: {
        "num_sim": 10,
        "lr": 0.05,
    },
    5: {
        "num_sim": 10,
        "lr": 0.05,
    },
    10: {
        "num_sim": 10,
        "lr": 0.05,
    },
    30: {
        "num_sim": 10,
        "lr": 0.01,
    },
    50: {
        "num_sim": 10,
        "lr": 0.01,
    },
    100: {
        "num_sim": 10,
        "lr": 0.001,
    },
    200: {
        "num_sim": 10,
        "lr": 0.001,
    },
    300: {
        "num_sim": 10,
        "lr": 0.001,
    },
    400: {
        "num_sim": 10,
        "lr": 0.001,
    },
    500: {
        "num_sim": 10,
        "lr": 0.0005,
    },
    1000: {
        "num_sim": 10,
        "lr": 0.0005,
    },
    2000: {
        "num_sim": 10,
        "lr": 0.0001,
    }
}

def main():
    parser = argparse.ArgumentParser(description='Run simulation experiments for DiffIntersort.')
    parser.add_argument('--n_vars', type=int, default=10, help='Number of variables (nodes) in the graph')
    parser.add_argument('--p_int', type=float, default=0.5, help='Probability of intervention')
    parser.add_argument('--num_sim', type=int, default=10, help='Number of simulations to run')
    parser.add_argument('--repeat', type=int, default=10, help='Number of repetitions for each simulation')
    parser.add_argument('--edges_per_var', type=int, default=3, help='Number of edges per variable for graph generation')
    parser.add_argument('--kappa', type=float, default=1.0, help='Attractiveness parameter for the generalized BA model')
    parser.add_argument('--lr', type=float, default=0.01, help='Learning rate for diffintersort')
    parser.add_argument('--t_sinkhorn', type=float, default=0.05, help='Temperature parameter for the Sinkhorn operator')
    parser.add_argument('--n_iter_sinkhorn', type=int, default=500, help='Number of iterations for the Sinkhorn operator')
    parser.add_argument('--eps', type=float, default=0.3, help='Epsilon parameter for score_ordering')
    parser.add_argument('--graph_type', type=str, default='scale_free', choices=['scale_free', 'er', 'scale_free_er'], help='Type of graph to generate')
    parser.add_argument('--p_edge', type=float, default=0.3, help='Probability of edge for ER graphs')
    parser.add_argument('--output_dir', type=str, default='results', help='Directory to save the results')
    args = parser.parse_args()

    rng = np.random.default_rng(seed=1000)
    scores_diff = []
    p_ints = []
    cs = []
    n_edges = []
    times = []
    final_scores = []

    for _ in tqdm(range(args.num_sim)):
        if args.graph_type == 'scale_free':
            g = generate_generalized_BA_DAG(args.n_vars, 1, args.edges_per_var, args.kappa)
            output_path = f'{args.output_dir}/results_type_{args.graph_type}_n{args.n_vars}_p{args.p_int}_e{args.edges_per_var}_k{args.kappa}.csv'
        elif args.graph_type == 'er':
            g = random_graph(rng, args.p_edge, args.n_vars)
            output_path = f'{args.output_dir}/results_type_{args.graph_type}_n{args.n_vars}_p{args.p_int}_p_edge{args.p_edge}.csv'
        elif args.graph_type == 'scale_free_er':
            p_edge = args.edges_per_var / args.n_vars
            g = random_graph(rng, p_edge, args.n_vars)
            output_path = f'{args.output_dir}/results_type_{args.graph_type}_n{args.n_vars}_p{args.p_int}_e{args.edges_per_var}.csv'
        
        graph = create_graph(g)
        n_edge = g.sum()

        for _ in range(args.repeat):
            start_time = time.time()
            score_matrix = np.array(g).astype(float)
            mask = rng.binomial(1, args.p_int, args.n_vars)
            score_matrix_new = score_matrix.copy() * mask[:, None]
            score_matrix_new = score_matrix_new + 0.01 * mask[:, None]

            pred_sort_ranking, _ = _sort_ranking(score_matrix_new, args.eps)
            topological_order_sortranking = list(nx.topological_sort(pred_sort_ranking))

            pred_diffintersort, final_score = diffintersort(
                score_matrix_new, args.n_vars, topological_order_sortranking, 
                n_iter=5000, lr=args.lr, t_sinkhorn=args.t_sinkhorn, 
                n_iter_sinkhorn=args.n_iter_sinkhorn, eps=args.eps
            )

            score_ordering_diff = toporder_divergence(graph, pred_diffintersort)
            
            end_time = time.time()
            times.append(end_time - start_time)
            scores_diff.append(score_ordering_diff)
            final_scores.append(final_score)
            cs.append(args.edges_per_var)
            p_ints.append(args.p_int)
            n_edges.append(n_edge)

    df = pd.DataFrame({
        'c': cs,
        'p_int': p_ints,
        'n_edge': n_edges,
        'D_top': scores_diff,
        'time': times,
        'final_score': final_scores,
    })

    
    os.makedirs(args.output_dir, exist_ok=True)
    df.to_csv(output_path, index=False)


if __name__ == "__main__":
    main()
