import os
import os.path as osp
import random
from tqdm import tqdm
import traceback

import numpy as np
import networkx as nx
import torch


try:
    from molecular import motif_generators, motif_connectors
    from BA3_loc import find_gd
except ImportError:
    print("Error：Unable to import 'molecular' or 'BA3_loc' module。")
    print("Please ensure molecular.py and BA3_loc.py files are in the same directory as this script。")
    exit()


CONFOUND_PROB = 0.7
_MEAN = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)
_STD  = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)



CONFOUNDER_POOL_IDS = list(range(1, 6))
BACKGROUND_CONFOUNDER_PROB_TRAIN = 0.05
BACKGROUND_CONFOUNDER_PROB_TEST = 0

def attach_confounders_to_graph(G, num_confounders):

    new_G = G
    for _ in range(num_confounders):
        try:

            confounder_id = random.choice(CONFOUNDER_POOL_IDS)
            


            confounder_motif, _ = motif_generators[confounder_id][0](
                size=random.randint(8, 20),
                branches=random.randint(2, 4),
                node_feature_mean=_MEAN,
                std=_STD
            )
            

            if confounder_motif.number_of_nodes() == 0 or new_G.number_of_nodes() == 0:
                continue


            new_G, _, _, _ = motif_connectors[1][0](new_G, confounder_motif)

        except Exception as e:
            print(f"Warning: Failed to attach a confounder motif (ID: {confounder_id}). Error: {e}")
            continue
            
    return new_G



def generate_Y0():
    
    a = random.randint(2, 4)
    motif6, role_id6 = motif_generators[6][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif7, role_id7 = motif_generators[7][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motifY0, role_id, edge_index = motif_connectors[2][0](motif6, motif7, 2)
    label = 0
    return motifY0, role_id, label, edge_index

def generate_Y1():
    
    a = random.randint(2, 4)
    motif6, role_id6 = motif_generators[6][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif8, role_id8 = motif_generators[8][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    G, role_id, edge_index, node_features = motif_connectors[1][0](motif6, motif8)
    motif10, role_id10 = motif_generators[10][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motifY1, role_id, edge_index = motif_connectors[2][0](G, motif10, 2)
    label = 1
    return motifY1, role_id, label, edge_index

def generate_Y2():
    
    a = random.randint(2, 4)
    motif6, role_id6 = motif_generators[6][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif7, role_id7 = motif_generators[7][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    G, role_id, edge_index = motif_connectors[3][0](motif6, motif7, 2)
    motif10, role_id10 = motif_generators[10][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motifY2, role_id, edge_index = motif_connectors[2][0](G, motif10, 2)
    label = 2
    return motifY2, role_id, label, edge_index

def generate_Y3():
    
    a = random.randint(2, 4)
    motif9, role_id9 = motif_generators[9][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif10, role_id10 = motif_generators[10][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motifY3, role_id, edge_index, node_features = motif_connectors[1][0](motif9, motif10)
    label = 3
    return motifY3, role_id, label, edge_index

def generate_Y4():
    
    a = random.randint(2, 4)
    motif8, role_id8 = motif_generators[8][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motif9, role_id9 = motif_generators[9][0](
        random.randint(5, 10), a, _MEAN, _STD
    )
    motifY4, role_id, edge_index = motif_connectors[2][0](motif8, motif9, 2)
    label = 4
    return motifY4, role_id, label, edge_index

def generate_real_dataset():

    y = random.choice([0, 1, 2, 3, 4])
    

    motif_present = [False] * 11
    
    if y == 0:
        G, role_id, label, edge_index = generate_Y0()
        motif_present[6] = True
        motif_present[7] = True
    elif y == 1:
        G, role_id, label, edge_index = generate_Y1()
        motif_present[6] = True
        motif_present[8] = True
        motif_present[10] = True
    elif y == 2:
        G, role_id, label, edge_index = generate_Y2()
        motif_present[6] = True
        motif_present[7] = True
        motif_present[10] = True
    elif y == 3:
        G, role_id, label, edge_index = generate_Y3()
        motif_present[9] = True
        motif_present[10] = True
    elif y == 4:
        G, role_id, label, edge_index = generate_Y4()
        motif_present[8] = True
        motif_present[9] = True
    
    return (G, role_id, label, edge_index, 
            motif_present[6], motif_present[7], motif_present[8], 
            motif_present[9], motif_present[10])

def generate_dataset_with_confounder(split='train'):


    (G, role_id, label, edge_index, 
     motif6_present, motif7_present, motif8_present, motif9_present, motif10_present) = generate_real_dataset()
    

    confounder_prob = CONFOUND_PROB
    

    causal_motifs_present = any([
        motif6_present, motif7_present, motif8_present, motif9_present, motif10_present
    ])
    
    if causal_motifs_present and random.random() < confounder_prob:

        num_to_attach = 1
        G = attach_confounders_to_graph(G, num_to_attach)


    background_prob = BACKGROUND_CONFOUNDER_PROB_TRAIN if split != 'test' else BACKGROUND_CONFOUNDER_PROB_TEST
    if random.random() < background_prob:
        G = attach_confounders_to_graph(G, 1)

    return G, role_id, label, edge_index

def generate_mol_conf_dataset(
    num_samples: int,
    output_filename: str,
    split: str,
    feature_noise: float = 0.05,
    edge_del_prob: float = 0.02,
    edge_add_prob: float = 0.02
):

    data_dir = './data/motif_new_conf/'
    os.makedirs(data_dir, exist_ok=True)

    node_features_list = []
    edge_index_list = []
    label_list = []
    ground_truth_list = []
    role_id_list = []
    pos_list = []
    
    print(f"\nGenerating {split} dataset with ENHANCED confounders using MOTIF 6-10 to MOTIF 1-5...")
    conf_prob_display = f"{CONFOUND_PROB*100:.0f}%" if split != 'test' else "10%"
    bg_prob_display = f"{BACKGROUND_CONFOUNDER_PROB_TRAIN*100:.0f}%" if split != 'test' else f"{BACKGROUND_CONFOUNDER_PROB_TEST*100:.0f}%"
    print(f"Causal Confounder Prob: {conf_prob_display}, Background Confounder Prob: {bg_prob_display}")
    print(f"Causal motifs: motif_generators (IDs 6-10)")
    print(f"Confounder motifs: motif_generators (IDs 1-5)")
    
    pbar = tqdm(total=num_samples, desc=f'Generating {split.upper()} Samples')
    
    generated_count = 0
    while generated_count < num_samples:
        try:

            G, role_id, label, edge_index = generate_dataset_with_confounder(split=split)


            if G.number_of_nodes() == 0:
                continue


            for n, attrs in G.nodes(data=True):
                if not any(k in attrs for k in ('feature', 'features', 'feat')):
                    G.nodes[n]['feature'] = np.random.normal(
                        loc=_MEAN, scale=_STD
                    ).astype(np.float32)


            for n, attrs in G.nodes(data=True):
                if 'feature' in attrs:
                    orig = attrs['feature']
                elif 'features' in attrs:
                    orig = attrs['features']
                else:
                    orig = attrs['feat']
                

                if not isinstance(orig, np.ndarray):
                    orig = np.array(orig, dtype=np.float32)
                
                noise = np.random.normal(0.0, feature_noise, size=orig.shape).astype(np.float32)
                G.nodes[n]['feature'] = (orig + noise)


            edges = list(G.edges())
            if len(edges) > 0:

                num_del = int(edge_del_prob * len(edges))
                if num_del > 0:
                    del_edges = random.sample(edges, min(num_del, len(edges)))
                    G.remove_edges_from(del_edges)
                

                nodes = list(G.nodes())
                if len(nodes) > 1:
                    num_add = int(edge_add_prob * len(edges))
                    for _ in range(num_add):
                        u, v = random.sample(nodes, 2)
                        if not G.has_edge(u, v):
                            G.add_edge(u, v)


            G = nx.convert_node_labels_to_integers(G, first_label=0)


            if G.number_of_edges() > 0:
                ei = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
                if ei.shape[0] != 2 and ei.shape[1] == 2:
                    ei = ei.T
                edge_idx = ei.astype(np.int64)
            else:

                edge_idx = np.array([[], []], dtype=np.int64)


            nodes = sorted(G.nodes())
            role_id = np.array([G.degree(i) for i in nodes], dtype=np.int64)


            if edge_idx.shape[1] > 0:
                gt = find_gd(edge_idx, role_id)
            else:
                gt = np.array([], dtype=np.float64)


            features = []
            for i in sorted(G.nodes()):
                attrs = G.nodes[i]
                if 'feature' in attrs:
                    feat = attrs['feature']
                elif 'features' in attrs:
                    feat = attrs['features']
                else:
                    feat = attrs['feat']
                

                if not isinstance(feat, np.ndarray):
                    feat = np.array(feat, dtype=np.float32)
                if feat.ndim == 0:
                    feat = np.array([feat], dtype=np.float32)
                elif feat.ndim == 1 and len(feat) != len(_MEAN):

                    feat = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
                
                features.append(feat)
            
            if features:
                node_feats = np.vstack(features).astype(np.float32)
            else:
                node_feats = np.array([], dtype=np.float32).reshape(0, len(_MEAN))


            if G.number_of_nodes() > 0:
                pos_arr = np.array(list(nx.spring_layout(G).values()), dtype=np.float32)
            else:
                pos_arr = np.array([], dtype=np.float32).reshape(0, 2)


            node_features_list.append(node_feats)
            edge_index_list.append(edge_idx)
            label_list.append(int(label))
            ground_truth_list.append(gt)
            role_id_list.append(role_id)
            pos_list.append(pos_arr)
            
            generated_count += 1
            pbar.update(1)

        except Exception as e:

            print(f"\nError generating a sample, skipping. Error: {e}")
            traceback.print_exc()
            continue
    
    pbar.close()


    if node_features_list:
        avg_nodes = np.mean([nf.shape[0] for nf in node_features_list])
        avg_edges = np.mean([ei.shape[1] for ei in edge_index_list])
        print(f"\nDataset generation for '{split}' complete.")
        print(f"# Graphs: {len(node_features_list)}")
        print(f"Avg Nodes: {avg_nodes:.2f}")
        print(f"Avg Edges: {avg_edges:.2f}")
        

        label_counts = {}
        for label in label_list:
            label_counts[label] = label_counts.get(label, 0) + 1
        print(f"Label distribution: {label_counts}")
    

    save_path = osp.join(data_dir, output_filename)
    np.save(save_path, {
        'node_features': node_features_list,
        'edge_index': edge_index_list,
        'label': label_list,
        'ground_truth': ground_truth_list,
        'role_id': role_id_list,
        'pos': pos_list
    })
    print(f"Saved {split} dataset to: {save_path}")

def main():
    generate_mol_conf_dataset(
        num_samples=1500,
        output_filename=f'train_mot_conf_{CONFOUND_PROB}.npy',
        split='train'
    )
    

    generate_mol_conf_dataset(
        num_samples=200,
        output_filename=f'val_mot_conf_{CONFOUND_PROB}.npy',
        split='val'
    )
    

    generate_mol_conf_dataset(
        num_samples=200,
        output_filename=f'test_mot_conf_{CONFOUND_PROB}.npy',
        split='test'
    )

if __name__ == '__main__':

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    main()
