from BA3_loc import *
from tqdm import tqdm
import os.path as osp
import warnings
warnings.filterwarnings("ignore")
import random
import math
import torch
import copy

from scipy.stats import gamma
from scipy.stats import gompertz
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import weibull_min
from scipy.special import gamma, gammaincinv
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import euclidean_distances
from collections import deque
from paper import *

def create_confounding_variables(original_role_id, confounding_prob):

    num_nodes = len(original_role_id)
    confounders = {}
    

    social_status = []
    for i in range(num_nodes):
        if i < num_nodes // 4:
            status = "elite"
        elif i < num_nodes // 2:
            status = "high"
        elif i < 3 * num_nodes // 4:
            status = "medium"
        else:
            status = "low"
        social_status.append(status)
    confounders['social_status'] = social_status
    

    time_period = []
    for i in range(num_nodes):
        period_idx = i % 5
        periods = ["ancient", "early", "middle", "late", "modern"]
        time_period.append(periods[period_idx])
    confounders['time_period'] = time_period
    

    domain_preference = []
    for i in range(num_nodes):
        domain = i % 7
        domain_preference.append(domain)
    confounders['domain_preference'] = domain_preference
    

    network_influence = []
    for i in range(num_nodes):

        influence = int((np.sin(i * 0.5) + 1) * 2.5)
        network_influence.append(influence)
    confounders['network_influence'] = network_influence
    

    resource_access = []
    for i in range(num_nodes):

        access_level = (i * 3 + 7) % 6
        resource_access.append(access_level)
    confounders['resource_access'] = resource_access
    

    collaboration_tendency = []
    for i in range(num_nodes):
        if i % 2 == 0:
            tendency = "cooperative"
        else:
            tendency = "competitive"
        collaboration_tendency.append(tendency)
    confounders['collaboration_tendency'] = collaboration_tendency
    
    return confounders

def create_fixed_intervention_confounders(num_nodes):


    fixed_values = {
        'social_status': 'medium',
        'time_period': 'middle', 
        'domain_preference': 3,
        'network_influence': 2,
        'resource_access': 2,
        'collaboration_tendency': 'cooperative'
    }
    

    confounders = {}
    for key, value in fixed_values.items():
        confounders[key] = [value] * num_nodes
    
    return confounders

def apply_confounding_bias_gen_conf_style(original_role_id, confounders, confounding_prob):

    confounded_role_id = original_role_id.copy()
    
    for i in range(len(original_role_id)):
        if random.random() < confounding_prob:

            social_status = confounders['social_status'][i]
            time_period = confounders['time_period'][i]
            domain_pref = confounders['domain_preference'][i]
            network_influence = confounders['network_influence'][i]
            resource_access = confounders['resource_access'][i]
            collaboration = confounders['collaboration_tendency'][i]
            

            confounded_role = 0
            

            if social_status == "elite":
                confounded_role += 4
            elif social_status == "high":
                confounded_role += 3
            elif social_status == "medium":
                confounded_role += 2
            else:
                confounded_role += 1
            

            time_weights = {"ancient": 0, "early": 1, "middle": 2, "late": 3, "modern": 4}
            confounded_role += time_weights[time_period]
            

            confounded_role += domain_pref
            

            confounded_role += network_influence * 2
            

            confounded_role += resource_access
            

            if collaboration == "cooperative":
                confounded_role += 3
            else:
                confounded_role += 1
            

            for j in range(0, 10):
                if random.random() < confounding_prob:
                    confounded_role_id[j] = (confounded_role + random.randint(0, 5)) % 5
                else:
                    confounded_role_id[j] = 0
    
    return confounded_role_id

def apply_fixed_confounding_bias_gen_intervened_style(original_role_id, confounders, confounding_prob=0.9):

    confounded_role_id = original_role_id.copy()
    
    for i in range(len(original_role_id)):
        if random.random() < confounding_prob:

            social_status = confounders['social_status'][i]
            time_period = confounders['time_period'][i]
            domain_pref = confounders['domain_preference'][i]
            network_influence = confounders['network_influence'][i]
            resource_access = confounders['resource_access'][i]
            collaboration = confounders['collaboration_tendency'][i]
            

            confounded_role = 0
            

            if social_status == "elite":
                confounded_role += 4
            elif social_status == "high":
                confounded_role += 3
            elif social_status == "medium":
                confounded_role += 2
            else:
                confounded_role += 1
            

            time_weights = {"ancient": 0, "early": 1, "middle": 2, "late": 3, "modern": 4}
            confounded_role += time_weights[time_period]
            

            confounded_role += domain_pref
            

            confounded_role += network_influence * 2
            

            confounded_role += resource_access
            

            if collaboration == "cooperative":
                confounded_role += 3
            else:
                confounded_role += 1
            


            confounded_role_id[i] = 1
    
    return confounded_role_id

def generate_confounded_features_gen_conf_style(G, original_role_id, confounders, confounding_prob):

    features_dict = {}
    

    base_feature_params = {
        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},
        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},
        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},
        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},
        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}
    }
    
    for node_id in G.nodes():
        if node_id < len(original_role_id):
            original_role = original_role_id[node_id]
            

            social_status = confounders['social_status'][node_id]
            time_period = confounders['time_period'][node_id]
            domain_pref = confounders['domain_preference'][node_id]
            network_influence = confounders['network_influence'][node_id]
            resource_access = confounders['resource_access'][node_id]
            collaboration = confounders['collaboration_tendency'][node_id]
            

            base_mean = base_feature_params[original_role]['mean']
            base_std = base_feature_params[original_role]['std']
            features = np.random.normal(base_mean, base_std)
            

            if random.random() < confounding_prob:

                confounding_bias = np.zeros(5)
                

                status_effects = {
                    "elite": np.array([4.5, 6.0, 1.5, 3.0, 7.5]),
                    "high": np.array([3.0, 4.0, 1.0, 2.0, 5.0]),
                    "medium": np.array([1.5, 2.0, 2.5, 1.5, 2.5]),
                    "low": np.array([-1.5, -2.0, 4.0, -1.0, 1.0])
                }
                confounding_bias += status_effects[social_status]
                

                time_effects = {
                    "ancient": np.array([2.0, -1.0, 3.0, 4.0, -2.0]),
                    "early": np.array([1.5, 2.5, -2.0, 3.0, 2.0]),
                    "middle": np.array([-1.0, 3.5, -1.5, 1.0, 3.0]),
                    "late": np.array([3.0, 1.0, -4.0, -2.0, 5.0]),
                    "modern": np.array([-2.0, 4.0, 2.0, -3.0, 4.5])
                }
                confounding_bias += time_effects[time_period]
                

                domain_bias = np.array([
                    domain_pref * 0.8,
                    (domain_pref - 3) * 1.0,
                    (6 - domain_pref) * 1.2,
                    domain_pref * 0.6,
                    (domain_pref + 2) * 0.9
                ])
                confounding_bias += domain_bias
                

                influence_bias = np.array([
                    network_influence * 1.5,
                    network_influence * 1.2,
                    network_influence * -0.8,
                    network_influence * 2.0,
                    network_influence * 1.0
                ])
                confounding_bias += influence_bias
                

                resource_bias = np.array([
                    resource_access * 0.7,
                    resource_access * -0.5,
                    resource_access * 1.3,
                    resource_access * 0.9,
                    resource_access * -0.8
                ])
                confounding_bias += resource_bias
                

                if collaboration == "cooperative":
                    collab_bias = np.array([2.5, 3.0, -1.5, 2.0, 3.5])
                else:
                    collab_bias = np.array([-1.5, -2.0, 3.5, -1.0, -2.5])
                confounding_bias += collab_bias
                

                features += confounding_bias * 3.5
            
            features_dict[node_id] = features
            G.nodes[node_id]['features'] = features
    
    return features_dict

def generate_enhanced_confounded_features_gen_intervened_style(G, original_role_id, confounders, confounding_prob):

    features_dict = {}
    

    base_feature_params = {
        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},
        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},
        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},
        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},
        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}
    }
    
    for node_id in G.nodes():
        if node_id < len(original_role_id):
            original_role = original_role_id[node_id]
            

            social_status = confounders['social_status'][node_id]
            time_period = confounders['time_period'][node_id]
            domain_pref = confounders['domain_preference'][node_id]
            network_influence = confounders['network_influence'][node_id]
            resource_access = confounders['resource_access'][node_id]
            collaboration = confounders['collaboration_tendency'][node_id]
            

            base_mean = base_feature_params[original_role]['mean']
            base_std = base_feature_params[original_role]['std']
            features = np.random.normal(base_mean, base_std)
            

            if random.random() < confounding_prob:

                confounding_bias = np.zeros(5)
                

                status_effects = {
                    "elite": np.array([4.5, 6.0, 1.5, 3.0, 7.5]),
                    "high": np.array([3.0, 4.0, 1.0, 2.0, 5.0]),
                    "medium": np.array([1.5, 2.0, 2.5, 1.5, 2.5]),
                    "low": np.array([-1.5, -2.0, 4.0, -1.0, 1.0])
                }
                confounding_bias += status_effects[social_status]
                

                time_effects = {
                    "ancient": np.array([2.0, -1.0, 3.0, 4.0, -2.0]),
                    "early": np.array([1.5, 2.5, -2.0, 3.0, 2.0]),
                    "middle": np.array([-1.0, 3.5, -1.5, 1.0, 3.0]),
                    "late": np.array([3.0, 1.0, -4.0, -2.0, 5.0]),
                    "modern": np.array([-2.0, 4.0, 2.0, -3.0, 4.5])
                }
                confounding_bias += time_effects[time_period]
                

                domain_bias = np.array([
                    domain_pref * 0.8,
                    (domain_pref - 3) * 1.0,
                    (6 - domain_pref) * 1.2,
                    domain_pref * 0.6,
                    (domain_pref + 2) * 0.9
                ])
                confounding_bias += domain_bias
                

                influence_bias = np.array([
                    network_influence * 1.5,
                    network_influence * 1.2,
                    network_influence * -0.8,
                    network_influence * 2.0,
                    network_influence * 1.0
                ])
                confounding_bias += influence_bias
                

                resource_bias = np.array([
                    resource_access * 0.7,
                    resource_access * -0.5,
                    resource_access * 1.3,
                    resource_access * 0.9,
                    resource_access * -0.8
                ])
                confounding_bias += resource_bias
                

                if collaboration == "cooperative":
                    collab_bias = np.array([2.5, 3.0, -1.5, 2.0, 3.5])
                else:
                    collab_bias = np.array([-1.5, -2.0, 3.5, -1.0, -2.5])
                confounding_bias += collab_bias
                

                features += confounding_bias * random.uniform(-20, 20) + random.uniform(-30, 30)
            
            features_dict[node_id] = features
            G.nodes[node_id]['features'] = features
    
    return features_dict

def generate_mixed_dataset(confounding_prob=0.7, intervention_prob=0.1, dataset_type='train', num_samples=2000):

    edge_index_list, label_list = [], []
    ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []
    confounding_info_list = []
    e_mean, n_mean = [], []
    
    for i in tqdm(range(num_samples)):
        try:

            G, original_role_id, label = generate_false_cause_dataset1()
            

            if intervention_prob == 0.0:

                confounders = create_confounding_variables(original_role_id, confounding_prob)
                confounded_role_id = apply_confounding_bias_gen_conf_style(original_role_id, confounders, confounding_prob)
                features_dict = generate_confounded_features_gen_conf_style(G, confounded_role_id, confounders, confounding_prob)
                
            elif intervention_prob == 1.0:

                confounders = create_fixed_intervention_confounders(len(original_role_id))
                confounded_role_id = apply_fixed_confounding_bias_gen_intervened_style(original_role_id, confounders, 0.9)
                features_dict = generate_enhanced_confounded_features_gen_intervened_style(G, original_role_id, confounders, 0.9)
                
            else:


                if random.random() < intervention_prob:

                    confounders = create_fixed_intervention_confounders(len(original_role_id))
                    confounded_role_id = apply_fixed_confounding_bias_gen_intervened_style(original_role_id, confounders, 0.9)
                    features_dict = generate_enhanced_confounded_features_gen_intervened_style(G, original_role_id, confounders, 0.9)
                else:

                    confounders = create_confounding_variables(original_role_id, confounding_prob)
                    confounded_role_id = apply_confounding_bias_gen_conf_style(original_role_id, confounders, confounding_prob)
                    features_dict = generate_confounded_features_gen_conf_style(G, confounded_role_id, confounders, confounding_prob)
            
            label_list.append(label)
            e_mean.append(len(G.edges))
            n_mean.append(len(G.nodes))
            

            role_id_list.append(np.array(confounded_role_id))
            
            if G.number_of_edges() > 0:
                edge_index = np.array(list(G.edges), dtype=int).T
            else:
                edge_index = np.array([[], []], dtype=int)
            
            edge_index_list.append(edge_index)
            
            try:
                pos = nx.spring_layout(G) if G.number_of_nodes() <= 1000 else nx.random_layout(G)
                pos_list.append(np.array(list(pos.values())))
            except:
                pos_list.append(np.array([]))
            
            try:
                if edge_index.size > 0:
                    row, col = edge_index
                    original_role_id_array = np.array(original_role_id)
                    gd = np.array(original_role_id_array[row] > 0, dtype=np.float64) * np.array(original_role_id_array[col] > 0, dtype=np.float64)
                else:
                    gd = np.array([])
                ground_truth_list.append(gd)
            except:
                ground_truth_list.append(np.array([]))
            

            if G.number_of_nodes() > 0:
                feat_mat = np.vstack([
                    features_dict.get(n, np.zeros(5))
                    for n in sorted(G.nodes())
                ])
                features_list.append(feat_mat)
            else:
                features_list.append(np.zeros((0, 5)))
            

            confounding_strength = np.mean([
                1 if original_role_id[j] != confounded_role_id[j] else 0
                for j in range(len(original_role_id))
            ])
            
            confounding_info_list.append({
                'confounding_prob': confounding_prob,
                'intervention_prob': intervention_prob,
                'dataset_type': dataset_type,
                'original_role_id': original_role_id,
                'confounded_role_id': confounded_role_id.tolist(),
                'confounding_strength': confounding_strength,
                'confounders': confounders,
                'label': label
            })
            
        except Exception as e:
            continue
    
    dataset_dict = {
        'features': features_list,
        'edge_index': edge_index_list,
        'label': label_list,
        'ground_truth': ground_truth_list,
        'role_id': role_id_list,
        'pos': pos_list,
        'confounding_info': confounding_info_list,
        'confounding_prob': confounding_prob,
        'intervention_prob': intervention_prob,
        'dataset_type': dataset_type
    }
    
    return dataset_dict

def generate_mixed_experiment_datasets(intervention_prob=0.1, base_dir='./data/int/'):

    import os
    os.makedirs(base_dir, exist_ok=True)
    
    generated_files = []
    confounding_prob = 0.7
    

    train_dataset = generate_mixed_dataset(
        confounding_prob=confounding_prob,
        intervention_prob=intervention_prob,
        dataset_type='train',
        num_samples=1500
    )
    train_file = f'{base_dir}/train_int_{intervention_prob:.1f}.npy'
    np.save(train_file, train_dataset)
    generated_files.append(train_file)
    

    val_dataset = generate_mixed_dataset(
        confounding_prob=confounding_prob,
        intervention_prob=intervention_prob,
        dataset_type='val',
        num_samples=200
    )
    val_file = f'{base_dir}/val_int_{intervention_prob:.1f}.npy'
    np.save(val_file, val_dataset)
    generated_files.append(val_file)
    

    test_dataset = generate_mixed_dataset(
        confounding_prob=confounding_prob,
        intervention_prob=intervention_prob,
        dataset_type='test',
        num_samples=200
    )
    test_file = f'{base_dir}/test_int_{intervention_prob:.1f}.npy'
    np.save(test_file, test_dataset)
    generated_files.append(test_file)
    
    return generated_files

def run_mixed_experiment(intervention_prob=0.1):

    generated_files = generate_mixed_experiment_datasets(intervention_prob)
    return generated_files

if __name__ == "__main__":

    generated_files = run_mixed_experiment(intervention_prob=0.2)
    print(f"Generated files: {generated_files}")
    




