import numpy as np
from scipy.special import expit as sigmoid
from itertools import combinations
import igraph as ig
import random
import torch
import networkx as nx
import matplotlib.pyplot as plt
# import pydot
from IPython.display import Image
import pydotplus
import copy
import json
import os

############### linear gaussian ################
#####################case 1 (cause)#####################
def adjacency_matrix_to_dot(adjacency_matrix, node_names=None, latent= [], selection = []):
    graph = pydotplus.Dot(directed=True)

    num_nodes = len(adjacency_matrix)
    for i in range(num_nodes):
        graph.add_node(pydotplus.Node(str(i), label=str(node_names[i]) if node_names else None))

    for i in range(num_nodes):
        for j in range(num_nodes):
            if adjacency_matrix[i, j] == 1:
                graph.add_edge(pydotplus.Edge(str(i), str(j)))
    if len(selection) != 0:
        graph.add_node(pydotplus.Node('s', label='s' if node_names else None))
        for s in selection:
            graph.add_edge(pydotplus.Edge(str(s[0]), 's'))
            graph.add_edge(pydotplus.Edge(str(s[1]), 's'))
    if len(latent) != 0:
        graph.add_node(pydotplus.Node('L', label=str('L') if node_names else None))
        for l in latent:
            for i in range(len(l)):
                graph.add_edge(pydotplus.Edge('L', str(l[i])))
    return graph

def simulate_dag(d, num_of_edge, graph_type):
    """Simulate random DAG with some expected number of edges.

    Args:
        d (int): num of nodes
        s0 (int): expected num of edges
        graph_type (str): ER, SF, BP

    Returns:
        B (np.ndarray): [d, d] binary adj matrix of DAG
    """
    def _random_permutation(M):
        # np.random.permutation permutes first axis only
        P = np.random.permutation(np.eye(M.shape[0]))
        return P.T @ M @ P

    def _random_acyclic_orientation(B_und):
        return np.tril(_random_permutation(B_und), k=-1)

    def _graph_to_adjmat(G):
        return np.array(G.get_adjacency().data)
    
    
    if graph_type == 'ER':
        # Erdos-Renyi
        G_und = ig.Graph.Erdos_Renyi(n=d, m=num_of_edge)#m=d, p=0.3 p=0.5
        # while True:
        #     G_und = ig.Graph.Erdos_Renyi(n=d, m=num_of_edge)#m=d, p=0.3 p=0.5
        #     confounder = [c for c in G_und.vs if c.outdegree() >= 2]
        #     if len(confounder)<2:
        #         print('fail')
        #         continue 
        #     print('success')
        #     G_new, n_l, n_s, L, S, under_L = constrain(G_und)
        #     if n_l and n_s != 0:
        #         break
        B_und = _graph_to_adjmat(G_und)
        B = _random_acyclic_orientation(B_und)
    elif graph_type == 'SF':
        # Scale-free, Barabasi-Albert
        G = ig.Graph.Barabasi(n=d, m=d, directed=True)
        B = _graph_to_adjmat(G)
    elif graph_type == 'BP':
        # Bipartite, Sec 4.1 of (Gu, Fu, Zhou, 2018)
        top = int(0.2 * d)
        G = ig.Graph.Random_Bipartite(top, d - top, m=num_of_edge, directed=True, neimode=ig.OUT)
        B = _graph_to_adjmat(G)
    else:
        raise ValueError('unknown graph type')
    B_perm = _random_permutation(B)
    assert ig.Graph.Adjacency(B_perm.tolist()).is_dag()
    
    return B_perm

def is_dag(W):
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    return G.is_dag()

def S(a, b, x,y, mean, var):
    Es = np.random.normal(mean, var, size=x.shape)
    # import pdb
    # pdb.set_trace()
    return a * x + b * y + Es


def my_simulate_linear_gaussian(W, n, n_L,n_S, L, Sel, under_latent):
    """Simulate samples from linear SEM with specified type of noise.

    For uniform, noise z ~ uniform(-a, a), where a = noise_scale.

    Args:
        W (np.ndarray): [d, d] unweighted adj matrix of DAG
        n (int): num of samples, n=inf mimics population risk
        sem_type (str): gauss, exp, gumbel, uniform, logistic, poisson
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones
        s : number of selection bias.

    Returns:
        X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf
    """   

    data = {}
    perturb_list = []
    d = W.shape[0]
    index = [i for i in range(d)]
    if not is_dag(W):
        raise ValueError('W must be a DAG')
    # empirical risk
    cofs = []
    m_v = []
    s_m_v = []
    # import pdb
    # pdb.set_trace()
    for i in range(n_S):
        s_a = random.uniform(0,3)
        s_b = random.uniform(-1,3)
        s_var = np.random.uniform(low=1.0, high=3.0)
        s_mean = np.random.uniform(low=0.0, high=3.0)
        cofs.append([s_a,s_b])
        s_m_v.append([s_mean, s_var])
        
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    ordered_vertices = G.topological_sorting()

    col_index = [x for x in index if x not in L]

    assert len(ordered_vertices) == d
    X = np.zeros([n, d])
    b = np.random.uniform(low=0.5, high=5.5, size=(d,d))
    b = b * W
    for j in ordered_vertices:
        perturb_list.append(j)
        parents = G.neighbors(j, mode=ig.IN)
        var = np.random.uniform(low=1.0, high=3.0)
        mean = np.random.uniform(low=0.0, high=3.0)
        eps = np.random.normal(mean, var, size=n)
        m_v.append([mean, var])
        X[:,j] = X[:,parents] @ b[parents,j] + eps
    # X_original = X
    # o_index = []
    if n_S > 0:
        new_index = []
        selection_value = []
        for cof, item, noise in zip(cofs,Sel, s_m_v):
            node_x, node_y = item[0], item[1]
            s1,s2 = cof[0], cof[1]
            if n_S > 1:
                thres_value = np.percentile(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]), 25)
            else:
                thres_value = np.percentile(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]), 40)
            selection_value.append(thres_value)
            index = S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]) > thres_value

            # index = (S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X[:,node_x],X[:,node_y], noise[0], noise[1])< 20)
            # print(np.max(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))
            # print(np.min(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))
            # import pdb
            # pdb.set_trace()
            if len(new_index)==0:
                new_index = index
            else:
                new_index = new_index & index
            # print(np.sum(new_index)) 
            if np.sum(new_index) < 2000:
                print('#########')
                return 0
        X_o = X[new_index][:,col_index]
        data['obs'] = X_o
    else:
        data['obs'] = X[:,col_index]
    for i in ordered_vertices:
        # index = ordered_vertices.index(i)
        # update_vertices = ordered_vertices[index+1:]
        if i in col_index:
            co_index = col_index.index(i)
            X_per = np.zeros([n, d])
            X_per[:,i] = 0
            count = 0
            for j in ordered_vertices:
                if j == i:
                    count +=1
                    continue
                parents = G.neighbors(j, mode=ig.IN)
                mean, var = m_v[count][0], m_v[count][1]
                eps = np.random.normal(mean, var, size=n)
                X_per[:,j] = X_per[:,parents] @ b[parents,j] + eps
                count += 1
            if n_S > 0:
                per_index = []
                for cof,item, noise, value in zip(cofs, Sel, s_m_v, selection_value):
                    node_x, node_y = item[0], item[1]
                    s1,s2 = cof[0], cof[1]
                    # index = (S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X_per[:,node_x],X_per[:,node_y], noise[0], noise[1])< 20)
                    # print(np.max(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
                    # print(np.min(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
                    index = S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1]) > value
                    if len(per_index) == 0:
                        per_index = index 
                    else:
                        per_index = per_index & index
                    print(sum(per_index))
                    if np.sum(per_index) < 2000 or (abs(np.sum(per_index) - np.sum(new_index))> 16000):
                        return 0
                X_per = X_per[per_index][:,col_index] 
            else:
                X_per = X_per[:,col_index]
            data[f'per_{co_index}'] = X_per
        else:
            continue
        # if i in col_index:
        #     data[f'per_{i}'] = X_per
        # else:
        #     continue
    ske = np.zeros_like(W)
    for i in range(d):
        for j in range(i):
            if W[i][j] == 1 or W[j][i] == 1:
                ske[i][j] = ske[j][i] = 1
    for item in Sel:
        ske[item[0], item[1]] = ske[item[1], item[0]] = 1
    for item in under_latent:
        pair = list(combinations(item, r =2))
        for edge in pair:
            ske[edge[0], edge[1]] = ske[edge[1],edge[0]] = 1
    # for l in L:
    #     child_of_L = []
    #     for j in range(d):
    #         if W[l][j] == 1:
    #             child_of_L.append(j)
    #     pair = list(combinations(child_of_L,r=2))
    #     for edge in pair:
    #         ske[edge[0], edge[1]] = ske[edge[1],edge[0]] = 1
    new_under_latent = []
    for pair in under_latent:
        new = []
        for i in range(len(pair)):
            new.append(col_index.index(pair[i]))
        new_under_latent.append(new)

    new_sel = []
    for se in Sel:
        new_sel.append([col_index.index(se[0]), col_index.index(se[1])])
    data['ske'] = ske[col_index][:,col_index]
    # data.append(selection)
    # data.append(ordered_vertices)
    data['num_L'] = n_L
    data['num_S'] = n_S
    data['latent'] = new_under_latent
    data['selection'] = new_sel
    data['order_v'] = ordered_vertices
    data['node'] = col_index
    data['dag'] = W[col_index][:,col_index]
    data['Gdag'] = W
    # import pdb
    # pdb.set_trace()
    return data


def my_simulate_linear_gaussian_soft(W, n, n_L,n_S, L, Sel, under_latent):
    """Simulate samples from linear SEM with specified type of noise.

    For uniform, noise z ~ uniform(-a, a), where a = noise_scale.

    Args:
        W (np.ndarray): [d, d] unweighted adj matrix of DAG
        n (int): num of samples, n=inf mimics population risk
        sem_type (str): gauss, exp, gumbel, uniform, logistic, poisson
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones
        s : number of selection bias.

    Returns:
        X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf
    """   

    data = {}
    perturb_list = []
    d = W.shape[0]
    index = [i for i in range(d)]
    if not is_dag(W):
        raise ValueError('W must be a DAG')
    # empirical risk
    cofs = []
    m_v = []
    s_m_v = []
    # import pdb
    # pdb.set_trace()
    for i in range(n_S):
        s_a = random.uniform(0,3)
        s_b = random.uniform(-1,3)
        s_var = np.random.uniform(low=1.0, high=3.0)
        s_mean = np.random.uniform(low=0.0, high=3.0)
        cofs.append([s_a,s_b])
        s_m_v.append([s_mean, s_var])
        
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    ordered_vertices = G.topological_sorting()

    col_index = [x for x in index if x not in L]

    assert len(ordered_vertices) == d
    X = np.zeros([n, d])
    b = np.random.uniform(low=0.5, high=5.5, size=(d,d))
    b = b * W
    for j in ordered_vertices:
        perturb_list.append(j)
        parents = G.neighbors(j, mode=ig.IN)
        var = np.random.uniform(low=1.0, high=3.0)
        mean = np.random.uniform(low=0.0, high=3.0)
        eps = np.random.normal(mean, var, size=n)
        m_v.append([mean, var])
        X[:,j] = X[:,parents] @ b[parents,j] + eps
    # X_original = X
    # o_index = []
    if n_S > 0:
        new_index = []
        selection_value = []
        for cof, item, noise in zip(cofs,Sel, s_m_v):
            node_x, node_y = item[0], item[1]
            s1,s2 = cof[0], cof[1]
            if n_S > 1:
                thres_value = np.percentile(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]), 25)
            else:
                thres_value = np.percentile(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]), 40)
            selection_value.append(thres_value)
            index = S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]) > thres_value

            # index = (S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X[:,node_x],X[:,node_y], noise[0], noise[1])< 20)
            # print(np.max(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))
            # print(np.min(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))
            # import pdb
            # pdb.set_trace()
            if len(new_index)==0:
                new_index = index
            else:
                new_index = new_index & index
            # print(np.sum(new_index)) 
            if np.sum(new_index) < 2000:
                print('#########')
                return 0
        X_o = X[new_index][:,col_index]
        data['obs'] = X_o
    else:
        data['obs'] = X[:,col_index]
    for i in ordered_vertices:
        # index = ordered_vertices.index(i)
        # update_vertices = ordered_vertices[index+1:]
        if i in col_index:
            co_index = col_index.index(i)
            X_per = np.zeros([n, d])
            count = 0
            for j in ordered_vertices:
                parents = G.neighbors(j, mode=ig.IN)
                mean, var = m_v[count][0], m_v[count][1]
                eps = np.random.normal(mean, var, size=n)
                if j == i:
                    X_per[:,j] = X_per[:,parents] @ b[parents,j] + eps + np.random.uniform(3,6,size=n)
                else:
                    X_per[:,j] = X_per[:,parents] @ b[parents,j] + eps
                count += 1
            if n_S > 0:
                per_index = []
                for cof,item, noise, value in zip(cofs, Sel, s_m_v, selection_value):
                    node_x, node_y = item[0], item[1]
                    s1,s2 = cof[0], cof[1]
                    # index = (S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X_per[:,node_x],X_per[:,node_y], noise[0], noise[1])< 20)
                    # print(np.max(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
                    # print(np.min(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
                    index = S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1]) > value
                    if len(per_index) == 0:
                        per_index = index 
                    else:
                        per_index = per_index & index
                    print(sum(per_index))
                    if np.sum(per_index) < 2000 or (abs(np.sum(per_index) - np.sum(new_index))> 16000):
                        return 0
                X_per = X_per[per_index][:,col_index] 
            else:
                X_per = X_per[:,col_index]
            data[f'per_{co_index}'] = X_per
        else:
            continue
        # if i in col_index:
        #     data[f'per_{i}'] = X_per
        # else:
        #     continue
    ske = np.zeros_like(W)
    for i in range(d):
        for j in range(i):
            if W[i][j] == 1 or W[j][i] == 1:
                ske[i][j] = ske[j][i] = 1
    for item in Sel:
        ske[item[0], item[1]] = ske[item[1], item[0]] = 1
    for item in under_latent:
        pair = list(combinations(item, r =2))
        for edge in pair:
            ske[edge[0], edge[1]] = ske[edge[1],edge[0]] = 1
    # for l in L:
    #     child_of_L = []
    #     for j in range(d):
    #         if W[l][j] == 1:
    #             child_of_L.append(j)
    #     pair = list(combinations(child_of_L,r=2))
    #     for edge in pair:
    #         ske[edge[0], edge[1]] = ske[edge[1],edge[0]] = 1
    new_under_latent = []
    for pair in under_latent:
        new = []
        for i in range(len(pair)):
            new.append(col_index.index(pair[i]))
        new_under_latent.append(new)

    new_sel = []
    for se in Sel:
        new_sel.append([col_index.index(se[0]), col_index.index(se[1])])
    data['ske'] = ske[col_index][:,col_index]
    # data.append(selection)
    # data.append(ordered_vertices)
    data['num_L'] = n_L
    data['num_S'] = n_S
    data['latent'] = new_under_latent
    data['selection'] = new_sel
    data['order_v'] = ordered_vertices
    data['node'] = col_index
    data['dag'] = W[col_index][:,col_index]
    data['Gdag'] = W
    # import pdb
    # pdb.set_trace()
    return data


def constrain(W):
        W = ig.Graph.Adjacency(W.tolist())
        num_latent = random.randint(1,2)
        num_select = random.randint(1,2)
        nodes = [i for i in range(W.vcount())]
        selection = np.random.choice(nodes,size=(num_select,2), replace=False).tolist()
        confounder = [c.index for c in W.vs if c.outdegree() >= 2]
        if len(confounder) < num_latent:
            return 'c', 0, 0, 0 ,0, 0
        latent = random.sample(confounder, num_latent)
        ###### create latent confounder ######
        for i in latent:
            par = W.predecessors(i)
            if len(par) == 0:
                continue
            else:
                for j in par:
                    W.delete_edges((j,i))
        
        ##### latent confounder and selection bias can not occur together ######
        vstructure = [v for v in W.vs if v.indegree() >= 2]
        for i in latent:
            child = W.successors(i)
            L_pair = list(combinations(child, r=2))
            for p in L_pair:
                if p in selection:
                    selection.remove(p)
                elif (p[1], p[0]) in selection:
                    selection.remove((p[1], p[0]))
            ####### latent variable can not be selected ########
            for edge in selection:
                if i in edge:
                    selection.remove(edge)
        ##### The rules of constrain ######
        for v in vstructure:
            parent_v = W.predecessors(v)
            for u in parent_v:
                if (u,v) in selection:
                    selection.remove((u,v))
                elif (v,u) in selection:
                    selection.remove((v,u))
            child_v = W.successors(v)
            all_parent = set()
            for m in child_v:
                parent = W.predecessors(m)
                all_parent.update(parent)
            inter = all_parent.intersection(parent_v)
            latent = [x for x in latent if x not in inter]
        
        under_latent = []
        for la in latent:
            ch = W.successors(la)
            if len(ch) < 2:
                latent.remove(la)
            else:
                under_latent.append(ch)
        n_latent = len(latent)
        n_select = len(selection)
        G = np.array(W.get_adjacency().data)
        return G, n_latent, n_select, latent, selection, under_latent

num_of_sample = 1000
tiems = 0
interven = 'soft'
for i in range(num_of_sample):
    # d = random.randint(6, 10)
    print(i)
    d = 18
    number_of_edge = d
    graph_type, sem_type = 'ER', 'gauss'
    DAG_bin = simulate_dag(d, number_of_edge, graph_type) # ground-truth binary matrix
    true_DAG_bin, number_of_latent, number_of_selection, L, Sel, under_latent = constrain(DAG_bin)
    if true_DAG_bin == 'c':
        print('continue')
        continue
    else:
        print("good graph")
    if len(L) != len(under_latent):
        continue
    # import pdb
    # pdb.set_trace()

    sample_size = 30000
    # node_name = [str(i) for i in range(d)]
    if interven == 'hard':
        data = my_simulate_linear_gaussian(true_DAG_bin, sample_size, number_of_latent, number_of_selection, L, Sel, under_latent)
        # import pdb
        # pdb.set_trace()
        if data == 0:
            print('does not count')
            continue
        node_name = [str(i) for i in range(data['dag'].shape[0])]
        dot_graph = adjacency_matrix_to_dot(data['dag'], node_name, data['latent'], data['selection'])
        # file_path = f'./v_{d}/{interven}/sample_{i}'
        file_path = f'./latent_with_selection/LG/{interven}/v_{d}/sample_{i}'
        if not os.path.exists(file_path):
            os.makedirs(file_path, exist_ok=True)
        dot_graph.write_png(os.path.join(file_path, f'graph_{interven}.png'))
        np.savez(os.path.join(file_path, f'sample_{interven}.npz'), **data)
    elif interven == 'soft':
        data = my_simulate_linear_gaussian_soft(true_DAG_bin, sample_size, number_of_latent, number_of_selection, L, Sel, under_latent)
        if data == 0:
            print('does not count continue')
            continue
        node_name = [str(i) for i in range(data['dag'].shape[0])]
        dot_graph = adjacency_matrix_to_dot(data['dag'], node_name, data['latent'], data['selection'])
        # file_path = f'./v_{d}/{interven}/sample_{i}'
        file_path = f'./latent_with_selection/LG/{interven}/v_{d}/sample_{i}'
        if not os.path.exists(file_path):
            os.makedirs(file_path, exist_ok=True)
        dot_graph.write_png(os.path.join(file_path, f'graph_{interven}.png'))
        np.savez(os.path.join(file_path, f'sample_{interven}.npz'), **data)

    tiems += 1

print(tiems)


