import numpy as np
from scipy.special import expit as sigmoid
import igraph as ig
import random
import torch
import networkx as nx
import matplotlib.pyplot as plt
# import pydot
from IPython.display import Image
import pydotplus
import copy
import json
import os

############### linear gaussian ################
#####################case 1 (cause)#####################
def adjacency_matrix_to_dot(adjacency_matrix, node_names=None, selection = []):
    graph = pydotplus.Dot(directed=True)

    num_nodes = len(adjacency_matrix)
    for i in range(num_nodes):
        graph.add_node(pydotplus.Node(str(i), label=str(node_names[i]) if node_names else None))

    for i in range(num_nodes):
        for j in range(num_nodes):
            if adjacency_matrix[i, j] == 1:
                graph.add_edge(pydotplus.Edge(str(i), str(j)))
    if len(selection) != 0:
        graph.add_node(pydotplus.Node('s', label='s' if node_names else None))
        for s in selection:
            graph.add_edge(pydotplus.Edge(str(s[0]), 's'))
            graph.add_edge(pydotplus.Edge(str(s[1]), 's'))
        
    return graph

def simulate_dag(d, s0, graph_type, number_of_edge):
    """Simulate random DAG with some expected number of edges.

    Args:
        d (int): num of nodes
        s0 (int): expected num of edges
        graph_type (str): ER, SF, BP

    Returns:
        B (np.ndarray): [d, d] binary adj matrix of DAG
    """
    def _random_permutation(M):
        # np.random.permutation permutes first axis only
        P = np.random.permutation(np.eye(M.shape[0]))
        return P.T @ M @ P

    def _random_acyclic_orientation(B_und):
        return np.tril(_random_permutation(B_und), k=-1)

    def _graph_to_adjmat(G):
        return np.array(G.get_adjacency().data)

    if graph_type == 'ER':
        # Erdos-Renyi
        G_und = ig.Graph.Erdos_Renyi(n=d, m=number_of_edge)#m=d, p=0.3 p=0.5
        B_und = _graph_to_adjmat(G_und)
        B = _random_acyclic_orientation(B_und)
    elif graph_type == 'SF':
        # Scale-free, Barabasi-Albert
        G = ig.Graph.Barabasi(n=d, m=d, directed=True)
        B = _graph_to_adjmat(G)
    elif graph_type == 'BP':
        # Bipartite, Sec 4.1 of (Gu, Fu, Zhou, 2018)
        top = int(0.2 * d)
        G = ig.Graph.Random_Bipartite(top, d - top, m=s0, directed=True, neimode=ig.OUT)
        B = _graph_to_adjmat(G)
    else:
        raise ValueError('unknown graph type')
    B_perm = _random_permutation(B)
    assert ig.Graph.Adjacency(B_perm.tolist()).is_dag()
    
    return B_perm

def is_dag(W):
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    return G.is_dag()

def S(a, b, x,y, mean, var):
    Es = np.random.normal(mean, var, size=x.shape)
    # import pdb
    # pdb.set_trace()
    return a * x + b * y + Es


def my_simulate_linear_gaussian(W, n, s):
    """Simulate samples from linear SEM with specified type of noise.

    For uniform, noise z ~ uniform(-a, a), where a = noise_scale.

    Args:
        W (np.ndarray): [d, d] unweighted adj matrix of DAG
        n (int): num of samples, n=inf mimics population risk
        sem_type (str): gauss, exp, gumbel, uniform, logistic, poisson
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones
        s : number of selection bias.

    Returns:
        X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf
    """   

    data = {}
    perturb_list = []
    d = W.shape[0]
    if not is_dag(W):
        raise ValueError('W must be a DAG')
    # empirical risk
    nodes = [i for i in range(d)]
    selection = []
    cofs = []
    m_v = []
    s_m_v = []
    # import pdb
    # pdb.set_trace()
    for i in range(s):
        selection.append(random.sample(nodes, k=2))
        s_a = random.uniform(0,3)
        s_b = random.uniform(-1,3)
        s_var = np.random.uniform(low=1.0, high=3.0)
        s_mean = np.random.uniform(low=0.0, high=3.0)
        cofs.append([s_a,s_b])
        s_m_v.append([s_mean, s_var])
        
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    ordered_vertices = G.topological_sorting()

    assert len(ordered_vertices) == d
    X = np.zeros([n, d])
    b = np.random.uniform(low=0.5, high=5.5, size=(d,d))
    b = b * W
    for j in ordered_vertices:
        perturb_list.append(j)
        parents = G.neighbors(j, mode=ig.IN)
        
        var = np.random.uniform(low=1.0, high=3.0)
        mean = np.random.uniform(low=0.0, high=4.0)
        eps = np.random.normal(mean, var, size=n)
        m_v.append([mean, var])
        X[:,j] = X[:,parents] @ b[parents,j] + eps
    X_original = X
    o_index = []

    new_index = []
    selection_value = []
    for cof, item, noise in zip(cofs,selection, s_m_v):
        node_x, node_y = item[0], item[1]
        s1,s2 = cof[0], cof[1]
        if s > 1:
            thres_value = np.percentile(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]), 25)
        else:
            thres_value = np.median(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]))
        selection_value.append(thres_value)
        # index = (S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X[:,node_x],X[:,node_y], noise[0], noise[1])< 20)
        index = S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]) > thres_value
        print(np.max(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))
        print(np.min(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))
        # import pdb
        # pdb.set_trace()
        if len(new_index)==0:
            new_index = index
        else:
            new_index = new_index & index
        print(np.sum(new_index)) 
        if np.sum(new_index) < 2000 or np.sum(new_index) > 26000:
            print('#########')
            return 0
    
    X_o = X[new_index]
    data['obs'] = X_o
    for i in ordered_vertices:
        # index = ordered_vertices.index(i)
        # update_vertices = ordered_vertices[index+1:]
        X_per = np.zeros([n, d])
        X_per[:,i] = 0
        count = 0
        for j in ordered_vertices:
            if j == i:
                count +=1
                continue
            parents = G.neighbors(j, mode=ig.IN)
            mean, var = m_v[count][0], m_v[count][1]
            eps = np.random.normal(mean, var, size=n)
            X_per[:,j] = X_per[:,parents] @ b[parents,j] + eps
            count += 1
        per_index = []
        for cof,item, noise, value in zip(cofs, selection, s_m_v, selection_value):
            node_x, node_y = item[0], item[1]
            s1,s2 = cof[0], cof[1]
            # index = (S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X_per[:,node_x],X_per[:,node_y], noise[0], noise[1])< 20)
            # print(np.max(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
            # print(np.min(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
            index = S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1]) > value
            if len(per_index) == 0:
                per_index = index 
            else:
                per_index = per_index & index
            print(sum(per_index))
            if np.sum(per_index) < 2000 or (abs(np.sum(per_index) - np.sum(new_index))> 14000) or np.sum(per_index)> 26000:
                return 0
            
        X_per = X_per[per_index]
        # data.append(X_per)
        data[f'per_{i}'] = X_per
    ske = np.zeros_like(W)
    for i in range(d):
        for j in range(i):
            if W[i][j] == 1 or W[j][i] == 1:
                ske[i][j] = ske[j][i] = 1
    for item in selection:
        ske[item[0], item[1]] = ske[item[1], item[0]] = 1
    data['ske'] = ske
    # data.append(selection)
    # data.append(ordered_vertices)
    data['selection'] = selection
    data['order_v'] = ordered_vertices
    data['dag'] = W
    # import pdb
    # pdb.set_trace()
    return data

def my_simulate_linear_gaussian_soft(W, n, s):
    """Simulate samples from linear SEM with specified type of noise.

    For uniform, noise z ~ uniform(-a, a), where a = noise_scale.

    Args:
        W (np.ndarray): [d, d] unweighted adj matrix of DAG
        n (int): num of samples, n=inf mimics population risk
        sem_type (str): gauss, exp, gumbel, uniform, logistic, poisson
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones
        s : number of selection bias.

    Returns:
        X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf
    """   

    data = {}
    perturb_list = []
    d = W.shape[0]
    if not is_dag(W):
        raise ValueError('W must be a DAG')
    # empirical risk
    nodes = [i for i in range(d)]
    selection = []
    cofs = []
    m_v = []
    s_m_v = []
    # import pdb
    # pdb.set_trace()
    for i in range(s):
        selection.append(random.sample(nodes, k=2))
        s_a = random.uniform(0,3)
        s_b = random.uniform(-1,3)
        var = np.random.uniform(low=1.0, high=2.0)
        mean = np.random.uniform(low=0.0, high=3.0)
        cofs.append([s_a,s_b])
        s_m_v.append([mean, var])
        
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    ordered_vertices = G.topological_sorting()

    assert len(ordered_vertices) == d
    X = np.zeros([n, d])
    b = np.random.uniform(low=0.5, high=4.5, size=(d,d))
    b = b * W
    for j in ordered_vertices:
        perturb_list.append(j)
        parents = G.neighbors(j, mode=ig.IN)
        
        var = np.random.uniform(low=1.0, high=3.0)
        mean = np.random.uniform(low=0.0, high=3.0)
        eps = np.random.normal(mean, var, size=n)
        m_v.append([mean, var])
        X[:,j] = X[:,parents] @ b[parents,j] + eps
    X_original = X
    o_index = []
    selection_value = []
    for cof,item, noise in zip(cofs,selection, s_m_v):
        node_x, node_y = item[0], item[1]
        s1,s2 = cof[0], cof[1]
        # if s > 1:
        #     thres_value = np.percentile(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]), 25)
        # else:
        #     thres_value = np.median(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]))
        index = (S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])>0) & (S(s1,s2, X[:,node_x],X[:,node_y], noise[0], noise[1])< 20)
        # index = S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]) > thres_value
        # selection_value.append(thres_value)
        if len(o_index) == 0:
            o_index = index
        else:
            o_index = o_index & index
        print(sum(o_index))
        if np.sum(o_index) < 2000 or np.sum(o_index) > 26000:
            print('negtive return 0')
            return 0
        # import pdb
        # pdb.set_trace()
        # import pdb
        # pdb.set_trace()
    X_o = X[o_index]    
    data['obs'] = X_o
    for i in ordered_vertices:
        # index = ordered_vertices.index(i)
        # update_vertices = ordered_vertices[index+1:]
        X_per = np.zeros([n, d])
        n_var = np.random.uniform(low=1, high=3)
        n_mean = np.random.uniform(low=1, high=3.0)
        noise = np.random.normal(n_mean, n_var, size = n)
        # X_per[:,i] = X_per[:,i] + np.random.uniform(5,6,size=n)
        count = 0
        for j in ordered_vertices:
            parents = G.neighbors(j, mode=ig.IN)
            mean, var = m_v[count][0], m_v[count][1]
            eps = np.random.normal(mean, var, size=n)
            if j == i:
                X_per[:,j] = X_per[:,parents] @ b[parents,j] + eps + np.random.uniform(3,6,size=n)
            else:
                X_per[:,j] = X_per[:,parents] @ b[parents,j] + eps
            count += 1
        per_index = []
        print('finish per')

        for cof,item, noise in zip(cofs,selection, s_m_v):
            node_x, node_y = item[0], item[1]
            s1,s2 = cof[0], cof[1]
            index = (S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])> 0) & (S(s1,s2, X_per[:,node_x],X_per[:,node_y], noise[0], noise[1])< 20)
            # index = S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1]) > value
            if len(per_index) == 0:
                per_index = index
            else:
                per_index = per_index & index
            print(sum(per_index))
            # if (np.sum(per_index) < 2000) or (abs(np.sum(per_index) - np.sum(o_index)) > 15000 ):
            if (np.sum(per_index) < 2000) or (abs(np.sum(per_index) - np.sum(o_index)) > 14000) or (np.sum(per_index) > 26000):
                print('interven does not mathc')
                return 0
        X_per = X_per[per_index]
        # import pdb
        # pdb.set_trace()
        print(X_per.shape)
        data[f'per_{i}'] = X_per
    
    # data.append(selection)
    # data.append(ordered_vertices)
    ske = np.zeros_like(W)
    for i in range(d):
        for j in range(i):
            if W[i][j] == 1 or W[j][i] == 1:
                ske[i][j] = ske[j][i] = 1
    for item in selection:
        ske[item[0], item[1]] = ske[item[1], item[0]] = 1
    data['ske'] = ske
    data['selection'] = selection
    data['order_v'] = ordered_vertices
    data['dag'] = W
    # import pdb
    # pdb.set_trace()
    return data

num_of_sample = 8000
tiems = 0
interven = 'hard'
for i in range(num_of_sample):
    # d = random.randint(6, 10)
    d = 10
    number_of_edge = d-1
    graph_type, sem_type = 'ER', 'gauss'
    true_DAG_bin = simulate_dag(d, d, graph_type, number_of_edge) # ground-truth binary matrix

    # import pdb
    # pdb.set_trace()

    sample_size = 30000
    number_of_selection = 4
    node_name = [str(i) for i in range(d)]
    if interven == 'hard':
        data = my_simulate_linear_gaussian(true_DAG_bin, sample_size, number_of_selection)
        if data == 0:
            continue
        dot_graph = adjacency_matrix_to_dot(true_DAG_bin, node_name, data['selection'])
        # file_path = f'./v_{d}/{interven}/sample_{i}'
        file_path = f'./selection/v_{d}/{interven}/select_{number_of_selection}/sample_{i}'
        if not os.path.exists(file_path):
            os.makedirs(file_path, exist_ok=True)
        dot_graph.write_png(os.path.join(file_path, f'graph_{interven}.png'))
        np.savez(os.path.join(file_path, f'sample_{interven}.npz'), **data)
    elif interven == 'soft':
        data = my_simulate_linear_gaussian_soft(true_DAG_bin, sample_size, number_of_selection)
        
        if data == 0:
            print('does not count continue')
            continue
        else:
            print('######')
        dot_graph = adjacency_matrix_to_dot(true_DAG_bin, node_name, data['selection'])
        # file_path = f'./v_{d}/{interven}/sample_{i}'
        file_path = f'./selection/v_{d}/{interven}/select_{number_of_selection}/sample_{i}'
        if not os.path.exists(file_path):
            os.makedirs(file_path, exist_ok=True)
        dot_graph.write_png(os.path.join(file_path, f'graph_{interven}.png'))
        np.savez(os.path.join(file_path, f'sample_{interven}.npz'), **data)

    tiems += 1

print(tiems)


