import os
import os.path as osp
import random
from tqdm import tqdm
import traceback
import json

import numpy as np
import networkx as nx
import torch


from BA3_loc import find_gd
from paper import generate_false_cause_dataset1
from gen_conf import create_confounding_variables, apply_confounding_bias, generate_confounded_features
import warnings
warnings.filterwarnings("ignore")


_MEAN = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)
_STD  = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)


ALL_PAPER_ELEMENTS = {
    0: "Basic elements",
    1: "Citation elements", 
    2: "Author elements",
    3: "Topic elements",
    4: "Method elements"
}


ELEMENT_CONFOUNDING_PROPERTIES = {
    0: {
        'name': 'Basic elements',
        'confounding_intensity': 'minimal',
        'feature_bias_multiplier': 0.8,
        'structural_disruption': 0.3,
        'label_confusion_prob': 0.15,
    },
    1: {
        'name': 'Citation elements',
        'confounding_intensity': 'moderate',
        'feature_bias_multiplier': 1.8,
        'structural_disruption': 1.2,
        'label_confusion_prob': 0.35,
    },
    2: {
        'name': 'Author elements',
        'confounding_intensity': 'severe',
        'feature_bias_multiplier': 3.2,
        'structural_disruption': 2.5,
        'label_confusion_prob': 0.55,
    },
    3: {
        'name': 'Topic elements',
        'confounding_intensity': 'strong',
        'feature_bias_multiplier': 2.5,
        'structural_disruption': 1.8,
        'label_confusion_prob': 0.45,
    },
    4: {
        'name': 'Method elements',
        'confounding_intensity': 'mild',
        'feature_bias_multiplier': 1.3,
        'structural_disruption': 0.7,
        'label_confusion_prob': 0.25,
    }
}


PAPER_EXPERIMENTS = []


for element_id in range(5):

    PAPER_EXPERIMENTS.append({
        'name': f'causal_{element_id}',
        'type': 'causal',
        'element_id': element_id,
        'file_suffix': f'element_causal_{element_id}_no_conf',
        'target_causal_elements': [element_id],
        'confounder_pool_ids': [],
        'confound_prob': 0.0,
        'background_confounder_prob_train': 0.0,
        'background_confounder_prob_test': 0.0,
    })
    

    other_elements = [eid for eid in range(5) if eid != element_id]
    PAPER_EXPERIMENTS.append({
        'name': f'confound_{element_id}',
        'type': 'confounding',
        'element_id': element_id,
        'file_suffix': f'element_confound_{element_id}_others_causal',
        'target_causal_elements': other_elements,
        'confounder_pool_ids': [element_id],
        'confound_prob': 0.7,
        'background_confounder_prob_train': 0.1,
        'background_confounder_prob_test': 0.1,
    })

def generate_paper_dataset_with_specific_elements(target_element_ids=None):

    if target_element_ids is None:

        G, role_id, label = generate_false_cause_dataset1()
        

        if G.number_of_edges() > 0:
            edge_index = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
            if edge_index.shape[0] != 2:
                edge_index = edge_index.T
        else:
            edge_index = np.array([[], []], dtype=np.int64)
        

        element_presence = {i: False for i in range(5)}
        for role in role_id:
            if 0 <= role <= 4:
                element_presence[role] = True
                
        return G, role_id, label, edge_index, element_presence




    if len(target_element_ids) == 1:
        element_id = target_element_ids[0]

        label = element_id
        element_presence = {i: False for i in range(5)}
        
        try:

            G, _, _ = generate_false_cause_dataset1()
            

            role_id = [element_id] * G.number_of_nodes()
            element_presence[element_id] = True
            
        except Exception as e:
            print(f"Warning: fallback generation for element {element_id}: {e}")

            G = nx.Graph()
            G.add_nodes_from(range(8))
            G.add_edges_from([(i, i+1) for i in range(7)])
            for node in G.nodes():
                G.nodes[node]['feature'] = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
            role_id = [element_id] * G.number_of_nodes()


        if G.number_of_edges() > 0:
            edge_index = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
            if edge_index.shape[0] != 2:
                edge_index = edge_index.T
        else:
            edge_index = np.array([[], []], dtype=np.int64)
            
        return G, role_id, label, edge_index, element_presence


    label = random.choice([0, 1, 2, 3, 4])
    element_presence = {i: False for i in range(5)}
    
    try:
        G, original_role_id, _ = generate_false_cause_dataset1()
        

        role_id = []
        for i, orig_role in enumerate(original_role_id):
            if i < len(target_element_ids):
                new_role = target_element_ids[i % len(target_element_ids)]
                role_id.append(new_role)
                element_presence[new_role] = True
            else:

                new_role = random.choice(target_element_ids)
                role_id.append(new_role)
                element_presence[new_role] = True
                
    except Exception as e:
        print(f"Warning: Failed to generate multi-element dataset: {e}")

        G = nx.Graph()
        G.add_nodes_from(range(10))
        G.add_edges_from([(i, i+1) for i in range(9)])
        for node in G.nodes():
            G.nodes[node]['feature'] = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
        role_id = [random.choice(target_element_ids) for _ in range(10)]
        for eid in target_element_ids:
            element_presence[eid] = True


    if G.number_of_edges() > 0:
        edge_index = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
        if edge_index.shape[0] != 2:
            edge_index = edge_index.T
    else:
        edge_index = np.array([[], []], dtype=np.int64)
        
    return G, role_id, label, edge_index, element_presence

def attach_confounders_to_paper_graph(G, role_id, num_confounders, confounder_element_ids=None):

    if confounder_element_ids is None:
        confounder_element_ids = list(range(5))
        
    new_G = G.copy()
    new_role_id = role_id.copy()
    

    if not hasattr(new_G, 'graph') or not isinstance(new_G.graph, dict):
        new_G.graph = {}
    new_G.graph.setdefault('attached_confounders', [])
    new_G.graph.setdefault('confounding_effects', [])
    
    for _ in range(num_confounders):
        try:

            confounder_id = random.choice(confounder_element_ids)
            

            confounder_props = ELEMENT_CONFOUNDING_PROPERTIES[confounder_id]
            

            try:
                conf_G, conf_role_id, _ = generate_false_cause_dataset1()
                

                if confounder_props['confounding_intensity'] == 'minimal':

                    target_nodes = max(4, int(conf_G.number_of_nodes() * 0.6))
                    nodes_to_keep = random.sample(list(conf_G.nodes()), min(target_nodes, conf_G.number_of_nodes()))
                    conf_G = conf_G.subgraph(nodes_to_keep).copy()
                elif confounder_props['confounding_intensity'] == 'severe':

                    additional_edges = int(conf_G.number_of_edges() * 0.5)
                    nodes = list(conf_G.nodes())
                    for _ in range(additional_edges):
                        if len(nodes) > 1:
                            u, v = random.sample(nodes, 2)
                            if not conf_G.has_edge(u, v):
                                conf_G.add_edge(u, v)
                elif confounder_props['confounding_intensity'] == 'strong':

                    additional_edges = int(conf_G.number_of_edges() * 0.3)
                    nodes = list(conf_G.nodes())
                    for _ in range(additional_edges):
                        if len(nodes) > 1:
                            u, v = random.sample(nodes, 2)
                            if not conf_G.has_edge(u, v):
                                conf_G.add_edge(u, v)
                

                conf_role_id = [confounder_id] * conf_G.number_of_nodes()
                

                bias_multiplier = confounder_props['feature_bias_multiplier']
                for node in conf_G.nodes():
                    if 'feature' not in conf_G.nodes[node]:

                        base_feature = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
                        

                        if confounder_id == 0:
                            bias = np.random.normal(0, 0.3, size=base_feature.shape) * bias_multiplier
                        elif confounder_id == 1:
                            bias = np.zeros_like(base_feature)
                            bias[0] += np.random.normal(0, 0.8) * bias_multiplier
                            bias[3] += np.random.normal(0, 0.6) * bias_multiplier
                        elif confounder_id == 2:
                            bias = np.random.normal(0, 1.2, size=base_feature.shape) * bias_multiplier

                            authority_bias = np.array([2.0, -1.5, 1.8, 2.5, -1.2]) * (bias_multiplier / 3.2)
                            bias += authority_bias
                        elif confounder_id == 3:
                            bias = np.zeros_like(base_feature)

                            bias[1] += np.random.normal(0, 1.0) * bias_multiplier
                            bias[2] += np.random.normal(0, 1.1) * bias_multiplier
                            bias[4] += np.random.normal(0, 0.9) * bias_multiplier
                        elif confounder_id == 4:
                            bias = np.zeros_like(base_feature)
                            bias[0] += np.random.normal(0, 0.5) * bias_multiplier
                            bias[2] += np.random.normal(0, 0.7) * bias_multiplier
                        
                        conf_G.nodes[node]['feature'] = (base_feature + bias).astype(np.float32)
                
            except Exception:

                conf_G = nx.path_graph(6)
                bias_multiplier = confounder_props['feature_bias_multiplier']
                for node in conf_G.nodes():
                    base_feature = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
                    bias = np.random.normal(0, 0.5, size=base_feature.shape) * bias_multiplier
                    conf_G.nodes[node]['feature'] = (base_feature + bias).astype(np.float32)
                conf_role_id = [confounder_id] * conf_G.number_of_nodes()
            

            if conf_G.number_of_nodes() == 0 or new_G.number_of_nodes() == 0:
                continue


            node_offset = max(new_G.nodes()) + 1 if new_G.number_of_nodes() > 0 else 0
            

            mapping = {old: old + node_offset for old in conf_G.nodes()}
            conf_G_renamed = nx.relabel_nodes(conf_G, mapping)
            

            new_G = nx.union(new_G, conf_G_renamed)
            

            new_role_id.extend(conf_role_id)
            

            if new_G.number_of_nodes() > node_offset and node_offset > 0:
                main_nodes = [n for n in new_G.nodes() if n < node_offset]
                conf_nodes = [n for n in new_G.nodes() if n >= node_offset]
                
                if main_nodes and conf_nodes:

                    disruption_factor = confounder_props['structural_disruption']
                    base_connections = 2
                    num_connections = max(1, int(base_connections * disruption_factor))
                    
                    if confounder_id == 0:
                        for _ in range(num_connections):
                            main_node = random.choice(main_nodes)
                            conf_node = random.choice(conf_nodes)
                            new_G.add_edge(main_node, conf_node)
                    
                    elif confounder_id == 1:

                        authority_node = random.choice(main_nodes)
                        selected_conf_nodes = random.sample(conf_nodes, min(num_connections, len(conf_nodes)))
                        for conf_node in selected_conf_nodes:
                            new_G.add_edge(authority_node, conf_node)
                    
                    elif confounder_id == 2:

                        for _ in range(num_connections * 2):
                            main_node = random.choice(main_nodes)
                            conf_node = random.choice(conf_nodes)
                            new_G.add_edge(main_node, conf_node)

                        if len(conf_nodes) > 1:
                            for _ in range(max(1, len(conf_nodes) // 2)):
                                u, v = random.sample(conf_nodes, 2)
                                new_G.add_edge(u, v)
                    
                    elif confounder_id == 3:

                        for _ in range(num_connections):
                            main_node = random.choice(main_nodes)

                            conf_node = random.choice(conf_nodes)
                            new_G.add_edge(main_node, conf_node)
                    
                    elif confounder_id == 4:

                        sorted_main = sorted(random.sample(main_nodes, min(num_connections, len(main_nodes))))
                        sorted_conf = sorted(random.sample(conf_nodes, min(num_connections, len(conf_nodes))))
                        for i in range(min(len(sorted_main), len(sorted_conf))):
                            new_G.add_edge(sorted_main[i], sorted_conf[i])
            

            new_G.graph['attached_confounders'].append(confounder_id)
            new_G.graph['confounding_effects'].append({
                'element_id': confounder_id,
                'element_name': confounder_props['name'],
                'intensity': confounder_props['confounding_intensity'],
                'theory': confounder_props['theory']
            })

        except Exception as e:
            print(f"Warning: Failed to attach a confounder element (ID: {confounder_id}). Error: {e}")
            continue
    
    return new_G, new_role_id

def generate_dataset_with_paper_confounder(split='train', confound_prob=0.7, confounder_pool_ids=None, 
                                         background_confounder_prob_train=0.10, background_confounder_prob_test=0.05,
                                         target_causal_elements=None):

    if confounder_pool_ids is None:
        confounder_pool_ids = list(range(5))
        

    G, role_id, label, edge_index, element_presence = generate_paper_dataset_with_specific_elements(target_causal_elements)
    

    confounder_prob = confound_prob
    

    causal_elements_present = any(element_presence.values())
    
    if causal_elements_present and random.random() < confounder_prob:

        primary_confounder_id = confounder_pool_ids[0] if confounder_pool_ids else random.choice(range(5))
        confounder_props = ELEMENT_CONFOUNDING_PROPERTIES[primary_confounder_id]
        

        if confounder_props['confounding_intensity'] == 'minimal':
            num_to_attach = random.randint(1, 2)
        elif confounder_props['confounding_intensity'] == 'mild':
            num_to_attach = random.randint(1, 3)
        elif confounder_props['confounding_intensity'] == 'moderate':
            num_to_attach = random.randint(2, 3)
        elif confounder_props['confounding_intensity'] == 'strong':
            num_to_attach = random.randint(2, 4)
        else:
            num_to_attach = random.randint(3, 5)
        
        G, role_id = attach_confounders_to_paper_graph(G, role_id, num_to_attach, confounder_pool_ids)
        

        if confounder_pool_ids:
            primary_confounder = confounder_pool_ids[0]
            confusion_prob = ELEMENT_CONFOUNDING_PROPERTIES[primary_confounder]['label_confusion_prob']
            
            if random.random() < confusion_prob:

                if confounder_pool_ids:

                    if primary_confounder == 0:
                        if random.random() < 0.5:
                            label = (label + random.choice([-1, 1])) % 5
                    elif primary_confounder == 1:

                        label = (label + random.choice([-1, 1])) % 5
                    elif primary_confounder == 2:

                        if label in [0, 1, 2]:
                            label = random.choice([3, 4])
                        else:
                            label = random.choice([0, 1, 2, 3, 4])
                    elif primary_confounder == 3:

                        theme_clusters = [[0, 2], [1, 3], [4]]
                        for cluster in theme_clusters:
                            if label in cluster:
                                label = random.choice(cluster)
                                break
                    elif primary_confounder == 4:

                        complexity_order = [0, 1, 2, 3, 4]
                        current_idx = complexity_order.index(label)

                        if current_idx > 0 and current_idx < len(complexity_order) - 1:
                            label = complexity_order[current_idx + random.choice([-1, 1])]


    background_prob = background_confounder_prob_train if split != 'test' else background_confounder_prob_test
    if random.random() < background_prob:

        bg_confounder_pool = [0, 4]
        G, role_id = attach_confounders_to_paper_graph(G, role_id, 1, bg_confounder_pool)

    return G, role_id, label, edge_index

def generate_paper_conf_dataset(
    num_samples: int,
    output_filename: str,
    split: str,
    experiment_config: dict = None,
    feature_noise: float = 0.05,
    edge_del_prob: float = 0.02,
    edge_add_prob: float = 0.02
):


    if experiment_config is None:
        experiment_config = {
            'confound_prob': 0.7,
            'confounder_pool_ids': list(range(5)),
            'background_confounder_prob_train': 0.10,
            'background_confounder_prob_test': 0.05,
            'target_causal_elements': None
        }
    
    confound_prob = experiment_config.get('confound_prob', 0.7)
    confounder_pool_ids = experiment_config.get('confounder_pool_ids', list(range(5)))
    bg_prob_train = experiment_config.get('background_confounder_prob_train', 0.10)
    bg_prob_test = experiment_config.get('background_confounder_prob_test', 0.05)
    target_causal_elements = experiment_config.get('target_causal_elements', None)
    
    data_dir = './data/analysis/'
    os.makedirs(data_dir, exist_ok=True)

    node_features_list = []
    edge_index_list = []
    label_list = []
    ground_truth_list = []
    role_id_list = []
    pos_list = []

    
    pbar = tqdm(total=num_samples, desc=f'Generating {split.upper()} Samples')
    
    generated_count = 0

    debug_label_conf_counts = {l: 0 for l in range(5)}
    debug_label_counts = {l: 0 for l in range(5)}
    
    while generated_count < num_samples:
        try:

            G, role_id, label, edge_index = generate_dataset_with_paper_confounder(
                split=split, 
                confound_prob=confound_prob,
                confounder_pool_ids=confounder_pool_ids,
                background_confounder_prob_train=bg_prob_train,
                background_confounder_prob_test=bg_prob_test,
                target_causal_elements=target_causal_elements
            )


            if G.number_of_nodes() == 0:
                continue


            for n, attrs in G.nodes(data=True):
                if not any(k in attrs for k in ('feature', 'features', 'feat')):
                    G.nodes[n]['feature'] = np.random.normal(
                        loc=_MEAN, scale=_STD
                    ).astype(np.float32)


            for n, attrs in G.nodes(data=True):
                if 'feature' in attrs:
                    orig = attrs['feature']
                elif 'features' in attrs:
                    orig = attrs['features']
                else:
                    orig = attrs['feat']
                

                if not isinstance(orig, np.ndarray):
                    orig = np.array(orig, dtype=np.float32)
                
                noise = np.random.normal(0.0, feature_noise, size=orig.shape).astype(np.float32)
                G.nodes[n]['feature'] = (orig + noise)


            edges = list(G.edges())
            if len(edges) > 0:

                num_del = int(edge_del_prob * len(edges))
                if num_del > 0:
                    del_edges = random.sample(edges, min(num_del, len(edges)))
                    G.remove_edges_from(del_edges)
                

                nodes = list(G.nodes())
                if len(nodes) > 1:
                    num_add = int(edge_add_prob * len(edges))
                    for _ in range(num_add):
                        u, v = random.sample(nodes, 2)
                        if not G.has_edge(u, v):
                            G.add_edge(u, v)


            node_mapping = {old: new for new, old in enumerate(sorted(G.nodes()))}
            G = nx.relabel_nodes(G, node_mapping)
            

            if len(role_id) >= len(G.nodes()):
                role_id = role_id[:len(G.nodes())]
            else:

                while len(role_id) < len(G.nodes()):
                    role_id.append(role_id[-1] if role_id else 0)


            if G.number_of_edges() > 0:
                ei = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
                if ei.shape[0] != 2 and ei.shape[1] == 2:
                    ei = ei.T
                edge_idx = ei.astype(np.int64)
            else:

                edge_idx = np.array([[], []], dtype=np.int64)


            role_id = np.array(role_id[:len(G.nodes())], dtype=np.int64)


            if edge_idx.shape[1] > 0:
                gt = find_gd(edge_idx, role_id)
            else:
                gt = np.array([], dtype=np.float64)


            features = []
            for i in sorted(G.nodes()):
                attrs = G.nodes[i]
                if 'feature' in attrs:
                    feat = attrs['feature']
                elif 'features' in attrs:
                    feat = attrs['features']
                else:
                    feat = attrs['feat']
                

                if not isinstance(feat, np.ndarray):
                    feat = np.array(feat, dtype=np.float32)
                if feat.ndim == 0:
                    feat = np.array([feat], dtype=np.float32)
                elif feat.ndim == 1 and len(feat) != len(_MEAN):

                    feat = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
                
                features.append(feat)
            
            if features:
                node_feats = np.vstack(features).astype(np.float32)
            else:
                node_feats = np.array([], dtype=np.float32).reshape(0, len(_MEAN))


            if G.number_of_nodes() > 0:
                pos_arr = np.array(list(nx.spring_layout(G).values()), dtype=np.float32)
            else:
                pos_arr = np.array([], dtype=np.float32).reshape(0, 2)


            if 'attached_confounders' in G.graph:
                if len(G.graph['attached_confounders']) > 0:
                    debug_label_conf_counts[label] += 1
            debug_label_counts[label] += 1


            node_features_list.append(node_feats)
            edge_index_list.append(edge_idx)
            label_list.append(int(label))
            ground_truth_list.append(gt)
            role_id_list.append(role_id)
            pos_list.append(pos_arr)
            
            generated_count += 1
            pbar.update(1)

        except Exception as e:

            print(f"\nError generating a sample, skipping. Error: {e}")
            traceback.print_exc()
            continue
    
    pbar.close()


    if node_features_list:
        avg_nodes = np.mean([nf.shape[0] for nf in node_features_list])
        avg_edges = np.mean([ei.shape[1] for ei in edge_index_list])
        

        label_counts = {}
        for label in label_list:
            label_counts[label] = label_counts.get(label, 0) + 1
        print(f"Label distribution: {label_counts}")
        

        attach_rates = {l: (debug_label_conf_counts[l] / debug_label_counts[l] if debug_label_counts[l] > 0 else 0)
                for l in range(5)}
        

        if experiment_config and 'confounder_pool_ids' in experiment_config:
            confounder_ids = experiment_config['confounder_pool_ids']
            if confounder_ids:
                primary_confounder = confounder_ids[0]
                confounder_props = ELEMENT_CONFOUNDING_PROPERTIES[primary_confounder]
    

    save_path = osp.join(data_dir, output_filename)
    np.save(save_path, {
        'features': node_features_list,
        'edge_index': edge_index_list,
        'label': label_list,
        'ground_truth': ground_truth_list,
        'role_id': role_id_list,
        'pos': pos_list,
        'experiment_config': experiment_config
    })
    print(f"Saved {split} dataset to: {save_path}")

def run_systematic_paper_experiments():
    
    for i, exp_config in enumerate(PAPER_EXPERIMENTS, 1):
        
        analysis_type = exp_config['type']
        element_id = exp_config['element_id']
        target_causal_elements = exp_config.get('target_causal_elements', [element_id])
        confounder_pool_ids = exp_config.get('confounder_pool_ids', [])
        
        try:

            print(f"\nGenerating TRAIN set for {exp_config['name']}...")
            generate_paper_conf_dataset(
                num_samples=3000,
                output_filename=f"train_{exp_config['file_suffix']}.npy",
                split='train',
                experiment_config=exp_config
            )
            

            print(f"\nGenerating VAL set for {exp_config['name']}...")
            generate_paper_conf_dataset(
                num_samples=375,
                output_filename=f"val_{exp_config['file_suffix']}.npy",
                split='val',
                experiment_config=exp_config
            )
            

            print(f"\nGenerating TEST set for {exp_config['name']}...")
            generate_paper_conf_dataset(
                num_samples=375,
                output_filename=f"test_{exp_config['file_suffix']}.npy",
                split='test',
                experiment_config=exp_config
            )
            
            print(f"✓ Completed experiment: {exp_config['name']}")
            
        except Exception as e:
            print(f"✗ Error in experiment {exp_config['name']}: {str(e)}")
            print("Continuing with next experiment...")
            continue

if __name__ == '__main__':

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    run_systematic_paper_experiments()
