import random
import pandas as pd
import numpy as np
import pyAgrum as gum
from itertools import combinations
from scipy.stats import chisquare
import networkx as nx
import pickle
import os
from Utils import *

# Fix random seed
#seed = 42
#gum.initRandom(seed=seed)
#np.random.seed(seed)
#random.seed(seed)

def random_complete_dag(n):
    """
    Generate a random *complete* DAG on n nodes (0..n-1) using:
      1. Random permutation (topological order) of the nodes.
      2. For each pair (i, j), orient it in the direction of their order.
         i.e., if perm.index(i) < perm.index(j), we have i->j,
         else j->i.

    Return a dict:
      {
        'directed': [(u,v), ...]   # all oriented edges
      }
    so that for every distinct pair (i,j), exactly one directed edge is present.
    """
    # 1) Create a random permutation of nodes: topological order
    nodes = list(range(n))
    random.shuffle(nodes)  # in-place shuffle

    # 2) For each pair (i, j), if nodes.index(i) < nodes.index(j), i->j; else j->i
    directed_edges = []
    # We can do this more efficiently by tracking the positions in a dict
    pos = {nodes[i]: i for i in range(n)}  # pos[x] = index of x in the permutation

    for i in range(n):
        for j in range(i + 1, n):
            pass

    # It's clearer to do it with the node labels directly:
    directed_edges = []
    for a in range(n):
        for b in range(a + 1, n):
            # Compare positions
            if pos[a] < pos[b]:
                directed_edges.append((a, b))  # a->b
            else:
                directed_edges.append((b, a))  # b->a

    return {
        'directed': directed_edges
    }


def add_bidirected_edges(dag, n, p=0.5, fix = False):
    """
    Take a *complete DAG* in dict form, i.e. {'directed': [...]}
    plus the node count n,
    then randomly add bidirected edges for each unordered pair (i<j),
    each with probability p.

    Return a dict:
      {
        'directed':   dag['directed'],
        'bidirected': [(x,y), ...]
      }
    where each bidirected edge (x<y) is chosen with probability p.
    """
    import random

    directed_edges = dag['directed']
    bidirected_edges = []

    if fix:
        # List all pairs, randomly pick n-1 from them
        pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
        random.shuffle(pairs)
        for i in range(n-1):
            bidirected_edges.append(pairs[i])
    else:
        # For each unordered pair i<j
        for i in range(n):
            for j in range(i + 1, n):
                # Decide with probability p if we add i<->j
                if random.random() < p:
                    bidirected_edges.append((i, j))

    return {
        'directed': directed_edges,
        'bidirected': bidirected_edges
    }

if __name__ == "__main__":
    n = 2
    k = 2
    epsilon = 0.01
    delta = 0.01
    n_admg = int(- np.log(delta)/(2*epsilon**2) )
    n_experiment = 30
    c_hard_list= []
    c_soft_list = []

    for i in range(n_experiment):
        print('Experiment: ', i)
        # Step 1: Make a random complete DAG
        dag = random_complete_dag(n)
        #print("Random complete DAG:\n", dag)

        # Step 2: Add bidirected edges with probability 0.3
        admg_true = add_bidirected_edges(dag, n, p=0.5, fix = False)
        print(admg_true)
       # print("\nRandom ADMG:\n", admg)

        c_hard = 0
        c_soft = 0

        for i_admg in range(n_admg):
            # Print progress
            if i_admg % 5000 == 0:
                print('Counting ADMGs: ', i_admg, 'out of ', n_admg)

            # Step 3: Generate random target
            target_dict = random_targets(n, 2)

            # Get adjacency matrices of the augmented MAGs
            aug_adj_dir, aug_adj_bi = get_adj_of_aug_mag(admg_true, n, target_dict)
            # Get adjacency matrices of the augmented MAGs under soft interventions
            aug_adj_soft_dir, aug_adj_soft_bi = get_adj_of_aug_mag_soft(admg_true, n, target_dict)

            # Genrate a random ADMG
            dag_random = random_complete_dag(n)
            admg_random = add_bidirected_edges(dag_random, n, p=0.5)

            # Get adjacency matrices of the augmented MAGs
            aug_adj_dir_random, aug_adj_bi_random = get_adj_of_aug_mag(admg_random, n, target_dict)

            # Get adjacency matrices of the augmented MAGs under soft interventions
            aug_adj_soft_dir_random, aug_adj_soft_bi_random = get_adj_of_aug_mag_soft(admg_random, n, target_dict)

            # Verify equivalence
            if verify_MAG_equi_adj(aug_adj_dir, aug_adj_bi, aug_adj_dir_random, aug_adj_bi_random):
                c_hard += 1
                #print("IMEC ADMG:\n", admg_random)
            if verify_MAG_equi_adj(aug_adj_soft_dir, aug_adj_soft_bi, aug_adj_soft_dir_random, aug_adj_soft_bi_random):
                c_soft += 1
                #print("Soft IMEC ADMG:\n", admg_random)

        c_hard_list.append(c_hard)
        c_soft_list.append(c_soft)

    print('{} out of {} ADMGs are valid under hard interventions'.format(np.mean(c_hard_list)/n_admg, n_admg))
    print('{} out of {} ADMGs are valid under soft interventions'.format(np.mean(c_soft_list)/n_admg, n_admg))
    print('Ratio of valid ADMGs under hard interventions to valid ADMGs under soft interventions: ', np.mean(c_hard_list)/np.mean(c_soft_list))
    # Save the results
    results = {'hard': c_hard_list, 'soft': c_soft_list, 'n_admg': n_admg}
    # Save the results to a file
    with open('bound_results_{}.pkl'.format(n), 'wb') as f:
        pickle.dump(results, f)