import os
import os.path as osp
import random
from tqdm import tqdm
import traceback
import json

import numpy as np
import networkx as nx
import torch


try:
    from molecular import motif_generators, motif_connectors
    from BA3_loc import find_gd
except ImportError:
    print("Error：Unable to import必要module。")
    print("Please ensure molecular.py and BA3_loc.py files are in the same directory as this script。")
    exit()


_MEAN = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)
_STD  = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)


ALL_MOTIF_TYPES = {
    1: "星形",
    2: "路径形", 
    3: "扇形",
    4: "尖角多边形",
    5: "随机二分图",
    6: "树形",
    7: "三叉戟形",
    8: "锥形连接图",
    9: "链旁路形",
    10: "部分多边形"
}


MOTIF_EXPERIMENTS = []


for motif_id in range(1, 6):

    MOTIF_EXPERIMENTS.append({
        'name': f'causal_{motif_id}',
        'type': 'causal',
        'motif_id': motif_id,
        'file_suffix': f'causal_{motif_id}_no_conf',
        'target_causal_motifs': [motif_id],
        'confounder_pool_ids': [],
        'confound_prob': 0.0,
        'background_confounder_prob_train': 0.0,
        'background_confounder_prob_test': 0.0,
        'description': f'{ALL_MOTIF_TYPES[motif_id]}作为因果信息（无混杂）'
    })
    

    other_motifs = [mid for mid in range(1, 6) if mid != motif_id]
    MOTIF_EXPERIMENTS.append({
        'name': f'confound_{motif_id}',
        'type': 'confounding',
        'motif_id': motif_id,
        'file_suffix': f'confound_{motif_id}_others_causal',
        'target_causal_motifs': other_motifs,
        'confounder_pool_ids': [motif_id],
        'confound_prob': 0.7,
        'background_confounder_prob_train': 0.1,
        'background_confounder_prob_test': 0.1,
        'description': f'{ALL_MOTIF_TYPES[motif_id]}作为混杂信息，其他motif作为因果信息'
    })

class ExperimentDatasetGenerator:

    def __init__(self, experiment_config):
        self.config = experiment_config
        self.causal_motif_ids = experiment_config.get('causal_motif_ids', list(range(1, 6)))
        self.confounding_motif_ids = experiment_config.get('confounding_motif_ids', [])
        self.intervention_prob = experiment_config.get('intervention_prob', 0.6)
        self.feature_noise_level = experiment_config.get('feature_noise_level', 0.05)
        
    def generate_targeted_causal_dataset(self, split='train'):


        G, role_id, label, edge_index = self._generate_selective_causal_graph()
        

        if self.confounding_motif_ids and random.random() < self.intervention_prob:
            G = self._attach_specific_confounders(G, self.confounding_motif_ids)
            
        return G, role_id, label, edge_index
    
    def _generate_selective_causal_graph(self):


        label = random.choice([0, 1, 2, 3, 4])
        

        G = nx.Graph()
        role_id = []
        

        motif_count = 0
        for causal_id in self.causal_motif_ids:
            if motif_count >= 3:
                break
                
            try:

                motif_G, motif_role = motif_generators[causal_id][0](
                    size=random.randint(5, 15),
                    branches=random.randint(2, 4),
                    node_feature_mean=_MEAN,
                    std=_STD
                )
                
                if motif_G.number_of_nodes() > 0:

                    if G.number_of_nodes() == 0:
                        G = motif_G.copy()
                        role_id = motif_role.copy()
                    else:

                        G, role_id, _, _ = motif_connectors[1][0](G, motif_G)
                    
                    motif_count += 1
                    
            except Exception as e:
                print(f"Warning: Failed to generate causal motif {causal_id}: {e}")
                continue
        

        if G.number_of_nodes() == 0:
            G, role_id, label, edge_index = self._generate_fallback_graph()
            return G, role_id, label, edge_index
            

        if G.number_of_edges() > 0:
            edge_index = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
            if edge_index.shape[0] != 2:
                edge_index = edge_index.T
        else:
            edge_index = np.array([[], []], dtype=np.int64)
            
        return G, role_id, label, edge_index
    
    def _generate_fallback_graph(self):
        

        G, role_id = motif_generators[1][0](
            size=random.randint(8, 15),
            branches=random.randint(2, 4),
            node_feature_mean=_MEAN,
            std=_STD
        )
        label = random.choice([0, 1, 2, 3, 4])
        
        if G.number_of_edges() > 0:
            edge_index = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
            if edge_index.shape[0] != 2:
                edge_index = edge_index.T
        else:
            edge_index = np.array([[], []], dtype=np.int64)
            
        return G, role_id, label, edge_index
    
    def _attach_specific_confounders(self, G, confounder_ids):

        new_G = G.copy()
        
        for conf_id in confounder_ids:

            if conf_id < 1 or conf_id > 5:
                print(f"Warning: Skipping confounder motif {conf_id} - only motifs 1-5 are supported")
                continue
                
            try:

                confounder_motif, _ = motif_generators[conf_id][0](
                    size=random.randint(8, 20),
                    branches=random.randint(2, 4),
                    node_feature_mean=_MEAN,
                    std=_STD
                )
                
                if confounder_motif.number_of_nodes() > 0 and new_G.number_of_nodes() > 0:

                    new_G, _, _, _ = motif_connectors[1][0](new_G, confounder_motif)
                    
            except Exception as e:
                print(f"Warning: Failed to attach confounder motif {conf_id}: {e}")
                continue
                
        return new_G

def attach_confounders_to_graph(G, num_confounders, confounder_pool_ids=None):

    if confounder_pool_ids is None:
        confounder_pool_ids = list(range(1, 6))
        
    new_G = G
    for _ in range(num_confounders):
        try:

            confounder_id = random.choice(confounder_pool_ids)
            

            if confounder_id < 1 or confounder_id > 5:
                print(f"Warning: Skipping confounder motif {confounder_id} - only motifs 1-5 are supported")
                continue
            

            confounder_motif, _ = motif_generators[confounder_id][0](
                size=random.randint(8, 20),
                branches=random.randint(2, 4),
                node_feature_mean=_MEAN,
                std=_STD
            )
            

            if confounder_motif.number_of_nodes() == 0 or new_G.number_of_nodes() == 0:
                continue


            new_G, _, _, _ = motif_connectors[1][0](new_G, confounder_motif)

        except Exception as e:
            print(f"Warning: Failed to attach a confounder motif (ID: {confounder_id}). Error: {e}")
            continue
            
    return new_G



def generate_Y0():
    
    a = random.randint(2, 4)
    motif1, role_id1 = motif_generators[1][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif2, role_id2 = motif_generators[2][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    

    try:
        motifY0, role_id, edge_index, _ = motif_connectors[1][0](motif1, motif2)
    except:

        motifY0 = motif1
        role_id = role_id1
        edge_index = torch.tensor(list(motif1.edges())).t().contiguous().cpu().numpy() if motif1.number_of_edges() > 0 else np.array([[], []], dtype=np.int64)
    
    label = 0
    return motifY0, role_id, label, edge_index

def generate_Y1():
    
    a = random.randint(2, 4)
    motif1, role_id1 = motif_generators[1][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif3, role_id3 = motif_generators[3][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    
    try:
        G, role_id, edge_index, _ = motif_connectors[1][0](motif1, motif3)
        motif5, role_id5 = motif_generators[5][0](
            random.randint(5, 10), a, _MEAN, _STD
        )
        motifY1, role_id, edge_index, _ = motif_connectors[1][0](G, motif5)
    except:

        motifY1 = motif1
        role_id = role_id1
        edge_index = torch.tensor(list(motif1.edges())).t().contiguous().cpu().numpy() if motif1.number_of_edges() > 0 else np.array([[], []], dtype=np.int64)
    
    label = 1
    return motifY1, role_id, label, edge_index

def generate_Y2():
    
    a = random.randint(2, 4)
    motif1, role_id1 = motif_generators[1][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif2, role_id2 = motif_generators[2][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    
    try:
        G, role_id, edge_index, _ = motif_connectors[1][0](motif1, motif2)
        motif5, role_id5 = motif_generators[5][0](
            random.randint(5, 10), a, _MEAN, _STD
        )
        motifY2, role_id, edge_index, _ = motif_connectors[1][0](G, motif5)
    except:

        motifY2 = motif1
        role_id = role_id1
        edge_index = torch.tensor(list(motif1.edges())).t().contiguous().cpu().numpy() if motif1.number_of_edges() > 0 else np.array([[], []], dtype=np.int64)
    
    label = 2
    return motifY2, role_id, label, edge_index

def generate_Y3():
    
    a = random.randint(2, 4)
    motif4, role_id4 = motif_generators[4][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif5, role_id5 = motif_generators[5][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    
    try:
        motifY3, role_id, edge_index, _ = motif_connectors[1][0](motif4, motif5)
    except:

        motifY3 = motif4
        role_id = role_id4
        edge_index = torch.tensor(list(motif4.edges())).t().contiguous().cpu().numpy() if motif4.number_of_edges() > 0 else np.array([[], []], dtype=np.int64)
    
    label = 3
    return motifY3, role_id, label, edge_index

def generate_Y4():
    
    a = random.randint(2, 4)
    motif3, role_id3 = motif_generators[3][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif4, role_id4 = motif_generators[4][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    
    try:
        motifY4, role_id, edge_index, _ = motif_connectors[1][0](motif3, motif4)
    except:

        motifY4 = motif3
        role_id = role_id3
        edge_index = torch.tensor(list(motif3.edges())).t().contiguous().cpu().numpy() if motif3.number_of_edges() > 0 else np.array([[], []], dtype=np.int64)
    
    label = 4
    return motifY4, role_id, label, edge_index

def generate_real_dataset_with_specific_motifs(target_motif_ids=None):

    if target_motif_ids is None:
        target_motif_ids = [1, 2, 3, 4, 5]
    

    motif_presence = {i: False for i in range(1, 11)}
    

    G = nx.Graph()
    role_id = []
    


    label_motif_combinations = {
        0: [1, 2],
        1: [1, 3, 5],
        2: [1, 2, 5],
        3: [4, 5],
        4: [3, 4]
    }
    

    possible_labels = []
    for label, required_motifs in label_motif_combinations.items():
        if all(motif_id in target_motif_ids for motif_id in required_motifs):
            possible_labels.append(label)
    

    if not possible_labels:

        if target_motif_ids:
            primary_motif = target_motif_ids[0]

            if primary_motif == 1:
                y = 0
            elif primary_motif == 2:
                y = 0
            elif primary_motif == 3:
                y = 1
            elif primary_motif == 4:
                y = 3
            elif primary_motif == 5:
                y = 1
            else:
                y = random.choice([0, 1, 2, 3, 4])
        else:
            y = random.choice([0, 1, 2, 3, 4])
        

        motifs_to_use = target_motif_ids[:3]
    else:

        y = random.choice(possible_labels)
        motifs_to_use = label_motif_combinations[y]
        

        motifs_to_use = [m for m in motifs_to_use if m in target_motif_ids]
    

    for i, motif_id in enumerate(motifs_to_use):
        try:
            a = random.randint(2, 4)
            motif_G, motif_role = motif_generators[motif_id][0](
                random.randint(5, 10), a, _MEAN, _STD
            )
            
            motif_presence[motif_id] = True
            
            if i == 0:

                G = motif_G.copy()
                role_id = motif_role.copy() if isinstance(motif_role, list) else list(motif_role)
            else:

                try:
                    G, role_id, _, _ = motif_connectors[1][0](G, motif_G)
                except:

                    G = nx.union(G, motif_G, rename=('G-', 'motif-'))
                    G = nx.convert_node_labels_to_integers(G)
                    role_id = [G.degree(node) for node in G.nodes()]
                    
        except Exception as e:
            print(f"Warning: Failed to generate motif {motif_id}: {e}")
            continue
    

    if G.number_of_nodes() == 0:
        try:

            fallback_motif = target_motif_ids[0] if target_motif_ids else 1
            G, role_id = motif_generators[fallback_motif][0](
                random.randint(8, 15), 3, _MEAN, _STD
            )
            motif_presence[fallback_motif] = True
        except:

            G = nx.Graph()
            G.add_nodes_from(range(5))
            G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
            for node in G.nodes():
                G.nodes[node]['feature'] = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
            role_id = [G.degree(node) for node in G.nodes()]
    

    if G.number_of_edges() > 0:
        try:
            edge_index = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
            if edge_index.shape[0] != 2:
                edge_index = edge_index.T
        except:
            edge_index = np.array([[], []], dtype=np.int64)
    else:
        edge_index = np.array([[], []], dtype=np.int64)
    
    return G, role_id, y, edge_index, motif_presence

def generate_dataset_with_confounder(split='train', confound_prob=0.6, confounder_pool_ids=None, 
                                   background_confounder_prob_train=0.20, background_confounder_prob_test=0.05,
                                   target_causal_motifs=None):

    if confounder_pool_ids is None:
        confounder_pool_ids = list(range(1, 6))
        

    G, role_id, label, edge_index, motif_presence = generate_real_dataset_with_specific_motifs(target_causal_motifs)
    

    confounder_prob = confound_prob
    

    causal_motifs_present = any(motif_presence.values())
    
    if causal_motifs_present and random.random() < confounder_prob:

        num_to_attach = random.randint(1, 3)
        G = attach_confounders_to_graph(G, num_to_attach, confounder_pool_ids)


    background_prob = background_confounder_prob_train if split != 'test' else background_confounder_prob_test
    if random.random() < background_prob:
        G = attach_confounders_to_graph(G, 1, confounder_pool_ids)

    return G, role_id, label, edge_index

def generate_mol_conf_dataset(
    num_samples: int,
    output_filename: str,
    split: str,
    experiment_config: dict = None,
    feature_noise: float = 0.05,
    edge_del_prob: float = 0.02,
    edge_add_prob: float = 0.02
):


    if experiment_config is None:
        experiment_config = {
            'confound_prob': 0.6,
            'confounder_pool_ids': list(range(1, 6)),
            'background_confounder_prob_train': 0.20,
            'background_confounder_prob_test': 0.05,
            'target_causal_motifs': None
        }
    
    confound_prob = experiment_config.get('confound_prob', 0.6)
    confounder_pool_ids = experiment_config.get('confounder_pool_ids', list(range(1, 6)))
    bg_prob_train = experiment_config.get('background_confounder_prob_train', 0.20)
    bg_prob_test = experiment_config.get('background_confounder_prob_test', 0.05)
    target_causal_motifs = experiment_config.get('target_causal_motifs', None)
    
    data_dir = './data/analysis/'
    os.makedirs(data_dir, exist_ok=True)

    node_features_list = []
    edge_index_list = []
    label_list = []
    ground_truth_list = []
    role_id_list = []
    pos_list = []
    
    print(f"\nGenerating {split} dataset with ENHANCED confounders...")
    conf_prob_display = f"{confound_prob*100:.0f}%" if split != 'test' else "10%"
    bg_prob_display = f"{bg_prob_train*100:.0f}%" if split != 'test' else f"{bg_prob_test*100:.0f}%"
    print(f"Causal Confounder Prob: {conf_prob_display}, Background Confounder Prob: {bg_prob_display}")
    print(f"Target causal motifs: {target_causal_motifs}")
    print(f"Confounder motifs: {confounder_pool_ids}")
    
    pbar = tqdm(total=num_samples, desc=f'Generating {split.upper()} Samples')
    
    generated_count = 0
    while generated_count < num_samples:
        try:

            G, role_id, label, edge_index = generate_dataset_with_confounder(
                split=split, 
                confound_prob=confound_prob,
                confounder_pool_ids=confounder_pool_ids,
                background_confounder_prob_train=bg_prob_train,
                background_confounder_prob_test=bg_prob_test,
                target_causal_motifs=target_causal_motifs
            )


            if G.number_of_nodes() == 0:
                continue


            for n, attrs in G.nodes(data=True):
                if not any(k in attrs for k in ('feature', 'features', 'feat')):
                    G.nodes[n]['feature'] = np.random.normal(
                        loc=_MEAN, scale=_STD
                    ).astype(np.float32)


            for n, attrs in G.nodes(data=True):
                if 'feature' in attrs:
                    orig = attrs['feature']
                elif 'features' in attrs:
                    orig = attrs['features']
                else:
                    orig = attrs['feat']
                

                if not isinstance(orig, np.ndarray):
                    orig = np.array(orig, dtype=np.float32)
                
                noise = np.random.normal(0.0, feature_noise, size=orig.shape).astype(np.float32)
                G.nodes[n]['feature'] = (orig + noise)


            edges = list(G.edges())
            if len(edges) > 0:

                num_del = int(edge_del_prob * len(edges))
                if num_del > 0:
                    del_edges = random.sample(edges, min(num_del, len(edges)))
                    G.remove_edges_from(del_edges)
                

                nodes = list(G.nodes())
                if len(nodes) > 1:
                    num_add = int(edge_add_prob * len(edges))
                    for _ in range(num_add):
                        u, v = random.sample(nodes, 2)
                        if not G.has_edge(u, v):
                            G.add_edge(u, v)


            G = nx.convert_node_labels_to_integers(G, first_label=0)


            if G.number_of_edges() > 0:
                ei = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
                if ei.shape[0] != 2 and ei.shape[1] == 2:
                    ei = ei.T
                edge_idx = ei.astype(np.int64)
            else:

                edge_idx = np.array([[], []], dtype=np.int64)


            nodes = sorted(G.nodes())
            role_id = np.array([G.degree(i) for i in nodes], dtype=np.int64)


            if edge_idx.shape[1] > 0:
                gt = find_gd(edge_idx, role_id)
            else:
                gt = np.array([], dtype=np.float64)


            features = []
            for i in sorted(G.nodes()):
                attrs = G.nodes[i]
                if 'feature' in attrs:
                    feat = attrs['feature']
                elif 'features' in attrs:
                    feat = attrs['features']
                else:
                    feat = attrs['feat']
                

                if not isinstance(feat, np.ndarray):
                    feat = np.array(feat, dtype=np.float32)
                if feat.ndim == 0:
                    feat = np.array([feat], dtype=np.float32)
                elif feat.ndim == 1 and len(feat) != len(_MEAN):

                    feat = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
                
                features.append(feat)
            
            if features:
                node_feats = np.vstack(features).astype(np.float32)
            else:
                node_feats = np.array([], dtype=np.float32).reshape(0, len(_MEAN))


            if G.number_of_nodes() > 0:
                pos_arr = np.array(list(nx.spring_layout(G).values()), dtype=np.float32)
            else:
                pos_arr = np.array([], dtype=np.float32).reshape(0, 2)


            node_features_list.append(node_feats)
            edge_index_list.append(edge_idx)
            label_list.append(int(label))
            ground_truth_list.append(gt)
            role_id_list.append(role_id)
            pos_list.append(pos_arr)
            
            generated_count += 1
            pbar.update(1)

        except Exception as e:

            print(f"\nError generating a sample, skipping. Error: {e}")
            traceback.print_exc()
            continue
    
    pbar.close()


    if node_features_list:
        avg_nodes = np.mean([nf.shape[0] for nf in node_features_list])
        avg_edges = np.mean([ei.shape[1] for ei in edge_index_list])
        print(f"\nDataset generation for '{split}' complete.")
        print(f"# Graphs: {len(node_features_list)}")
        print(f"Avg Nodes: {avg_nodes:.2f}")
        print(f"Avg Edges: {avg_edges:.2f}")
        

        label_counts = {}
        for label in label_list:
            label_counts[label] = label_counts.get(label, 0) + 1
        print(f"Label distribution: {label_counts}")
    

    save_path = osp.join(data_dir, output_filename)
    np.save(save_path, {
        'node_features': node_features_list,
        'edge_index': edge_index_list,
        'label': label_list,
        'ground_truth': ground_truth_list,
        'role_id': role_id_list,
        'pos': pos_list
    })
    print(f"Saved {split} dataset to: {save_path}")

def run_systematic_experiments():
    for i, exp_config in enumerate(MOTIF_EXPERIMENTS, 1):
        
        analysis_type = exp_config['type']
        motif_id = exp_config['motif_id']
        target_causal_motifs = exp_config.get('target_causal_motifs', [motif_id])
        confounder_pool_ids = exp_config.get('confounder_pool_ids', [])
        
        try:

            print(f"\nGenerating TRAIN set for {exp_config['name']}...")
            generate_mol_conf_dataset(
                num_samples=3000,
                output_filename=f"train_{exp_config['file_suffix']}.npy",
                split='train',
                experiment_config=exp_config
            )
            

            print(f"\nGenerating VAL set for {exp_config['name']}...")
            generate_mol_conf_dataset(
                num_samples=375,
                output_filename=f"val_{exp_config['file_suffix']}.npy",
                split='val',
                experiment_config=exp_config
            )
            

            print(f"\nGenerating TEST set for {exp_config['name']}...")
            generate_mol_conf_dataset(
                num_samples=375,
                output_filename=f"test_{exp_config['file_suffix']}.npy",
                split='test',
                experiment_config=exp_config
            )
            
            print(f"✓ Completed experiment: {exp_config['name']}")
            
        except Exception as e:
            print(f"✗ Error in experiment {exp_config['name']}: {str(e)}")
            print("Continuing with next experiment...")
            continue

if __name__ == '__main__':

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    run_systematic_experiments()
