import numpy as np
import networkx as nx
import random
import seaborn as sns
import matplotlib.pyplot as plt
from dodiscover.metrics import toporder_divergence
from tqdm import tqdm
import pandas as pd

def random_graph(rng, p_edge, n_vars):
    # sample
    mat = rng.binomial(n=1, p=p_edge, size=(n_vars, n_vars)).astype(int) # bernoulli

    # make DAG by zeroing above diagonal; k=-1 indicates that diagonal is zero too
    dag = np.tril(mat, k=-1)

    # randomly permute
    p = rng.permutation(np.eye(n_vars).astype(int))
    dag = p.T @ dag @ p
    return dag

def create_graph(g):
    graph = nx.DiGraph()
    num_nodes = g.shape[0]
    graph.add_nodes_from(range(num_nodes))
    for i, j in zip(*np.where(g == 1)):
        graph.add_edge(i, j)
    return graph

def compute_bound(g, p_int):
    graph = create_graph(g)

    bound = 0.0

    for e in graph.edges:
        i, j = e
        an_j = nx.ancestors(graph, j) | {j}
        an_i = nx.ancestors(graph, i)
        tot = an_j - an_i
        bound += (1 - p_int)**(len(tot))

    return bound 

def score_ordering(score_matrix, topological_order, d, eps=0.3):
    tot = 0
    before = list()
    after = list(range(d))
    for i in topological_order:
        after.remove(i)
        if np.any(score_matrix[i, :] > 0.0):
            positive = np.sum(score_matrix[i, after] - eps)
            tot += positive
        before.append(i)
    return tot

def sort_ranking(score_matrix, lmbda):
    flat_array = score_matrix.flatten()
    G = nx.DiGraph()

    # Argsort on the flattened array
    sorted_flat_indices = np.argsort(-flat_array)

    # Mapping flat indices back to (i, j) format
    rows, cols = score_matrix.shape
    G.add_nodes_from(range(cols))
    sorted_indices_ij = np.unravel_index(sorted_flat_indices, (rows, cols))

    for k in range(len(sorted_indices_ij[0])):
        i, j = sorted_indices_ij[0][k], sorted_indices_ij[1][k]
        if i != j:
            score = score_matrix[i, j]
            if score > lmbda: 
                G.add_edge(i, j)
                if not nx.is_directed_acyclic_graph(G):
                    G.remove_edge(i, j)
            else:
                break
    W = np.zeros((cols, cols))
    for i, nbrdict in G.adjacency():
        for j in nbrdict.keys():
            W[i, j] = 1
    
    return W

 

def move_variable(perm, from_index, to_index):
    """Move a variable from from_index to to_index in the permutation."""
    if from_index == to_index:  # No move needed
        return perm
    new_perm = perm.copy()
    new_perm.insert(to_index, new_perm.pop(from_index))
    return new_perm

def generate_all_possible_moves(perm):
    """Generate all possible moves of a variable to any position."""
    moves = []
    for i in range(len(perm)):
        for j in range(len(perm)):
            if i != j:
                # Generate a move by placing i-th element to j-th position
                moved_perm = move_variable(perm, i, j)
                moves.append(moved_perm)
    return moves

def local_search_extended(initial_perm, score_matrix, d, score_function):
    """Perform local search with an extended neighborhood definition."""
    current_perm = initial_perm
    current_score = score_function(score_matrix, current_perm, d)
    while True:
        all_moves = generate_all_possible_moves(current_perm)
        next_perm = None
        for move in all_moves:
            move_score = score_function(score_matrix, move, d,)
            if move_score > current_score:  # Assuming we want to maximize the score
                next_perm = move
                current_score = move_score
                break  # Exit early if a better move is found
        if next_perm is None:
            break  # No improvement found
        current_perm = next_perm
    return current_perm


import itertools
def brute_force(score_function, score_matrix, d):
    permutations = list(itertools.permutations(range(d)))
    # Shuffle the list of permutations
    random.shuffle(permutations)
    max_score = -1000
    candidate = permutations[0] 
    for perm in permutations:
        score = score_function(score_matrix, perm, d)
        if score > max_score:
            max_score = score
            candidate = perm
    return candidate

def transitive_closure(matrix):
    n = len(matrix)
    reach = np.array(matrix).astype(bool)
    
    for k in range(n):
        for i in range(n):
            for j in range(n):
                reach[i][j] = reach[i][j] or (reach[i][k] and reach[k][j])
    
    return reach.astype(int)

def main():
    p_int_list = [0.25, 0.33, 0.5, 0.66, 0.75]
    n_vars = 30
    p_edge_list = [0.05, 0.1001, 0.2001]  #[0.5, 0.66, 0.75] for 5 variables
    repeat = 10
    rng = np.random.default_rng(seed=1000)
    num_sim = 20
    xs = []
    ys = []
    xs_2 = []
    xs_3 = []
    scores = []
    scores_approx = []
    upper_bound = []
    eps = 0.3
    skip_exact = True
    skip_approx = False
    

    for p_int in p_int_list:
        for p_edge in p_edge_list:
            bounds = []
            for _ in tqdm(range(num_sim)):
                g = random_graph(rng, p_edge, n_vars)
                graph = create_graph(g)
                bound = compute_bound(g, p_int) 

                score_matrix = transitive_closure(g)
                
                score_matrix = score_matrix.astype(float)
                x = p_int / np.sqrt(p_edge)
                estimate = []
                estimate_approx = []
                for _ in range(repeat):
                    mask = rng.binomial(1, p_int, n_vars)
                    score_matrix_new = score_matrix.copy() * mask[:, None]
                    score_matrix_new = score_matrix_new  + 0.01 * mask[:, None]
                    if not skip_exact: 
                        intersort_pred = brute_force(score_ordering, score_matrix_new, n_vars)
                        intersort_score = toporder_divergence(graph, intersort_pred)
                        estimate.append(intersort_score)

                    if not skip_approx:
                        pred_sort_ranking = sort_ranking(score_matrix_new, eps)
                        topological_order_sortranking = list(nx.topological_sort(create_graph(pred_sort_ranking))) 
                        candidate = local_search_extended(topological_order_sortranking, score_matrix_new, n_vars, score_ordering)
                        score_ordering_pagerank = toporder_divergence(graph, candidate)
                        estimate_approx.append(score_ordering_pagerank)
                if not skip_exact:
                    scores.append(np.mean(estimate))
                    xs_2.append(x)
                if not skip_approx:
                    xs_3.append(x)
                    scores_approx.append(np.mean(estimate_approx))
                
                xs.append(x)
                ys.append(bound)
                op = (1 - p_int * p_edge)
                inside = op - op**(n_vars + 1)
                inside /= p_int * p_edge
                upper_bound.append((n_vars - inside) * (1 - p_int)**2 / p_int)
                bounds.append(bound)
            sns.kdeplot(bounds, label=f'p_int={p_int}, p_edge={p_edge}', warn_singular=False)

    plt.title(f'Bound on expected FNR for {n_vars} variables')
    plt.xlabel('Value')
    plt.ylabel('Density')
    plt.grid(True)
    plt.legend()
    plt.savefig("./pdf.png", dpi=400)
    plt.show()

    precision = 10000.
    xs = [int(x * precision) / precision for x in xs]
    xs_2 = [int(x * precision) / precision for x in xs_2]
    xs_3 = [int(x * precision) / precision for x in xs_3]

    # Create a DataFrame from your data
    df = pd.DataFrame({
        'x': xs + xs + xs_3 + xs_2,  # repeat xs twice
        'y': upper_bound + ys + scores_approx + scores,  # concatenate ys and scores
        'label': ['Upperbound Thm 4'] * len(upper_bound)  + ['Upperbound Thm 2'] * len(ys) + ['Approx opt'] * len(scores_approx) + ['Actual opt'] * len(scores)  # create labels
    })

    df.to_csv(f"results_sim_{n_vars}.csv")

    # Create the bar plot with error bars
    plt.figure(figsize=(9,6))
    sns.barplot(x='x', y='y', hue='label', data=df) #, yerr=std['y'])
    plt.title(f'D_top of Intersort for {n_vars} variables')
    plt.xlabel('Effective intervention ratio',)
    plt.ylabel('Dtop')
    #plt.yticks([i for i in range(0, n_vars, 5)])
    plt.ylim(bottom=0)
    plt.tight_layout()
    plt.grid(True)
    plt.legend()
    plt.savefig("./violin_bound.png", dpi=400)
    plt.show()

if __name__ == "__main__":
    main()
