from BA3_loc import *
from tqdm import tqdm
import os.path as osp
import warnings
warnings.filterwarnings("ignore")
import random
import math
import torch
import copy

from scipy.stats import gamma
from scipy.stats import gompertz
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import weibull_min
from scipy.special import gamma, gammaincinv
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import euclidean_distances
from collections import deque
from paper import *

def create_complex_confounding_variables(G, original_role_id, confounding_prob):

    num_nodes = len(original_role_id)
    confounders = {}
    

    try:

        degree_centrality = nx.degree_centrality(G)

        if nx.is_connected(G.to_undirected()):
            betweenness_centrality = nx.betweenness_centrality(G)
            closeness_centrality = nx.closeness_centrality(G)
        else:
            betweenness_centrality = {node: 0.0 for node in G.nodes()}
            closeness_centrality = {node: 0.0 for node in G.nodes()}
        

        clustering_coeff = nx.clustering(G.to_undirected())
        

        pagerank = nx.pagerank(G)
        
    except:

        degree_centrality = {i: 0.0 for i in range(num_nodes)}
        betweenness_centrality = {i: 0.0 for i in range(num_nodes)}
        closeness_centrality = {i: 0.0 for i in range(num_nodes)}
        clustering_coeff = {i: 0.0 for i in range(num_nodes)}
        pagerank = {i: 1.0/num_nodes for i in range(num_nodes)}
    

    neighbor_influence = []
    for i in range(num_nodes):
        if i in G.nodes():
            neighbors = list(G.neighbors(i))
            if neighbors:

                neighbor_roles = [original_role_id[n] for n in neighbors if n < len(original_role_id)]
                if neighbor_roles:
                    avg_neighbor_role = np.mean(neighbor_roles)
                    max_neighbor_role = max(neighbor_roles)
                    influence_score = int((avg_neighbor_role + max_neighbor_role) / 2)
                else:
                    influence_score = 0
            else:
                influence_score = 0
        else:
            influence_score = 0
        neighbor_influence.append(influence_score)
    confounders['neighbor_influence'] = neighbor_influence
    

    path_dependency = []
    for i in range(num_nodes):
        if i in G.nodes():

            high_role_nodes = [j for j in range(num_nodes) if j < len(original_role_id) and original_role_id[j] > 2]
            if high_role_nodes:
                try:
                    paths_to_high_nodes = []
                    for target in high_role_nodes:
                        if target in G.nodes() and nx.has_path(G.to_undirected(), i, target):
                            path_length = nx.shortest_path_length(G.to_undirected(), i, target)
                            paths_to_high_nodes.append(path_length)
                    
                    if paths_to_high_nodes:
                        avg_path_length = np.mean(paths_to_high_nodes)

                        dependency_score = max(0, 5 - int(avg_path_length))
                    else:
                        dependency_score = 0
                except:
                    dependency_score = 0
            else:
                dependency_score = 0
        else:
            dependency_score = 0
        path_dependency.append(dependency_score)
    confounders['path_dependency'] = path_dependency
    

    community_structure = []
    try:

        communities = nx.community.louvain_communities(G.to_undirected())
        node_to_community = {}
        for comm_id, community in enumerate(communities):
            for node in community:
                node_to_community[node] = comm_id
        
        for i in range(num_nodes):
            if i in node_to_community:

                community_id = node_to_community[i]
                community_nodes = [node for node, comm in node_to_community.items() if comm == community_id]
                high_role_count = sum(1 for node in community_nodes 
                                    if node < len(original_role_id) and original_role_id[node] > 2)
                community_influence = min(4, high_role_count)
            else:
                community_influence = 0
            community_structure.append(community_influence)
    except:

        for i in range(num_nodes):
            if i in G.nodes():
                degree = G.degree(i)
                community_influence = min(4, degree // 2)
            else:
                community_influence = 0
            community_structure.append(community_influence)
    confounders['community_structure'] = community_structure
    

    centrality_composite = []
    for i in range(num_nodes):
        if i in G.nodes():

            degree_score = degree_centrality.get(i, 0) * 10
            betweenness_score = betweenness_centrality.get(i, 0) * 10
            closeness_score = closeness_centrality.get(i, 0) * 10
            pagerank_score = pagerank.get(i, 0) * 100
            
            composite_score = int((degree_score + betweenness_score + closeness_score + pagerank_score) / 4)
            composite_score = min(5, composite_score)
        else:
            composite_score = 0
        centrality_composite.append(composite_score)
    confounders['centrality_composite'] = centrality_composite
    

    structural_similarity = []
    for i in range(num_nodes):
        if i in G.nodes():

            clustering_score = clustering_coeff.get(i, 0)
            degree = G.degree(i) if i in G.nodes() else 0
            

            similar_nodes = []
            for j in range(num_nodes):
                if j != i and j in G.nodes():
                    j_clustering = clustering_coeff.get(j, 0)
                    j_degree = G.degree(j)
                    

                    if abs(clustering_score - j_clustering) < 0.3 and abs(degree - j_degree) <= 2:
                        similar_nodes.append(j)
            

            if similar_nodes:
                similar_roles = [original_role_id[n] for n in similar_nodes if n < len(original_role_id)]
                if similar_roles:
                    similarity_influence = int(np.mean(similar_roles))
                else:
                    similarity_influence = 0
            else:
                similarity_influence = 0
        else:
            similarity_influence = 0
        structural_similarity.append(similarity_influence)
    confounders['structural_similarity'] = structural_similarity
    

    propagation_influence = []
    for i in range(num_nodes):
        if i in G.nodes():

            visited = set()
            queue = deque([(i, 0)])
            influence_sum = 0
            
            while queue and len(visited) < min(10, num_nodes):
                current_node, distance = queue.popleft()
                if current_node in visited or distance > 3:
                    continue
                
                visited.add(current_node)
                

                if current_node < len(original_role_id):
                    node_influence = original_role_id[current_node] * (4 - distance) / 4
                    influence_sum += node_influence
                

                if current_node in G.nodes():
                    for neighbor in G.neighbors(current_node):
                        if neighbor not in visited:
                            queue.append((neighbor, distance + 1))
            
            propagation_score = min(5, int(influence_sum))
        else:
            propagation_score = 0
        propagation_influence.append(propagation_score)
    confounders['propagation_influence'] = propagation_influence
    


    social_status = []
    for i in range(num_nodes):
        if i in G.nodes():
            degree = G.degree(i)
            centrality_score = degree_centrality.get(i, 0)
            
            if centrality_score > 0.7 or degree > num_nodes * 0.3:
                status = "elite"
            elif centrality_score > 0.5 or degree > num_nodes * 0.2:
                status = "high"
            elif centrality_score > 0.3 or degree > num_nodes * 0.1:
                status = "medium"
            else:
                status = "low"
        else:
            status = "low"
        social_status.append(status)
    confounders['social_status'] = social_status
    

    network_role = []
    for i in range(num_nodes):
        if i in G.nodes():
            degree = G.degree(i)
            clustering = clustering_coeff.get(i, 0)
            
            if degree == 0:
                role = "isolated"
            elif clustering > 0.7:
                role = "hub_connector"
            elif degree > np.mean([G.degree(n) for n in G.nodes()]):
                role = "bridge"
            else:
                role = "peripheral"
        else:
            role = "isolated"
        network_role.append(role)
    confounders['network_role'] = network_role
    
    return confounders

def create_fixed_intervention_structural_confounders(G, num_nodes):


    fixed_values = {
        'neighbor_influence': 2,
        'path_dependency': 2,
        'community_structure': 2,
        'centrality_composite': 2,
        'structural_similarity': 2,
        'propagation_influence': 2,
        'social_status': 'medium',
        'network_role': 'peripheral'
    }
    

    confounders = {}
    for key, value in fixed_values.items():
        confounders[key] = [value] * num_nodes
    
    return confounders

def apply_complex_confounding_bias(G, original_role_id, confounders, confounding_prob):

    confounded_role_id = original_role_id.copy()
    
    for i in range(len(original_role_id)):
        if random.random() < confounding_prob:

            neighbor_influence = confounders['neighbor_influence'][i]
            path_dependency = confounders['path_dependency'][i]
            community_structure = confounders['community_structure'][i]
            centrality_composite = confounders['centrality_composite'][i]
            structural_similarity = confounders['structural_similarity'][i]
            propagation_influence = confounders['propagation_influence'][i]
            social_status = confounders['social_status'][i]
            network_role = confounders['network_role'][i]
            

            confounded_role = 0
            

            confounded_role += neighbor_influence * 2
            

            confounded_role += path_dependency * 1.5
            

            confounded_role += community_structure * 2
            

            confounded_role += centrality_composite * 2.5
            

            confounded_role += structural_similarity * 1.8
            

            confounded_role += propagation_influence * 2.2
            

            status_weights = {"elite": 5, "high": 4, "medium": 2, "low": 0}
            confounded_role += status_weights[social_status]
            

            role_weights = {"hub_connector": 4, "bridge": 3, "peripheral": 1, "isolated": 0}
            confounded_role += role_weights[network_role]
            

            if i in G.nodes():

                degree = G.degree(i)
                degree_influence = min(3, degree // 2)
                confounded_role += degree_influence
                

                try:
                    triangles = sum(1 for _ in nx.triangles(G.to_undirected(), i))
                    triangle_influence = min(2, triangles)
                    confounded_role += triangle_influence
                except:
                    triangle_influence = 0
                

                try:
                    two_hop_neighbors = set()
                    for neighbor in G.neighbors(i):
                        for second_neighbor in G.neighbors(neighbor):
                            if second_neighbor != i:
                                two_hop_neighbors.add(second_neighbor)
                    two_hop_influence = min(2, len(two_hop_neighbors) // 3)
                    confounded_role += two_hop_influence
                except:
                    two_hop_influence = 0
            

            base_role = int(confounded_role) % 5
            

            if random.random() < 0.8:

                original_weight = 0.3
                confounded_weight = 0.7
                final_role = int(original_role_id[i] * original_weight + base_role * confounded_weight)
                confounded_role_id[i] = final_role % 5
            else:

                confounded_role_id[i] = (original_role_id[i] + random.randint(-1, 1)) % 5
    
    return confounded_role_id

def apply_fixed_structural_confounding_bias(G, original_role_id, confounders, confounding_prob=0.9):

    confounded_role_id = original_role_id.copy()
    
    for i in range(len(original_role_id)):
        if random.random() < confounding_prob:

            neighbor_influence = confounders['neighbor_influence'][i]
            path_dependency = confounders['path_dependency'][i]
            community_structure = confounders['community_structure'][i]
            centrality_composite = confounders['centrality_composite'][i]
            structural_similarity = confounders['structural_similarity'][i]
            propagation_influence = confounders['propagation_influence'][i]
            social_status = confounders['social_status'][i]
            network_role = confounders['network_role'][i]
            

            confounded_role = 0
            

            confounded_role += neighbor_influence * 2
            

            confounded_role += path_dependency * 1.5
            

            confounded_role += community_structure * 2
            

            confounded_role += centrality_composite * 2.5
            

            confounded_role += structural_similarity * 1.8
            

            confounded_role += propagation_influence * 2.2
            

            status_weights = {"elite": 5, "high": 4, "medium": 2, "low": 0}
            confounded_role += status_weights[social_status]
            

            role_weights = {"hub_connector": 4, "bridge": 3, "peripheral": 1, "isolated": 0}
            confounded_role += role_weights[network_role]
            


            confounded_role_id[i] = int(confounded_role) % 5
    
    return confounded_role_id

def generate_complex_confounded_features(G, original_role_id, confounders, confounding_prob):

    features_dict = {}
    

    base_feature_params = {
        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},
        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},
        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},
        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},
        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}
    }
    
    for node_id in G.nodes():
        if node_id < len(original_role_id):
            original_role = original_role_id[node_id]
            

            neighbor_influence = confounders['neighbor_influence'][node_id]
            path_dependency = confounders['path_dependency'][node_id]
            community_structure = confounders['community_structure'][node_id]
            centrality_composite = confounders['centrality_composite'][node_id]
            structural_similarity = confounders['structural_similarity'][node_id]
            propagation_influence = confounders['propagation_influence'][node_id]
            social_status = confounders['social_status'][node_id]
            network_role = confounders['network_role'][node_id]
            

            base_mean = base_feature_params[original_role]['mean']
            base_std = base_feature_params[original_role]['std']
            features = np.random.normal(base_mean, base_std)
            

            if random.random() < confounding_prob:

                confounding_bias = np.zeros(5)
                

                neighbor_bias = np.array([
                    neighbor_influence * 1.2,
                    neighbor_influence * 0.8,
                    neighbor_influence * -0.6,
                    neighbor_influence * 1.5,
                    neighbor_influence * 1.0
                ])
                confounding_bias += neighbor_bias
                

                path_bias = np.array([
                    path_dependency * 0.9,
                    path_dependency * 1.3,
                    path_dependency * -0.8,
                    path_dependency * 1.1,
                    path_dependency * 0.7
                ])
                confounding_bias += path_bias
                

                community_bias = np.array([
                    community_structure * 1.1,
                    community_structure * -0.7,
                    community_structure * 1.4,
                    community_structure * 0.9,
                    community_structure * -1.2
                ])
                confounding_bias += community_bias
                

                centrality_bias = np.array([
                    centrality_composite * 1.8,
                    centrality_composite * 1.5,
                    centrality_composite * -1.0,
                    centrality_composite * 2.0,
                    centrality_composite * 1.3
                ])
                confounding_bias += centrality_bias
                

                similarity_bias = np.array([
                    structural_similarity * 0.8,
                    structural_similarity * -1.1,
                    structural_similarity * 1.6,
                    structural_similarity * 0.7,
                    structural_similarity * 1.2
                ])
                confounding_bias += similarity_bias
                

                propagation_bias = np.array([
                    propagation_influence * 1.4,
                    propagation_influence * 1.0,
                    propagation_influence * -0.9,
                    propagation_influence * 1.7,
                    propagation_influence * 0.6
                ])
                confounding_bias += propagation_bias
                

                status_effects = {
                    "elite": np.array([5.0, 6.5, -2.0, 4.0, 8.0]),
                    "high": np.array([3.5, 4.5, -1.0, 2.5, 5.5]),
                    "medium": np.array([1.0, 1.5, 1.0, 1.0, 2.0]),
                    "low": np.array([-2.0, -3.0, 3.5, -1.5, -1.0])
                }
                confounding_bias += status_effects[social_status]
                

                role_effects = {
                    "hub_connector": np.array([4.5, 5.0, -1.5, 3.5, 6.0]),
                    "bridge": np.array([3.0, 3.5, -0.5, 2.0, 4.0]),
                    "peripheral": np.array([0.5, 1.0, 1.5, 0.5, 1.5]),
                    "isolated": np.array([-1.5, -2.0, 2.5, -1.0, -0.5])
                }
                confounding_bias += role_effects[network_role]
                

                if node_id in G.nodes():

                    degree = G.degree(node_id)
                    degree_bias = np.array([
                        degree * 0.3,
                        degree * 0.2,
                        degree * -0.1,
                        degree * 0.4,
                        degree * 0.25
                    ])
                    confounding_bias += degree_bias
                    

                    try:
                        clustering = nx.clustering(G.to_undirected(), node_id)
                        clustering_bias = np.array([
                            clustering * 2.0,
                            clustering * -1.5,
                            clustering * 3.0,
                            clustering * 1.0,
                            clustering * 2.5
                        ])
                        confounding_bias += clustering_bias
                    except:
                        pass
                    

                    neighbor_count = len(list(G.neighbors(node_id)))
                    neighbor_count_bias = np.array([
                        neighbor_count * 0.2,
                        neighbor_count * 0.15,
                        neighbor_count * -0.1,
                        neighbor_count * 0.25,
                        neighbor_count * 0.18
                    ])
                    confounding_bias += neighbor_count_bias
                

                features += confounding_bias * 4.0
            
            features_dict[node_id] = features
            G.nodes[node_id]['features'] = features
    
    return features_dict

def generate_fixed_structural_confounded_features(G, original_role_id, confounders, confounding_prob):

    features_dict = {}
    

    base_feature_params = {
        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},
        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},
        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},
        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},
        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}
    }
    
    for node_id in G.nodes():
        if node_id < len(original_role_id):
            original_role = original_role_id[node_id]
            

            neighbor_influence = confounders['neighbor_influence'][node_id]
            path_dependency = confounders['path_dependency'][node_id]
            community_structure = confounders['community_structure'][node_id]
            centrality_composite = confounders['centrality_composite'][node_id]
            structural_similarity = confounders['structural_similarity'][node_id]
            propagation_influence = confounders['propagation_influence'][node_id]
            social_status = confounders['social_status'][node_id]
            network_role = confounders['network_role'][node_id]
            

            base_mean = base_feature_params[original_role]['mean']
            base_std = base_feature_params[original_role]['std']
            features = np.random.normal(base_mean, base_std)
            

            if random.random() < confounding_prob:

                confounding_bias = np.zeros(5)
                

                neighbor_bias = np.array([
                    neighbor_influence * 1.2,
                    neighbor_influence * 0.8,
                    neighbor_influence * -0.6,
                    neighbor_influence * 1.5,
                    neighbor_influence * 1.0
                ])
                confounding_bias += neighbor_bias
                

                path_bias = np.array([
                    path_dependency * 0.9,
                    path_dependency * 1.3,
                    path_dependency * -0.8,
                    path_dependency * 1.1,
                    path_dependency * 0.7
                ])
                confounding_bias += path_bias
                

                community_bias = np.array([
                    community_structure * 1.1,
                    community_structure * -0.7,
                    community_structure * 1.4,
                    community_structure * 0.9,
                    community_structure * -1.2
                ])
                confounding_bias += community_bias
                

                centrality_bias = np.array([
                    centrality_composite * 1.8,
                    centrality_composite * 1.5,
                    centrality_composite * -1.0,
                    centrality_composite * 2.0,
                    centrality_composite * 1.3
                ])
                confounding_bias += centrality_bias
                

                similarity_bias = np.array([
                    structural_similarity * 0.8,
                    structural_similarity * -1.1,
                    structural_similarity * 1.6,
                    structural_similarity * 0.7,
                    structural_similarity * 1.2
                ])
                confounding_bias += similarity_bias
                

                propagation_bias = np.array([
                    propagation_influence * 1.4,
                    propagation_influence * 1.0,
                    propagation_influence * -0.9,
                    propagation_influence * 1.7,
                    propagation_influence * 0.6
                ])
                confounding_bias += propagation_bias
                

                status_effects = {
                    "elite": np.array([5.0, 6.5, -2.0, 4.0, 8.0]),
                    "high": np.array([3.5, 4.5, -1.0, 2.5, 5.5]),
                    "medium": np.array([1.0, 1.5, 1.0, 1.0, 2.0]),
                    "low": np.array([-2.0, -3.0, 3.5, -1.5, -1.0])
                }
                confounding_bias += status_effects[social_status]
                

                role_effects = {
                    "hub_connector": np.array([4.5, 5.0, -1.5, 3.5, 6.0]),
                    "bridge": np.array([3.0, 3.5, -0.5, 2.0, 4.0]),
                    "peripheral": np.array([0.5, 1.0, 1.5, 0.5, 1.5]),
                    "isolated": np.array([-1.5, -2.0, 2.5, -1.0, -0.5])
                }
                confounding_bias += role_effects[network_role]
                

                features += confounding_bias * random.uniform(-25, 25) + random.uniform(-35, 35)
            
            features_dict[node_id] = features
            G.nodes[node_id]['features'] = features
    
    return features_dict

def generate_mixed_structural_dataset(confounding_prob=0.7, intervention_prob=0.1, dataset_type='train', num_samples=2000):

    edge_index_list, label_list = [], []
    ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []
    confounding_info_list = []
    e_mean, n_mean = [], []
    
    for i in tqdm(range(num_samples)):
        try:

            G, original_role_id, label = generate_false_cause_dataset1()
            

            if intervention_prob == 0.0:

                confounders = create_complex_confounding_variables(G, original_role_id, confounding_prob)
                confounded_role_id = apply_complex_confounding_bias(G, original_role_id, confounders, confounding_prob)
                features_dict = generate_complex_confounded_features(G, original_role_id, confounders, confounding_prob)
                
            elif intervention_prob == 1.0:

                confounders = create_fixed_intervention_structural_confounders(G, len(original_role_id))
                confounded_role_id = apply_fixed_structural_confounding_bias(G, original_role_id, confounders, confounding_prob)
                features_dict = generate_fixed_structural_confounded_features(G, original_role_id, confounders, confounding_prob)
                
            else:

                if random.random() < intervention_prob:

                    confounders = create_fixed_intervention_structural_confounders(G, len(original_role_id))
                    confounded_role_id = apply_fixed_structural_confounding_bias(G, original_role_id, confounders, confounding_prob)
                    features_dict = generate_fixed_structural_confounded_features(G, original_role_id, confounders, confounding_prob)
                else:

                    confounders = create_complex_confounding_variables(G, original_role_id, confounding_prob)
                    confounded_role_id = apply_complex_confounding_bias(G, original_role_id, confounders, confounding_prob)
                    features_dict = generate_complex_confounded_features(G, original_role_id, confounders, confounding_prob)
            

            edge_index = torch.tensor(list(G.edges)).t().contiguous()
            if edge_index.size(0) == 0:
                edge_index = torch.empty((2, 0), dtype=torch.long)
            

            num_nodes = len(original_role_id)
            features = torch.zeros((num_nodes, 5))
            for node_id, feat in features_dict.items():
                if node_id < num_nodes:
                    features[node_id] = torch.tensor(feat, dtype=torch.float)
            

            labels = torch.tensor(confounded_role_id, dtype=torch.long)
            ground_truth = torch.tensor(original_role_id, dtype=torch.long)
            

            pos = None
            if hasattr(G, 'pos'):
                pos = torch.tensor([G.pos[i] for i in range(num_nodes)], dtype=torch.float)
            

            edge_index_list.append(edge_index)
            label_list.append(labels)
            ground_truth_list.append(ground_truth)
            role_id_list.append(labels)
            features_list.append(features)
            pos_list.append(pos)
            confounding_info_list.append(confounders)
            

            e_mean.append(edge_index.size(1))
            n_mean.append(num_nodes)
            
        except Exception as e:
            continue
    
    dataset_dict = {
        'features': features_list,
        'edge_index': edge_index_list,
        'label': label_list,
        'ground_truth': ground_truth_list,
        'role_id': role_id_list,
        'pos': pos_list,
        'confounding_info': confounding_info_list,
        'confounding_prob': confounding_prob,
        'intervention_prob': intervention_prob,
        'dataset_type': dataset_type
    }
    
    return dataset_dict

def generate_mixed_structural_experiment_datasets(intervention_prob=0.1, base_dir='./data/casual/'):

    import os
    os.makedirs(base_dir, exist_ok=True)
    
    generated_files = []
    confounding_prob = 0.7
    

    train_dataset = generate_mixed_structural_dataset(
        confounding_prob=0,
        intervention_prob=1,
        dataset_type='train',
        num_samples=3000
    )
    train_file = f'{base_dir}/train_casual_2_4.npy'
    np.save(train_file, train_dataset)
    generated_files.append(train_file)
    

    val_dataset = generate_mixed_structural_dataset(
        confounding_prob=0,
        intervention_prob=1,
        dataset_type='val',
        num_samples=375
    )
    val_file = f'{base_dir}/val_casual_2_4.npy'
    np.save(val_file, val_dataset)
    generated_files.append(val_file)
    

    test_dataset = generate_mixed_structural_dataset(
        confounding_prob=0,
        intervention_prob=0,
        dataset_type='test',
        num_samples=375
    )
    test_file = f'{base_dir}/test_casual_2_4.npy'
    np.save(test_file, test_dataset)
    generated_files.append(test_file)
    
    return generated_files

def run_mixed_structural_experiment(intervention_prob=0.1):

    generated_files = generate_mixed_structural_experiment_datasets(intervention_prob)
    return generated_files

if __name__ == "__main__":

    generated_files = run_mixed_structural_experiment(intervention_prob=0.2)
    print(f"Generated files: {generated_files}")
    





