import numpy as np
from scipy.special import expit as sigmoid
from itertools import combinations
import igraph as ig
import random
import torch
import networkx as nx
import matplotlib.pyplot as plt
# import pydot
from IPython.display import Image
import pydotplus
import copy
import json
import os

############### linear gaussian ################
#####################case 1 (cause)#####################

def adjacency_matrix_to_dot(adjacency_matrix, latent= [], selection = []):
    graph = pydotplus.Dot(directed=True)
    num_nodes = adjacency_matrix.shape[0]
    for i in range(num_nodes):
        graph.add_node(pydotplus.Node(str(i), label=str(i)))
    for i in range(num_nodes):
        for j in range(num_nodes):
            if adjacency_matrix[i, j] == 1:
                graph.add_edge(pydotplus.Edge(str(i), str(j)))
    
    for i in range(len(latent)):
        graph.add_node(pydotplus.Node(f'L_{i}', label=f'L_{i}'))
        for j in range(len(latent[i])):
            graph.add_edge(pydotplus.Edge(f'L_{i}',str(latent[i][j])))
    if len(selection) != 0:
        for s in range(len(selection)):
            graph.add_node(pydotplus.Node(f'S_{s}', label=f'S_{s}'))
            if len(selection[s])==2:
                graph.add_edge(pydotplus.Edge(str(selection[s][0]), f'S_{s}'))
                graph.add_edge(pydotplus.Edge(str(selection[s][1]), f'S_{s}'))

    return graph

def simulate_dag(d, s0, graph_type, number_of_edge):
    """Simulate random DAG with some expected number of edges.

    Args:
        d (int): num of nodes
        s0 (int): expected num of edges
        graph_type (str): ER, SF, BP

    Returns:
        B (np.ndarray): [d, d] binary adj matrix of DAG
    
    """

    def remove_duplicate_pairs(pairs_list):
        # Use a set to track unique pairs (regardless of order)
        unique_pairs = set()
        result = []
        
        for pair in pairs_list:
            # Convert to frozenset to ignore order
            pair_set = frozenset(pair)
            
            # Only add if we haven't seen this pair before
            if pair_set not in unique_pairs:
                unique_pairs.add(pair_set)
                result.append(pair)
        
        return result

    def _random_permutation(M):
        # np.random.permutation permutes first axis only
        P = np.random.permutation(np.eye(M.shape[0]))
        return P.T @ M @ P

    def _random_acyclic_orientation(B_und):
        return np.tril(_random_permutation(B_und), k=-1)

    def _graph_to_adjmat(G):
        return np.array(G.get_adjacency().data)
    
    def constrain(W):
        w_old = W
        non_edge_pairs = []
        for i in range(d):
            for j in range(i+1, d):  # Avoid self-loops and duplicates
                # if not G_und.are_connected(i, j):  # Check if edge doesn't exist
                non_edge_pairs.append((i, j))
        # Randomly select one pair if there are any
        if non_edge_pairs:
            random_selection = random.sample(non_edge_pairs,2)
            # import pdb
            # pdb.set_trace()
        if d < 10:
            num_latent = 1
            num_select = random.randint(0,1)
        else:
            num_latent = random.randint(1,2)
            # num_select = random.randint(1,2)
            num_select = 1
        edge = W.get_edgelist()
        selection = random.sample(edge, num_select)
        selection.extend(random_selection)
        outdegrees = g.outdegree()
        confounder = [i for i, deg in enumerate(outdegrees) if deg >= 2]
        latent = random.sample(confounder, num_latent)
        for i in latent:
            par = W.predecessors(i)
            for j in par:
                W.delete_edges((j,i))
        vstructure = [v for v in W.vs if v.indegree() == 2]
        for i in latent:
            child = W.successors(i)
            L_pair = list(combinations(child, r=2))
            for p in L_pair:
                if p in selection:
                    selection.remove(p)
                if (p[1], p[0]) in selection:
                    selection.remove((p[1], p[0]))
            for edge in selection:
                if i in edge:
                    selection.remove(edge)

        for v in vstructure:
            parent_v = W.predecessors(v)
            for u in parent_v:
                if (u,v) in selection:
                    selection.remove((u,v))
            child_v = W.successors(v)
            all_parent = set()
            for m in child_v:
                parent = W.predecessors(m)
                all_parent.update(parent)
            inter = all_parent.intersection(parent_v)
            latent = [x for x in latent if x not in inter]
        n_latent = len(latent)
        under_latent = []
        for la in latent:
            if len(W.successors(la))<2:
                latent.remove(la)
                continue
            under_latent.append(W.successors(la))

        selection = remove_duplicate_pairs(selection)

        n_select = len(selection)
        if n_select !=3:
            return None, None, None, None, None, None

        G = np.array(W.get_adjacency().data)
        if len(latent) == 0:
            return None, None, None, None, None, None
        if n_select == 0:
            return None, None, None, None, None, None

        return G, n_latent, n_select, latent, selection, under_latent
    
    if graph_type == 'ER':
        # Erdos-Renyi
        while True:
            G_und = ig.Graph.Erdos_Renyi(n=d, m=number_of_edge)#m=d, p=0.3 p=0.5
            B_und = _graph_to_adjmat(G_und)
            B = _random_acyclic_orientation(B_und)
            g = ig.Graph.Adjacency((B > 0).tolist(), mode=ig.ADJ_DIRECTED)#m=d, p=0.3 p=0.5

            # confounder = [c for c in G_und.vs if c.outdegree() >= 2]
            outdegrees = g.outdegree()
            confounder = [i for i, deg in enumerate(outdegrees) if deg >= 2]
            # indegrees = g.indegree()
            # vstruc = [i for i, deg in enumerate(indegrees) if deg == 2]

            # import pdb
            # pdb.set_trace()
            if d <10:
                if len(confounder)<1 :
                    print('fail')
                    continue 
            else:
                if len(confounder)<2 :
                    print('fail')
                    continue 
            # print('success')
            B, n_l, n_s, L, S, under_L = constrain(g)
            if n_l == None:
                continue
            subl = [item for sublist in S for item in sublist]
            tag = 0
            for l in L:
                if l in subl:
                    tag = 1
            if tag == 1:
                continue
            
            try:
                assert ig.Graph.Adjacency(B.tolist()).is_dag()
            except AssertionError:
                print('not dag')
                continue
            
            if n_l and n_s != 0:
                break
    elif graph_type == 'SF':
        # Scale-free, Barabasi-Albert
        G = ig.Graph.Barabasi(n=d, m=d, directed=True)
        B = _graph_to_adjmat(G)
    elif graph_type == 'BP':
        # Bipartite, Sec 4.1 of (Gu, Fu, Zhou, 2018)
        top = int(0.2 * d)
        G = ig.Graph.Random_Bipartite(top, d - top, m=s0, directed=True, neimode=ig.OUT)
        B = _graph_to_adjmat(G)
    else:
        raise ValueError('unknown graph type')
    # B_perm = _random_permutation(B)
    # assert ig.Graph.Adjacency(B.tolist()).is_dag()
    
    return B, n_l, n_s, L, S, under_L

def is_dag(W):
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    return G.is_dag()

def S(a, b, x,y, low, up):
    Es = np.random.uniform(low, up, size=x.shape)
    # import pdb
    # pdb.set_trace()
    return a * x + b * y + Es



def my_simulate_general_hard(W, n, n_L,n_S, L, Sel, under_latent,):
    """Simulate samples from linear SEM with specified type of noise.

    For uniform, noise z ~ uniform(-a, a), where a = noise_scale.

    Args:
        W (np.ndarray): [d, d] unweighted adj matrix of DAG
        n (int): num of samples, n=inf mimics population risk
        sem_type (str): gauss, exp, gumbel, uniform, logistic, poisson
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones
        s : number of selection bias.

    Returns:
        X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf
    """   

    def f(x, det = True, index = None, f_set = []):
        if det == True:
            f_index = np.random.randint(4)
            function_set.append(f_index)
        else:
            f_index = f_set[index]
        if f_index == 0:
            y = np.log(x+2)
        elif f_index == 1:
            # y = np.sin(x)
            y = x**2 - 2*x + 2
        elif f_index == 2:
            y = x**2 - 2
        elif f_index == 3:
            y = x+1
        
        return y 
    
    data = {}
    perturb_list = []
    function_set = []
    d = W.shape[0]
    index = [i for i in range(d)]
    if not is_dag(W):
        raise ValueError('W must be a DAG')
    # empirical risk
    cofs = []
    uniform_l_u = []
    s_l_u = []
    EPS = []

    for i in range(n_S):
        s_a = random.uniform(0,3)
        s_b = random.uniform(-2,3)
        low = np.random.uniform(low=0.0, high=2.0)
        up = np.random.uniform(low=2.0, high=4.0)
        cofs.append([s_a,s_b])
        s_l_u.append([low, up])
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    ordered_vertices = G.topological_sorting()

    col_index = [x for x in index if x not in L]

    assert len(ordered_vertices) == d
    X = np.zeros([n, d])
    b = np.random.uniform(low=-3.5, high=3.5, size=(d,d))
    b = b * W
    for j in ordered_vertices:
        perturb_list.append(j)
        parents = G.neighbors(j, mode=ig.IN)
        low = np.random.uniform(low=0.0, high=4.0)
        # up = np.random.uniform(low=2.0, high=4.0)
        up = np.random.uniform(0,2)
        # eps = np.random.uniform(low=low, high=up, size=n)
        eps = np.random.normal(low,up, size=n)
        uniform_l_u.append([low, up])
        # EPS.append(eps)
        X[:,j] = f(X[:,parents]) @ b[parents,j] + eps

    new_index = []
    selection_value = []
    X_new = X.copy()
    for cof, item, noise in zip(cofs,Sel, s_l_u):
        if type(item) != int:
            node_x, node_y = item[0], item[1]
            s1,s2 = cof[0], cof[1]
        else:
            node_x, node_y = item, 0
            s1,s2 = cof[0], 0
        if n_S > 1:
            thres_value = np.percentile(S(s1,s2 ,X_new[:,node_x],X_new[:,node_y], noise[0], noise[1]), 15)
        else:
            thres_value = np.percentile(S(s1,s2 ,X_new[:,node_x],X_new[:,node_y], noise[0], noise[1]),30)
        selection_value.append(thres_value)
        # index = (S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X[:,node_x],X[:,node_y], noise[0], noise[1])< 20)
        index = S(s1,s2 ,X_new[:,node_x],X_new[:,node_y], noise[0], noise[1]) > thres_value
        # print(np.max(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))
        # print(np.min(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))
        X_new = X_new[index]

        if len(new_index)==0:
            old = 0
            new_index = index
        else:
            old = np.sum(new_index)
            new = np.sum(index)
            new_index = index
            if abs(old - new) < 2000:
                return 0
        print(np.sum(new_index)) 
        if np.sum(new_index) < 2000 :
            print('too many or less samples')
            return 0
    
    X_o_o = X_new[:,col_index]
    if np.isnan(X_o_o).any():
        return 0
    
    for idx_i, i in enumerate(ordered_vertices):
       
        if i in col_index:
            co_index = col_index.index(i)
            # X_per = np.zeros([n, d])
            X_per = np.zeros(X_new.shape)
            count = 0
            for idx_j, j in enumerate(ordered_vertices):
                low_p, up_p = uniform_l_u[count][0], uniform_l_u[count][1]
                # eps = np.random.uniform(low_p, up_p, size=X_per.shape[0])
                eps = np.random.normal(low_p, up_p, size=X_per.shape[0])
                parents = G.neighbors(j, mode=ig.IN)
                if idx_j < idx_i:
                    X_per[:,j] = X_new[:,j]
                elif idx_j == idx_i:
                    ave = np.mean(X_new[:,j])
                    print(f"ave = {ave}")
                    X_per[:,j] = np.random.uniform(0,ave,size=X_per.shape[0])
                elif idx_j > idx_i:
                    X_per[:,j] = f(X_per[:,parents], False, count, function_set) @ b[parents,j] + eps
                count += 1
            per_index = []
            for cof,item, noise, value in zip(cofs, Sel, s_l_u, selection_value):
                tag = False
                if type(item) != int:
                    node_x, node_y = item[0], item[1]
                    s1,s2 = cof[0], cof[1]
                else:
                    node_x, node_y = item, 0
                    s1,s2 = cof[0], 0
                # index = (S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X_per[:,node_x],X_per[:,node_y], noise[0], noise[1])< 20)
                # print(np.max(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
                # print(np.min(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
                index = S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1]) > value
                if i in item:
                    tag = True
                per_index = per_index.copy()
                if len(per_index) == 0:
                    per_index = index 
                else:
                    old_p = np.sum(per_index)
                    per_index = per_index & index
                    new_p = np.sum(per_index)
                    if tag:
                        if abs(old_p - new_p) < 2000:
                            print('perturbation is not good')
                            return 0
                print('#')
                print(sum(per_index))
                
                if tag == False:
                    if np.sum(per_index) < 2000 :
                        print('too many or less samples in perturbation')
                        return 0 
                # else:
            if np.sum(per_index) < 2000 or np.sum(per_index)> 20000:
    
                    print('too many or less samples in perturbation')
                    # import pdb
                    # pdb.set_trace()
                    return 0  
                  
            X_per_f = X_per[per_index][:,col_index]
            
            data[f'per_{co_index}'] = X_per_f
        else:
            continue
    ske = np.zeros_like(W)
    for i in range(d):
        for j in range(i):
            if W[i][j] == 1 or W[j][i] == 1:
                ske[i][j] = ske[j][i] = 1
    for item in Sel:
        ske[item[0], item[1]] = ske[item[1], item[0]] = 1
    child_of_L = []
    for l in L:
        for j in range(d):
            if W[l][j] == 1:
                child_of_L.append(j)
        pair = list(combinations(child_of_L,r=2))
        for edge in pair:
            ske[edge[0], edge[1]] = ske[edge[1],edge[0]] = 1
    data['ske'] = ske[col_index][:,col_index]

    new_under_latent = []
    for pair in under_latent:
        new = []
        for i in range(len(pair)):
            new.append(col_index.index(pair[i]))
        new_under_latent.append(new)

    new_sel = []
    
    for se in Sel:
        try:
            new_sel.append([col_index.index(se[0]), col_index.index(se[1])])
        except ValueError:
            import pdb
            pdb.set_trace()


    data['num_L'] = n_L
    data['num_S'] = n_S
    data['latent'] = new_under_latent
    data['latent_vars'] = L
    data['selection'] = new_sel
    data['order_v'] = ordered_vertices
    data['node'] = col_index
    data['dag'] = W[col_index][:,col_index]
    data['Gdag'] = W
    for sample_size_select in [2000]:
        data_new = data.copy()
        random_ind = np.random.choice(X_o_o.shape[0], size=sample_size_select, replace=False)
        X_o = X_o_o[random_ind]
        data_new['obs'] = X_o
        for i in ordered_vertices:
            if i in col_index:
                co_index = col_index.index(i)
                X_per_o = data[f'per_{co_index}']
                random_ind_per = np.random.choice(X_per_o.shape[0], size=sample_size_select, replace=False)
                X_per_new = X_per_o[random_ind_per]
                data_new[f'per_{co_index}'] = X_per_new
        dot_graph = adjacency_matrix_to_dot(data_new['dag'], data_new['latent'], data_new['selection'])
        # file_path = f'./v_{d}/{interven}/sample_{i}'
        file_path = f'./selection_robust/selection_{n_S}/{interven}/v_{d}/{sample_size_select}/sample_{times}'
        if not os.path.exists(file_path):
            os.makedirs(file_path, exist_ok=True)
        dot_graph.write_png(os.path.join(file_path, f'graph_{interven}.png'))
        np.savez(os.path.join(file_path, f'sample_{interven}.npz'), **data_new)
    return data

def my_simulate_general_soft(W, n, n_L,n_S, L, Sel, under_latent,):
    """Simulate samples from linear SEM with specified type of noise.

    For uniform, noise z ~ uniform(-a, a), where a = noise_scale.

    Args:
        W (np.ndarray): [d, d] unweighted adj matrix of DAG
        n (int): num of samples, n=inf mimics population risk
        sem_type (str): gauss, exp, gumbel, uniform, logistic, poisson
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones
        s : number of selection bias.

    Returns:
        X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf
    """   

    def f(x, det = True, index = None, f_set = []):
        if det == True:
            f_index = np.random.randint(4)
            function_set.append(f_index)
        else:
            f_index = f_set[index]
        if f_index == 0:
            y = np.log(x+2)
        elif f_index == 1:
            # y = np.sin(x)
            y = x**2 - 2*x + 2
        elif f_index == 2:
            y = x**2 - 2
        elif f_index == 3:
            y = x+1
        
        return y 
    
    data = {}
    perturb_list = []
    function_set = []
    d = W.shape[0]
    index = [i for i in range(d)]
    if not is_dag(W):
        raise ValueError('W must be a DAG')
    # empirical risk
    cofs = []
    uniform_l_u = []
    s_l_u = []
    EPS = []

    for i in range(n_S):
        s_a = random.uniform(0,3)
        s_b = random.uniform(-2,3)
        low = np.random.uniform(low=0.0, high=2.0)
        up = np.random.uniform(low=2.0, high=4.0)
        cofs.append([s_a,s_b])
        s_l_u.append([low, up])
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    ordered_vertices = G.topological_sorting()

    col_index = [x for x in index if x not in L]

    assert len(ordered_vertices) == d
    X = np.zeros([n, d])
    b = np.random.uniform(low=-3.5, high=3.5, size=(d,d))
    b = b * W
    for j in ordered_vertices:
        perturb_list.append(j)
        parents = G.neighbors(j, mode=ig.IN)
        low = np.random.uniform(low=0.0, high=4.0)
        # up = np.random.uniform(low=2.0, high=4.0)
        up = np.random.uniform(0,2)
        # eps = np.random.uniform(low=low, high=up, size=n)
        eps = np.random.normal(low,up, size=n)
        uniform_l_u.append([low, up])
        # EPS.append(eps)
        X[:,j] = f(X[:,parents]) @ b[parents,j] + eps

    new_index = []
    selection_value = []
    for cof, item, noise in zip(cofs,Sel, s_l_u):
        if type(item) != int:
            node_x, node_y = item[0], item[1]
            s1,s2 = cof[0], cof[1]
        else:
            node_x, node_y = item, 0
            s1,s2 = cof[0], 0
        if n_S > 1:
            thres_value = np.percentile(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]), 10)
        else:
            thres_value = np.percentile(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]),30)
        selection_value.append(thres_value)
        # index = (S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X[:,node_x],X[:,node_y], noise[0], noise[1])< 20)
        index = S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1]) > thres_value
        # print(np.max(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))
        # print(np.min(S(s1,s2 ,X[:,node_x],X[:,node_y], noise[0], noise[1])))

        if len(new_index)==0:
            old = 0
            new_index = index
        else:
            old = np.sum(new_index)
            new_index = new_index & index
            new = np.sum(new_index)
            if abs(old - new) < 2000:
                print('selection not good')
                return 0
        print(np.sum(new_index)) 
        if np.sum(new_index) < 2000 :
            print('too less samples')
            return 0
    
    X_o_o = X[new_index][:,col_index]
    if np.isnan(X_o_o).any():
        return 0
    
    for idx_i, i in enumerate(ordered_vertices):
       
        if i in col_index:
            co_index = col_index.index(i)
            # X_per = np.zeros([n, d])
            X_per = np.zeros(X[new_index].shape)
            count = 0
            for idx_j, j in enumerate(ordered_vertices):
                low_p, up_p = uniform_l_u[count][0], uniform_l_u[count][1]
                # eps = np.random.uniform(low_p, up_p, size=X_per.shape[0])
                eps = np.random.normal(low_p, up_p, size=X_per.shape[0])
                parents = G.neighbors(j, mode=ig.IN)
                if idx_j < idx_i:
                    X_per[:,j] = X[new_index][:,j]
                elif idx_j == idx_i:
                    ave = np.mean(X[new_index][:,j])
                    print(f"ave = {ave}")
                    X_per[:,j] = f(X_per[:,parents], False, count, function_set) @ b[parents,j] + eps + np.random.uniform(ave/3,ave,size=X_per.shape[0])
                elif idx_j > idx_i:
                    X_per[:,j] = f(X_per[:,parents], False, count, function_set) @ b[parents,j] + eps
                count += 1
            per_index = []
            for cof,item, noise, value in zip(cofs, Sel, s_l_u, selection_value):
                tag = False
                if type(item) != int:
                    node_x, node_y = item[0], item[1]
                    s1,s2 = cof[0], cof[1]
                else:
                    node_x, node_y = item, 0
                    s1,s2 = cof[0], 0
                # index = (S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])>-10) & (S(s1,s2, X_per[:,node_x],X_per[:,node_y], noise[0], noise[1])< 20)
                # print(np.max(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
                # print(np.min(S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1])))
                index = S(s1,s2, X_per[:,node_x], X_per[:,node_y], noise[0], noise[1]) > value
                if i in item:
                    tag = True
                per_index = per_index.copy()
                if len(per_index) == 0:
                    per_index = index 
                else:
                    old_p = np.sum(per_index)
                    per_index = per_index & index
                    new_p = np.sum(per_index)
                    if tag:
                        if abs(old_p - new_p) < 2000:
                            print('perbation not good')
                            return 0
               
                print(sum(per_index))
                
                if tag == False:
                    if np.sum(per_index) < 2000 :
                        print('too many or less samples in perturbation')
                        return 0 
                # else:
            if np.sum(per_index) < 2000 or (abs(np.sum(per_index) - np.sum(new_index))> 15000) or np.sum(per_index)> 20000:
                    print('too many or less samples in perturbation')
                    # import pdb
                    # pdb.set_trace()
                    return 0  
                  
            X_per_f = X_per[per_index][:,col_index]
            
            data[f'per_{co_index}'] = X_per_f
        else:
            continue
    ske = np.zeros_like(W)
    for i in range(d):
        for j in range(i):
            if W[i][j] == 1 or W[j][i] == 1:
                ske[i][j] = ske[j][i] = 1
    for item in Sel:
        ske[item[0], item[1]] = ske[item[1], item[0]] = 1
    child_of_L = []
    for l in L:
        for j in range(d):
            if W[l][j] == 1:
                child_of_L.append(j)
        pair = list(combinations(child_of_L,r=2))
        for edge in pair:
            ske[edge[0], edge[1]] = ske[edge[1],edge[0]] = 1
    data['ske'] = ske[col_index][:,col_index]

    new_under_latent = []
    for pair in under_latent:
        new = []
        for i in range(len(pair)):
            new.append(col_index.index(pair[i]))
        new_under_latent.append(new)

    new_sel = []
    
    for se in Sel:
        try:
            new_sel.append([col_index.index(se[0]), col_index.index(se[1])])
        except ValueError:
            import pdb
            pdb.set_trace()


    data['num_L'] = n_L
    data['num_S'] = n_S
    data['latent'] = new_under_latent
    data['latent_vars'] = L
    data['selection'] = new_sel
    data['order_v'] = ordered_vertices
    data['node'] = col_index
    data['dag'] = W[col_index][:,col_index]
    data['Gdag'] = W
    for sample_size_select in [2000]:
        data_new = data.copy()
        random_ind = np.random.choice(X_o_o.shape[0], size=sample_size_select, replace=False)
        X_o = X_o_o[random_ind]
        data_new['obs'] = X_o
        for i in ordered_vertices:
            if i in col_index:
                co_index = col_index.index(i)
                X_per_o = data[f'per_{co_index}']
                random_ind_per = np.random.choice(X_per_o.shape[0], size=sample_size_select, replace=False)
                X_per_new = X_per_o[random_ind_per]
                data_new[f'per_{co_index}'] = X_per_new
        dot_graph = adjacency_matrix_to_dot(data_new['dag'], data_new['latent'], data_new['selection'])
        # file_path = f'./v_{d}/{interven}/sample_{i}'
        file_path = f'./selection_robust/selection_{n_S}/{interven}/v_{d}/{sample_size_select}/sample_{times}'
        if not os.path.exists(file_path):
            os.makedirs(file_path, exist_ok=True)
        dot_graph.write_png(os.path.join(file_path, f'graph_{interven}.png'))
        np.savez(os.path.join(file_path, f'sample_{interven}.npz'), **data_new)
    return data

num_of_sample = 20000

interven = 'soft'
for d in [15]:
    times = 0
    for i in range(num_of_sample):
        # d = random.randint(6, 10)
        success = True
        number_of_edge = d
        graph_type, sem_type = 'ER', 'gauss'
        true_DAG_bin, number_of_latent, number_of_selection, L, Sel,under_latent  = simulate_dag(d, d, graph_type, number_of_edge) # ground-truth binary matrix

        print('data_generation')


        sample_size = 25000
        if interven == 'hard':
            data = my_simulate_general_hard(true_DAG_bin, sample_size, number_of_latent, number_of_selection, L,Sel, under_latent)
            if data == 0:
                success = False
                
        elif interven == 'soft':
            data = my_simulate_general_soft(true_DAG_bin, sample_size, number_of_latent, number_of_selection, L,Sel, under_latent)
            if data == 0:
                print('does not count continue')
                success = False

        if success:
            times += 1
        if times > 10:
            break




