import os
import os.path as osp
import random
from tqdm import tqdm
import traceback

import numpy as np
import networkx as nx
import torch

try:
    from molecular import generate_real_dataset, molecular_generators, motif_connectors
    from BA3_loc import find_gd
except ImportError:
    print("Error: Unable to import 'molecular' or 'BA3_loc' modules.")
    print("Please ensure molecular.py and BA3_loc.py files are in the same directory as this script.")
    exit()


CONFOUND_PROB = 0.8
_MEAN = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)
_STD  = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)


CONFOUNDER_POOL_IDS = list(range(6, 27))
NUM_CONFOUNDERS_RANGE = (2, 5)
BACKGROUND_CONFOUNDER_PROB_TRAIN = 0.20
BACKGROUND_CONFOUNDER_PROB_TEST = 0.05

def attach_confounders_to_graph(G, num_confounders):

    new_G = G
    for _ in range(num_confounders):
        try:

            confounder_id = random.choice(CONFOUNDER_POOL_IDS)
            

            confounder_motif, _ = molecular_generators[confounder_id][0](
                size=random.randint(8, 20),
                node_feature_mean=_MEAN,
                std=_STD
            )
            

            if confounder_motif.number_of_nodes() == 0 or new_G.number_of_nodes() == 0:
                continue


            new_G, _, _, _ = motif_connectors[1][0](new_G, confounder_motif)

        except Exception as e:
            print(f"Warning: Failed to attach a confounder motif (ID: {confounder_id}). Error: {e}")
            continue
            
    return new_G

def generate_dataset_with_confounder(split='train'):


    G, role_id, label, edge_index, motif1_present, motif2_present, \
    motif3_present, motif4_present, motif5_present = generate_real_dataset()
    

    confounder_prob = CONFOUND_PROB
    

    causal_motifs_present = any([
        motif1_present, motif2_present, motif3_present, 
        motif4_present, motif5_present
    ])
    
    if causal_motifs_present and random.random() < confounder_prob:

        num_to_attach = random.randint(*NUM_CONFOUNDERS_RANGE)
        G = attach_confounders_to_graph(G, num_to_attach)


    background_prob = BACKGROUND_CONFOUNDER_PROB_TRAIN if split != 'test' else BACKGROUND_CONFOUNDER_PROB_TEST
    if random.random() < background_prob:
        G = attach_confounders_to_graph(G, 1)

    return G, role_id, label, edge_index

def generate_mol_conf_dataset(
    num_samples: int,
    output_filename: str,
    split: str,
    feature_noise: float = 0.05,
    edge_del_prob: float = 0.02,
    edge_add_prob: float = 0.02
):

    data_dir = './data/motif_new_conf/'
    os.makedirs(data_dir, exist_ok=True)

    node_features_list = []
    edge_index_list = []
    label_list = []
    ground_truth_list = []
    role_id_list = []
    pos_list = []
    
    print(f"\nGenerating {split} dataset with ENHANCED confounders...")
    conf_prob_display = f"{CONFOUND_PROB*100:.0f}%" if split != 'test' else "10%"
    bg_prob_display = f"{BACKGROUND_CONFOUNDER_PROB_TRAIN*100:.0f}%" if split != 'test' else f"{BACKGROUND_CONFOUNDER_PROB_TEST*100:.0f}%"
    print(f"Causal Confounder Prob: {conf_prob_display}, Background Confounder Prob: {bg_prob_display}")
    
    pbar = tqdm(total=num_samples, desc=f'Generating {split.upper()} Samples')
    
    generated_count = 0
    while generated_count < num_samples:
        try:

            G, role_id, label, edge_index = generate_dataset_with_confounder(split=split)


            if G.number_of_nodes() == 0:
                continue


            for n, attrs in G.nodes(data=True):
                if not any(k in attrs for k in ('feature', 'features', 'feat')):
                    G.nodes[n]['feature'] = np.random.normal(
                        loc=_MEAN, scale=_STD
                    ).astype(np.float32)


            for n, attrs in G.nodes(data=True):
                if 'feature' in attrs:
                    orig = attrs['feature']
                elif 'features' in attrs:
                    orig = attrs['features']
                else:
                    orig = attrs['feat']
                

                if not isinstance(orig, np.ndarray):
                    orig = np.array(orig, dtype=np.float32)
                
                noise = np.random.normal(0.0, feature_noise, size=orig.shape).astype(np.float32)
                G.nodes[n]['feature'] = (orig + noise)


            edges = list(G.edges())
            if len(edges) > 0:

                num_del = int(edge_del_prob * len(edges))
                if num_del > 0:
                    del_edges = random.sample(edges, min(num_del, len(edges)))
                    G.remove_edges_from(del_edges)
                

                nodes = list(G.nodes())
                if len(nodes) > 1:
                    num_add = int(edge_add_prob * len(edges))
                    for _ in range(num_add):
                        u, v = random.sample(nodes, 2)
                        if not G.has_edge(u, v):
                            G.add_edge(u, v)


            G = nx.convert_node_labels_to_integers(G, first_label=0)


            if G.number_of_edges() > 0:
                ei = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
                if ei.shape[0] != 2 and ei.shape[1] == 2:
                    ei = ei.T
                edge_idx = ei.astype(np.int64)
            else:

                edge_idx = np.array([[], []], dtype=np.int64)


            nodes = sorted(G.nodes())
            role_id = np.array([G.degree(i) for i in nodes], dtype=np.int64)


            if edge_idx.shape[1] > 0:
                gt = find_gd(edge_idx, role_id)
            else:
                gt = np.array([], dtype=np.float64)


            features = []
            for i in sorted(G.nodes()):
                attrs = G.nodes[i]
                if 'feature' in attrs:
                    feat = attrs['feature']
                elif 'features' in attrs:
                    feat = attrs['features']
                else:
                    feat = attrs['feat']
                

                if not isinstance(feat, np.ndarray):
                    feat = np.array(feat, dtype=np.float32)
                if feat.ndim == 0:
                    feat = np.array([feat], dtype=np.float32)
                elif feat.ndim == 1 and len(feat) != len(_MEAN):

                    feat = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
                
                features.append(feat)
            
            if features:
                node_feats = np.vstack(features).astype(np.float32)
            else:
                node_feats = np.array([], dtype=np.float32).reshape(0, len(_MEAN))


            if G.number_of_nodes() > 0:
                pos_arr = np.array(list(nx.spring_layout(G).values()), dtype=np.float32)
            else:
                pos_arr = np.array([], dtype=np.float32).reshape(0, 2)


            node_features_list.append(node_feats)
            edge_index_list.append(edge_idx)
            label_list.append(int(label))
            ground_truth_list.append(gt)
            role_id_list.append(role_id)
            pos_list.append(pos_arr)
            
            generated_count += 1
            pbar.update(1)

        except Exception as e:

            print(f"\nError generating a sample, skipping. Error: {e}")
            traceback.print_exc()
            continue
    
    pbar.close()


    if node_features_list:
        avg_nodes = np.mean([nf.shape[0] for nf in node_features_list])
        avg_edges = np.mean([ei.shape[1] for ei in edge_index_list])
        

        label_counts = {}
        for label in label_list:
            label_counts[label] = label_counts.get(label, 0) + 1
        print(f"Label distribution: {label_counts}")
    

    save_path = osp.join(data_dir, output_filename)
    np.save(save_path, {
        'node_features': node_features_list,
        'edge_index': edge_index_list,
        'label': label_list,
        'ground_truth': ground_truth_list,
        'role_id': role_id_list,
        'pos': pos_list
    })
    print(f"Saved {split} dataset to: {save_path}")

def main():

    generate_mol_conf_dataset(
        num_samples=1500,
        output_filename=f'train_mol_conf_{CONFOUND_PROB}.npy',
        split='train'
    )
    

    generate_mol_conf_dataset(
        num_samples=200,
        output_filename=f'val_mol_conf_{CONFOUND_PROB}.npy',
        split='val'
    )
    

    generate_mol_conf_dataset(
        num_samples=200,
        output_filename=f'test_mol_conf_{CONFOUND_PROB}.npy',
        split='test'
    )

if __name__ == '__main__':

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    main()
