
import numpy as np
from rich import print as rprint

from order_notears import utils
# from notears.linear import notears_linear
from order_notears.linear import notears_linear
from order_notears.nonlinear import NotearsMLP, notears_nonlinear
import argparse
import random
import time
from order_notears.transitive import get_paths, add_path_to_matrix


def GT_simu(n, d, s0, graph_type, noise_type, nonlinear=False):
    dag_true = utils.simulate_dag(d, s0, graph_type)
    # np.savetxt('out/W_true.csv', dag_true, delimiter=',', fmt='%d')
    if nonlinear:
        X = utils.simulate_nonlinear_sem(dag_true, n, noise_type)
    else:
        X = utils.simulate_linear_sem(dag_true, n, noise_type)
    # np.savetxt('out/X.csv', X, delimiter=',', fmt='%.4f')
    return dag_true, X

def notears(X, nonlinear=False, w_ord=None):

    d = X.shape[1]
    # notears
    if nonlinear:
        model = NotearsMLP(dims=[d, 10, 1], w_ord=w_ord, bias=True, tau=3)
        dag_est = notears_nonlinear(model, X, lambda1=0.03, lambda2=0.03)
    else:
        dag_est = notears_linear(X, lambda1=0.3, loss_type='l2', w_ord=w_ord, tau=3)
    # assert utils.is_dag(dag_est)
    # np.savetxt('out/W_est.csv', dag_est, delimiter=',', fmt='%.4f')
    
    return dag_est
    # acc = utils.count_accuracy(dag_true, dag_est != 0)
    # rprint(acc)
    
def test_linear():
    # node number
    d_lists = [10,30,50,100]
    n_lists = [2, 1000]
    graph_types = ['ER', 'SF']
    noise_types = ['gauss', 'exp', 'gumbel']
    edge_node_ratio = [2, 4]
    ordering_num = [0,0.5,0.75,1]
    
    for d in d_lists:
        for n in n_lists:
            sample_size = min(1000, d * n)
            for graph_type in graph_types:
                for noise_type in noise_types:
                    for edge_ratio in edge_node_ratio:
                        for ord_num in ordering_num:
                            ord_n = int(ord_num * d)
                            for i in range(6):
                                ord_iter = 1 if ord_n == 0 else 6
                                for j in range(ord_iter):
                                    exec = f"python exp_1.py --d={d} --sample_size={sample_size} --graph_type={graph_type} --noise_type={noise_type} --edge_node_ratio={edge_ratio} --ordering_num={ord_n}, --data_seed={i}, --ordering_seed={j}"


def topological_sort(adj_matrix):
    num_nodes = adj_matrix.shape[0]
    visited = [False] * num_nodes
    stack = []
    
    def dfs(node):
        visited[node] = True
        for i in range(num_nodes):
            if adj_matrix[node, i] == 1 and not visited[i]:
                dfs(i)
        stack.append(node)
    
    for node in range(num_nodes):
        if not visited[node]:
            dfs(node)
    
    return stack[::-1]

def select_partial_orderings_by_chain(path_num, dag, proportion=True):
    paths = get_paths(dag)
    paths.sort(key=lambda x: len(x), reverse=True)
    if proportion:
        path_num = int(len(paths) * path_num)
    selected_paths = paths[:path_num]
    return selected_paths

def select_sublists(data, k, m):
    sublists = []
    n = len(data)
    
    # Ensure that it's possible to select m items from data
    if m > n:
        raise ValueError("m cannot be larger than the length of the list")
    
    for _ in range(k):
        # Create a sublist by selecting m indices, sorted to maintain order
        indices = sorted(random.sample(range(n), m))
        sublist = [data[idx] for idx in indices]
        sublists.append(sublist)
    
    return sublists

def select_partial_orderings(total_ordering, ordering_length, number):
    partial_orderings = []
    
    ## random select partial orderings with length of `ordering_length` not violating the total ordering
    for i in range(number):
        # Randomly select `size` elements from the total ordering
        selected_nodes = random.sample(total_ordering, ordering_length)
        # Sort the selected nodes to maintain the original total ordering
        selected_nodes.sort(key=total_ordering.index)
        partial_orderings.append(selected_nodes)
        
    return partial_orderings


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Graph Simulation Settings")

    parser.add_argument('--d', type=str, required=True, help="dataset")
    parser.add_argument('--sample_size', type=int, required=True, help="Sample size for data")
    parser.add_argument('--use_size', type=int, default=0, help="Used size for data")
    parser.add_argument('--linear_flag', action='store_true', help="Flag to indicate if a linear model should be used")
    parser.add_argument('--data_index', type=int, default=123, help="data index")
    parser.add_argument('--ordering_seed', type=int, default=123, help="Seed for random ordering generation")
    parser.add_argument('--chain_num', type=int, default=0, help="Number of variables with known ordering")
    parser.add_argument('--chain_size', type=int, default=0, help="Number of variables with known ordering")
    parser.add_argument('--ordering_num', type=float, default=0, help="Percentage of order pairs")
    
    args = parser.parse_args()
    datatset, n, linear_flag, data_index = args.d, args.sample_size, args.linear_flag, args.data_index
    chain_num, chain_size = args.chain_num, args.chain_size
    ordering_seed = args.ordering_seed
    ordering_num = args.ordering_num
    use_size = args.use_size
    
    nonlinear = not linear_flag
    
    dag_true, X = np.loadtxt(f'real_data/{datatset}_graph.txt', dtype=int), np.loadtxt(f'real_data/{datatset}_{n}.txt')
    d = dag_true.shape[0]
    if use_size:
        X = X[:use_size,:]
    
    utils.set_random_seed(ordering_seed)
    total_ordering = topological_sort(dag_true)
    
    if chain_num * chain_size == 0:
        ordering_mask = None
    else:
        ordering_chains = select_sublists(total_ordering, chain_num, chain_size)
        ordering_mask = np.zeros((d,d))
        for chain in ordering_chains:
            ordering_mask = add_path_to_matrix(ordering_mask, chain)
    
    utils.set_random_seed(None)
    s = time.time()
    dag_est = notears(X, nonlinear=nonlinear, w_ord=ordering_mask)
    t = time.time()
    
    acc = utils.count_accuracy(dag_true, dag_est != 0)
    acc["time"] = t-s
    
    print(f"{datatset}, {n},  {linear_flag}, {use_size},  {chain_num}x{chain_size}, {data_index}, {ordering_seed}\n {acc}")
    # rprint(acc)
