from BA3_loc import *
from tqdm import tqdm
import os.path as osp
import warnings
warnings.filterwarnings("ignore")
import random
import math
import torch
import copy

from scipy.stats import gamma
from scipy.stats import gompertz
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import weibull_min
from scipy.special import gamma, gammaincinv
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import euclidean_distances
from collections import deque
from paper import *

def create_complex_confounding_variables(original_role_id, G, confounding_prob):

    num_nodes = len(original_role_id)
    confounders = {}
    

    social_status = []
    degree_centrality = nx.degree_centrality(G) if G.number_of_nodes() > 0 else {}
    for i in range(num_nodes):
        centrality = degree_centrality.get(i, 0)
        if centrality > 0.8:
            status = "elite"
        elif centrality > 0.6:
            status = "high"
        elif centrality > 0.3:
            status = "medium"
        else:
            status = "low"
        social_status.append(status)
    confounders['social_status'] = social_status
    

    citation_strength = []
    for i in range(num_nodes):

        in_degree = G.in_degree(i) if G.is_directed() else G.degree(i)
        out_degree = G.out_degree(i) if G.is_directed() else G.degree(i)

        strength = (in_degree * 2 + out_degree) / max(1, num_nodes * 0.1)
        citation_strength.append(min(strength, 10))
    confounders['citation_strength'] = citation_strength
    

    clustering_info = []
    clustering_coeff = nx.clustering(G.to_undirected()) if G.number_of_nodes() > 0 else {}
    for i in range(num_nodes):
        coeff = clustering_coeff.get(i, 0)
        if coeff > 0.7:
            cluster_level = "high_cluster"
        elif coeff > 0.4:
            cluster_level = "medium_cluster"
        else:
            cluster_level = "low_cluster"
        clustering_info.append(cluster_level)
    confounders['clustering_info'] = clustering_info
    

    neighbor_influence = []
    for i in range(num_nodes):
        neighbors = list(G.neighbors(i))
        if neighbors:

            neighbor_degrees = [G.degree(n) for n in neighbors]
            avg_neighbor_influence = np.mean(neighbor_degrees)
            influence_level = int(avg_neighbor_influence / max(1, np.mean([G.degree(n) for n in G.nodes()])) * 5)
        else:
            influence_level = 0
        neighbor_influence.append(min(influence_level, 9))
    confounders['neighbor_influence'] = neighbor_influence
    

    path_centrality = []
    try:
        if G.number_of_nodes() > 1 and nx.is_connected(G.to_undirected()):
            betweenness = nx.betweenness_centrality(G.to_undirected())
        else:
            betweenness = {i: 0 for i in range(num_nodes)}
    except:
        betweenness = {i: 0 for i in range(num_nodes)}
    
    for i in range(num_nodes):
        centrality = betweenness.get(i, 0)
        if centrality > 0.1:
            path_level = "high_path"
        elif centrality > 0.05:
            path_level = "medium_path"
        else:
            path_level = "low_path"
        path_centrality.append(path_level)
    confounders['path_centrality'] = path_centrality
    

    temporal_correlation = []
    for i in range(num_nodes):

        time_factor = (i % 10) / 10.0
        neighbor_time_influence = 0
        neighbors = list(G.neighbors(i))
        if neighbors:
            neighbor_time_influence = np.mean([(n % 10) / 10.0 for n in neighbors])
        
        combined_temporal = (time_factor + neighbor_time_influence) / 2
        if combined_temporal > 0.7:
            temporal_level = "future_oriented"
        elif combined_temporal > 0.3:
            temporal_level = "present_focused"
        else:
            temporal_level = "past_oriented"
        temporal_correlation.append(temporal_level)
    confounders['temporal_correlation'] = temporal_correlation
    

    multilayer_interaction = []
    for i in range(num_nodes):

        degree_norm = G.degree(i) / max(1, max([G.degree(n) for n in G.nodes()]))
        clustering_norm = clustering_coeff.get(i, 0)
        centrality_norm = degree_centrality.get(i, 0)
        
        interaction_score = (degree_norm + clustering_norm + centrality_norm) / 3
        if interaction_score > 0.6:
            interaction_type = "strong_interaction"
        elif interaction_score > 0.3:
            interaction_type = "moderate_interaction"
        else:
            interaction_type = "weak_interaction"
        multilayer_interaction.append(interaction_type)
    confounders['multilayer_interaction'] = multilayer_interaction
    

    structural_equivalence = []
    for i in range(num_nodes):

        node_neighbors = set(G.neighbors(i))
        max_similarity = 0
        for j in range(num_nodes):
            if i != j:
                other_neighbors = set(G.neighbors(j))
                if len(node_neighbors) > 0 or len(other_neighbors) > 0:
                    jaccard_sim = len(node_neighbors.intersection(other_neighbors)) / max(1, len(node_neighbors.union(other_neighbors)))
                    max_similarity = max(max_similarity, jaccard_sim)
        
        if max_similarity > 0.5:
            equiv_level = "high_equivalence"
        elif max_similarity > 0.2:
            equiv_level = "medium_equivalence"
        else:
            equiv_level = "low_equivalence"
        structural_equivalence.append(equiv_level)
    confounders['structural_equivalence'] = structural_equivalence
    
    return confounders

def apply_complex_confounding_bias(original_role_id, G, confounders, confounding_prob):

    confounded_role_id = original_role_id.copy()
    
    for i in range(len(original_role_id)):
        if random.random() < confounding_prob:

            social_status = confounders['social_status'][i]
            citation_strength = confounders['citation_strength'][i]
            clustering_info = confounders['clustering_info'][i]
            neighbor_influence = confounders['neighbor_influence'][i]
            path_centrality = confounders['path_centrality'][i]
            temporal_correlation = confounders['temporal_correlation'][i]
            multilayer_interaction = confounders['multilayer_interaction'][i]
            structural_equivalence = confounders['structural_equivalence'][i]
            

            confounded_role = 0
            

            if social_status == "elite":
                confounded_role += 5
            elif social_status == "high":
                confounded_role += 4
            elif social_status == "medium":
                confounded_role += 2
            else:
                confounded_role += 1
            

            confounded_role += int(citation_strength * 0.8)
            

            clustering_weights = {"high_cluster": 4, "medium_cluster": 2, "low_cluster": 1}
            confounded_role += clustering_weights[clustering_info]
            

            confounded_role += neighbor_influence
            

            path_weights = {"high_path": 5, "medium_path": 3, "low_path": 1}
            confounded_role += path_weights[path_centrality]
            

            temporal_weights = {"future_oriented": 4, "present_focused": 2, "past_oriented": 1}
            confounded_role += temporal_weights[temporal_correlation]
            

            interaction_weights = {"strong_interaction": 6, "moderate_interaction": 3, "weak_interaction": 1}
            confounded_role += interaction_weights[multilayer_interaction]
            

            equiv_weights = {"high_equivalence": 3, "medium_equivalence": 2, "low_equivalence": 1}
            confounded_role += equiv_weights[structural_equivalence]
            

            node_degree = G.degree(i) if i in G.nodes() else 0
            degree_influence = min(node_degree, 5)
            confounded_role += degree_influence
            

            final_role = confounded_role % 5
            
            confounded_role_id[i] = final_role
    
    return confounded_role_id

def generate_complex_confounded_features(G, original_role_id, confounders, confounding_prob):

    features_dict = {}
    

    base_feature_params = {
        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},
        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},
        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},
        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},
        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}
    }
    
    for node_id in G.nodes():
        if node_id < len(original_role_id):
            original_role = original_role_id[node_id]
            

            social_status = confounders['social_status'][node_id]
            citation_strength = confounders['citation_strength'][node_id]
            clustering_info = confounders['clustering_info'][node_id]
            neighbor_influence = confounders['neighbor_influence'][node_id]
            path_centrality = confounders['path_centrality'][node_id]
            temporal_correlation = confounders['temporal_correlation'][node_id]
            multilayer_interaction = confounders['multilayer_interaction'][node_id]
            structural_equivalence = confounders['structural_equivalence'][node_id]
            

            base_mean = base_feature_params[original_role]['mean']
            base_std = base_feature_params[original_role]['std']
            features = np.random.normal(base_mean, base_std)
            

            if random.random() < confounding_prob:

                confounding_bias = np.zeros(5)
                

                status_effects = {
                    "elite": np.array([6.0, 8.0, 2.0, 4.0, 9.0]),
                    "high": np.array([4.0, 5.0, 1.5, 2.5, 6.0]),
                    "medium": np.array([2.0, 2.5, 3.0, 2.0, 3.0]),
                    "low": np.array([-1.0, -1.5, 5.0, -0.5, 1.5])
                }
                confounding_bias += status_effects[social_status]
                

                citation_bias = np.array([
                    citation_strength * 0.8,
                    citation_strength * 0.6,
                    citation_strength * -0.4,
                    citation_strength * 1.0,
                    citation_strength * 0.7
                ])
                confounding_bias += citation_bias
                

                clustering_effects = {
                    "high_cluster": np.array([3.5, 4.0, -2.0, 3.0, 4.5]),
                    "medium_cluster": np.array([1.5, 2.0, 0.5, 1.5, 2.0]),
                    "low_cluster": np.array([-1.0, -1.0, 3.0, -0.5, -1.0])
                }
                confounding_bias += clustering_effects[clustering_info]
                

                neighbor_bias = np.array([
                    neighbor_influence * 0.5,
                    neighbor_influence * 0.7,
                    neighbor_influence * -0.3,
                    neighbor_influence * 0.8,
                    neighbor_influence * 0.6
                ])
                confounding_bias += neighbor_bias
                

                path_effects = {
                    "high_path": np.array([4.0, 5.0, -1.5, 4.5, 5.5]),
                    "medium_path": np.array([2.0, 2.5, 0.5, 2.0, 2.5]),
                    "low_path": np.array([0.5, 0.5, 1.5, 0.5, 0.5])
                }
                confounding_bias += path_effects[path_centrality]
                

                temporal_effects = {
                    "future_oriented": np.array([3.0, 4.0, -1.0, 3.5, 4.0]),
                    "present_focused": np.array([1.0, 1.5, 1.0, 1.5, 1.5]),
                    "past_oriented": np.array([-1.0, -1.5, 3.0, -1.0, -1.5])
                }
                confounding_bias += temporal_effects[temporal_correlation]
                

                interaction_effects = {
                    "strong_interaction": np.array([5.0, 6.0, -2.0, 5.5, 6.5]),
                    "moderate_interaction": np.array([2.5, 3.0, 0.5, 2.5, 3.0]),
                    "weak_interaction": np.array([0.5, 0.5, 2.0, 0.5, 0.5])
                }
                confounding_bias += interaction_effects[multilayer_interaction]
                

                equiv_effects = {
                    "high_equivalence": np.array([2.5, 3.0, -1.0, 2.5, 3.0]),
                    "medium_equivalence": np.array([1.0, 1.5, 0.5, 1.0, 1.5]),
                    "low_equivalence": np.array([0.0, 0.0, 1.0, 0.0, 0.0])
                }
                confounding_bias += equiv_effects[structural_equivalence]
                

                node_degree = G.degree(node_id) if node_id in G.nodes() else 0
                degree_bias = np.array([
                    node_degree * 0.3,
                    node_degree * 0.4,
                    node_degree * -0.2,
                    node_degree * 0.5,
                    node_degree * 0.3
                ])
                confounding_bias += degree_bias
                

                features += confounding_bias * 4.0
            
            features_dict[node_id] = features
            G.nodes[node_id]['features'] = features
    
    return features_dict

def generate_complex_confounding_dataset(confounding_prob, dataset_type='train', num_samples=2000):

    edge_index_list, label_list = [], []
    ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []
    confounding_info_list = []
    e_mean, n_mean = [], []
    
    for i in tqdm(range(num_samples)):
        try:

            G, original_role_id, label = generate_false_cause_dataset1()
            

            confounders = create_complex_confounding_variables(original_role_id, G, confounding_prob)
            

            confounded_role_id = apply_complex_confounding_bias(original_role_id, G, confounders, confounding_prob)
            

            features_dict = generate_complex_confounded_features(G, confounded_role_id, confounders, confounding_prob)
            
            label_list.append(label)
            e_mean.append(len(G.edges))
            n_mean.append(len(G.nodes))
            

            role_id_list.append(np.array(confounded_role_id))
            
            if G.number_of_edges() > 0:
                edge_index = np.array(list(G.edges), dtype=int).T
            else:
                edge_index = np.array([[], []], dtype=int)
            
            edge_index_list.append(edge_index)
            
            try:
                pos = nx.spring_layout(G) if G.number_of_nodes() <= 1000 else nx.random_layout(G)
                pos_list.append(np.array(list(pos.values())))
            except:
                pos_list.append(np.array([]))
            
            try:
                if edge_index.size > 0:
                    row, col = edge_index
                    original_role_id_array = np.array(original_role_id)
                    gd = np.array(original_role_id_array[row] > 0, dtype=np.float64) * np.array(original_role_id_array[col] > 0, dtype=np.float64)
                else:
                    gd = np.array([])
                ground_truth_list.append(gd)
            except:
                ground_truth_list.append(np.array([]))
            

            if G.number_of_nodes() > 0:
                feat_mat = np.vstack([
                    features_dict.get(n, np.zeros(5))
                    for n in sorted(G.nodes())
                ])
                features_list.append(feat_mat)
            else:
                features_list.append(np.zeros((0, 5)))
            

            confounding_strength = np.mean([
                1 if original_role_id[j] != confounded_role_id[j] else 0
                for j in range(len(original_role_id))
            ])
            

            network_complexity = {
                'avg_degree': np.mean([G.degree(n) for n in G.nodes()]) if G.number_of_nodes() > 0 else 0,
                'clustering_coefficient': nx.average_clustering(G.to_undirected()) if G.number_of_nodes() > 0 else 0,
                'density': nx.density(G) if G.number_of_nodes() > 0 else 0
            }
            
            confounding_info_list.append({
                'confounding_prob': confounding_prob,
                'dataset_type': dataset_type,
                'original_role_id': original_role_id,
                'confounded_role_id': confounded_role_id.tolist(),
                'confounding_strength': confounding_strength,
                'confounders': confounders,
                'network_complexity': network_complexity,
                'label': label
            })
            
        except Exception as e:
            continue
    
    dataset_dict = {
        'features': features_list,
        'edge_index': edge_index_list,
        'label': label_list,
        'ground_truth': ground_truth_list,
        'role_id': role_id_list,
        'pos': pos_list,
        'confounding_info': confounding_info_list,
        'confounding_prob': confounding_prob,
        'dataset_type': dataset_type
    }
    
    return dataset_dict

def generate_complex_experiment_datasets(confounding_prob=0.7, base_dir='./data/casual/'):

    import os
    os.makedirs(base_dir, exist_ok=True)
    
    generated_files = []
    

    train_dataset = generate_complex_confounding_dataset(
        confounding_prob=0, 
        dataset_type='train', 
        num_samples=3000
    )
    train_file = f'{base_dir}/train_casual_1_4.npy'
    np.save(train_file, train_dataset)
    generated_files.append(train_file)
    

    val_dataset = generate_complex_confounding_dataset(
        confounding_prob=0, 
        dataset_type='val', 
        num_samples=375
    )
    val_file = f'{base_dir}/val_casual_1_4.npy'
    np.save(val_file, val_dataset)
    generated_files.append(val_file)
    

    test_dataset = generate_complex_confounding_dataset(
        confounding_prob=confounding_prob,
        dataset_type='test', 
        num_samples=375
    )
    test_file = f'{base_dir}/test_casual_1_4.npy'
    np.save(test_file, test_dataset)
    generated_files.append(test_file)
    
    return generated_files

def run_complex_confounding_experiment():

    confounding_prob = 0.7
    generated_files = generate_complex_experiment_datasets(confounding_prob)
    return generated_files

if __name__ == "__main__":
    generated_files = run_complex_confounding_experiment()
    print(f"Generated files: {generated_files}")
