from BA3_loc import *
from tqdm import tqdm
import os.path as osp
import warnings
warnings.filterwarnings("ignore")
import random
import math
import torch
import copy

from scipy.stats import gamma
from scipy.stats import gompertz
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import weibull_min
from scipy.special import gamma, gammaincinv
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import euclidean_distances
from collections import deque
from paper import *

def create_confounding_variables(original_role_id, confounding_prob):

    num_nodes = len(original_role_id)
    confounders = {}
    

    social_status = []
    for i in range(num_nodes):
        if i < num_nodes // 4:
            status = "elite"
        elif i < num_nodes // 2:
            status = "high"
        elif i < 3 * num_nodes // 4:
            status = "medium"
        else:
            status = "low"
        social_status.append(status)
    confounders['social_status'] = social_status
    

    time_period = []
    for i in range(num_nodes):
        period_idx = i % 5
        periods = ["ancient", "early", "middle", "late", "modern"]
        time_period.append(periods[period_idx])
    confounders['time_period'] = time_period
    

    domain_preference = []
    for i in range(num_nodes):
        domain = i % 7
        domain_preference.append(domain)
    confounders['domain_preference'] = domain_preference
    

    network_influence = []
    for i in range(num_nodes):

        influence = int((np.sin(i * 0.5) + 1) * 2.5)
        network_influence.append(influence)
    confounders['network_influence'] = network_influence
    

    resource_access = []
    for i in range(num_nodes):

        access_level = (i * 3 + 7) % 6
        resource_access.append(access_level)
    confounders['resource_access'] = resource_access
    

    collaboration_tendency = []
    for i in range(num_nodes):
        if i % 2 == 0:
            tendency = "cooperative"
        else:
            tendency = "competitive"
        collaboration_tendency.append(tendency)
    confounders['collaboration_tendency'] = collaboration_tendency
    
    return confounders

def apply_confounding_bias(original_role_id, confounders, confounding_prob):

    confounded_role_id = original_role_id.copy()
    
    for i in range(len(original_role_id)):
        if random.random() < confounding_prob:

            social_status = confounders['social_status'][i]
            time_period = confounders['time_period'][i]
            domain_pref = confounders['domain_preference'][i]
            network_influence = confounders['network_influence'][i]
            resource_access = confounders['resource_access'][i]
            collaboration = confounders['collaboration_tendency'][i]
            

            confounded_role = 0
            

            if social_status == "elite":
                confounded_role += 4
            elif social_status == "high":
                confounded_role += 3
            elif social_status == "medium":
                confounded_role += 2
            else:
                confounded_role += 1
            

            time_weights = {"ancient": 0, "early": 1, "middle": 2, "late": 3, "modern": 4}
            confounded_role += time_weights[time_period]
            

            confounded_role += domain_pref
            

            confounded_role += network_influence * 2
            

            confounded_role += resource_access
            

            if collaboration == "cooperative":
                confounded_role += 3
            else:
                confounded_role += 1
            

            for j in range(0, 10):
                if random.random() < confounding_prob:
                    confounded_role_id[j] = (confounded_role + random.randint(0, 5)) % 5
                else:
                    confounded_role_id[j] = 0
    
    return confounded_role_id

def generate_confounded_features(G, original_role_id, confounders, confounding_prob):

    features_dict = {}
    

    base_feature_params = {
        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},
        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},
        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},
        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},
        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}
    }
    
    for node_id in G.nodes():
        if node_id < len(original_role_id):
            original_role = original_role_id[node_id]
            

            social_status = confounders['social_status'][node_id]
            time_period = confounders['time_period'][node_id]
            domain_pref = confounders['domain_preference'][node_id]
            network_influence = confounders['network_influence'][node_id]
            resource_access = confounders['resource_access'][node_id]
            collaboration = confounders['collaboration_tendency'][node_id]
            

            base_mean = base_feature_params[original_role]['mean']
            base_std = base_feature_params[original_role]['std']
            features = np.random.normal(base_mean, base_std)
            

            if random.random() < confounding_prob:

                confounding_bias = np.zeros(5)
                

                status_effects = {
                    "elite": np.array([4.5, 6.0, 1.5, 3.0, 7.5]),
                    "high": np.array([3.0, 4.0, 1.0, 2.0, 5.0]),
                    "medium": np.array([1.5, 2.0, 2.5, 1.5, 2.5]),
                    "low": np.array([-1.5, -2.0, 4.0, -1.0, 1.0])
                }
                confounding_bias += status_effects[social_status]
                

                time_effects = {
                    "ancient": np.array([2.0, -1.0, 3.0, 4.0, -2.0]),
                    "early": np.array([1.5, 2.5, -2.0, 3.0, 2.0]),
                    "middle": np.array([-1.0, 3.5, -1.5, 1.0, 3.0]),
                    "late": np.array([3.0, 1.0, -4.0, -2.0, 5.0]),
                    "modern": np.array([-2.0, 4.0, 2.0, -3.0, 4.5])
                }
                confounding_bias += time_effects[time_period]
                

                domain_bias = np.array([
                    domain_pref * 0.8,
                    (domain_pref - 3) * 1.0,
                    (6 - domain_pref) * 1.2,
                    domain_pref * 0.6,
                    (domain_pref + 2) * 0.9
                ])
                confounding_bias += domain_bias
                

                influence_bias = np.array([
                    network_influence * 1.5,
                    network_influence * 1.2,
                    network_influence * -0.8,
                    network_influence * 2.0,
                    network_influence * 1.0
                ])
                confounding_bias += influence_bias
                

                resource_bias = np.array([
                    resource_access * 0.7,
                    resource_access * -0.5,
                    resource_access * 1.3,
                    resource_access * 0.9,
                    resource_access * -0.8
                ])
                confounding_bias += resource_bias
                

                if collaboration == "cooperative":
                    collab_bias = np.array([2.5, 3.0, -1.5, 2.0, 3.5])
                else:
                    collab_bias = np.array([-1.5, -2.0, 3.5, -1.0, -2.5])
                confounding_bias += collab_bias
                

                features += confounding_bias * 3.5
            
            features_dict[node_id] = features
            G.nodes[node_id]['features'] = features
    
    return features_dict

def generate_confounding_dataset(confounding_prob, dataset_type='train', num_samples=2000):

    edge_index_list, label_list = [], []
    ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []
    confounding_info_list = []
    e_mean, n_mean = [], []
    
    for i in tqdm(range(num_samples)):
        try:

            G, original_role_id, label = generate_false_cause_dataset1()
            

            confounders = create_confounding_variables(original_role_id, confounding_prob)
            

            confounded_role_id = apply_confounding_bias(original_role_id, confounders, confounding_prob)
            

            features_dict = generate_confounded_features(G, confounded_role_id, confounders, confounding_prob)
            
            label_list.append(label)
            e_mean.append(len(G.edges))
            n_mean.append(len(G.nodes))
            

            role_id_list.append(np.array(confounded_role_id))
            
            if G.number_of_edges() > 0:
                edge_index = np.array(list(G.edges), dtype=int).T
            else:
                edge_index = np.array([[], []], dtype=int)
            
            edge_index_list.append(edge_index)
            
            try:
                pos = nx.spring_layout(G) if G.number_of_nodes() <= 1000 else nx.random_layout(G)
                pos_list.append(np.array(list(pos.values())))
            except:
                pos_list.append(np.array([]))
            
            try:
                if edge_index.size > 0:
                    row, col = edge_index
                    original_role_id_array = np.array(original_role_id)
                    gd = np.array(original_role_id_array[row] > 0, dtype=np.float64) * np.array(original_role_id_array[col] > 0, dtype=np.float64)
                else:
                    gd = np.array([])
                ground_truth_list.append(gd)
            except:
                ground_truth_list.append(np.array([]))
            

            if G.number_of_nodes() > 0:
                feat_mat = np.vstack([
                    features_dict.get(n, np.zeros(5))
                    for n in sorted(G.nodes())
                ])
                features_list.append(feat_mat)
            else:
                features_list.append(np.zeros((0, 5)))
            

            confounding_strength = np.mean([
                1 if original_role_id[j] != confounded_role_id[j] else 0
                for j in range(len(original_role_id))
            ])
            
            confounding_info_list.append({
                'confounding_prob': confounding_prob,
                'dataset_type': dataset_type,
                'original_role_id': original_role_id,
                'confounded_role_id': confounded_role_id.tolist(),
                'confounding_strength': confounding_strength,
                'confounders': confounders,
                'label': label
            })
            
        except Exception as e:
            continue
    
    dataset_dict = {
        'features': features_list,
        'edge_index': edge_index_list,
        'label': label_list,
        'ground_truth': ground_truth_list,
        'role_id': role_id_list,
        'pos': pos_list,
        'confounding_info': confounding_info_list,
        'confounding_prob': confounding_prob,
        'dataset_type': dataset_type
    }
    
    return dataset_dict

def generate_experiment_datasets(confounding_probs, base_dir='./data/paper/'):

    import os
    os.makedirs(base_dir, exist_ok=True)
    
    generated_files = []
    
    for prob in confounding_probs:
        train_dataset = generate_confounding_dataset(
            confounding_prob=prob, 
            dataset_type='train', 
            num_samples=1500
        )
        train_file = f'{base_dir}/train_conf_{prob:.1f}.npy'
        np.save(train_file, train_dataset)
        generated_files.append(train_file)
        
        val_dataset = generate_confounding_dataset(
            confounding_prob=prob, 
            dataset_type='val', 
            num_samples=250
        )
        val_file = f'{base_dir}/val_conf_{prob:.1f}.npy'
        np.save(val_file, val_dataset)
        generated_files.append(val_file)
        
        test_dataset = generate_confounding_dataset(
            confounding_prob=prob,
            dataset_type='test', 
            num_samples=250
        )
        test_file = f'{base_dir}/test_conf_{prob:.1f}.npy'
        np.save(test_file, test_dataset)
        generated_files.append(test_file)
    
    return generated_files

def run_confounding_experiment():

    confounding_probs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    generated_files = generate_experiment_datasets(confounding_probs)
    return generated_files

if __name__ == "__main__":
    generated_files = run_confounding_experiment()
