import os
import os.path as osp
import random
from tqdm import tqdm
import traceback
import json

import numpy as np
import networkx as nx
import torch

try:
    from molecular import generate_real_dataset, molecular_generators, motif_connectors
    from BA3_loc import find_gd
except ImportError:
    print("Error: Unable to import 'molecular' or 'BA3_loc' modules.")
    print("Please ensure molecular.py and BA3_loc.py files are in the same directory as this script.")
    exit()

_MEAN = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)
_STD  = np.array([1.5, 2.0, 1.2, 1.3, 1.8], dtype=np.float32)

ALL_MOLECULAR_TYPES = {
    1: "Molecule1",
    2: "Molecule2", 
    3: "Molecule3",
    4: "Molecule4",
    5: "Molecule5"
}

MOLECULAR_EXPERIMENTS = []

BALANCED_MODE = True
BALANCED_LABEL_TO_MOLS = {
    0: [1, 2],
    1: [2, 3],
    2: [3, 4],
    3: [4, 5],
    4: [1, 5],
}

def generate_balanced_dataset(exclude_molecules=None, allowed_causal_molecules=None):
    if exclude_molecules is None:
        exclude_molecules = set()
    else:
        exclude_molecules = set(exclude_molecules)
    if allowed_causal_molecules is not None:
        allowed_causal_molecules = set(allowed_causal_molecules)
    candidate_labels = []
    for lab, mols in BALANCED_LABEL_TO_MOLS.items():
        s = set(mols)
        if s & exclude_molecules:
            continue
        if allowed_causal_molecules is not None and not s.issubset(allowed_causal_molecules):
            continue
        candidate_labels.append(lab)
    if not candidate_labels:
        raise RuntimeError("No available labels in Balanced mode")
    label = random.choice(candidate_labels)
    mol_list = BALANCED_LABEL_TO_MOLS[label]
    presence_map = {i: False for i in range(1,6)}
    G_total = None
    role_id_total = []
    for idx, mol_id in enumerate(mol_list):
        try:
            mol_G, mol_role = molecular_generators[mol_id][0](
                size=random.randint(18,30),
                node_feature_mean=_MEAN,
                std=_STD
            )
        except Exception:
            mol_G = nx.path_graph(6)
            for n in mol_G.nodes():
                mol_G.nodes[n]['feature'] = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
            mol_role = [mol_G.degree(n) for n in mol_G.nodes()]
        presence_map[mol_id] = True
        if idx == 0:
            G_total = mol_G.copy()
            role_id_total = mol_role.copy() if isinstance(mol_role,list) else list(mol_role)
        else:
            try:
                G_total, role_id_total, _, _ = motif_connectors[1][0](G_total, mol_G)
            except Exception:
                G_total = nx.union(G_total, mol_G, rename=('A-','B-'))
                G_total = nx.convert_node_labels_to_integers(G_total)
                role_id_total = [G_total.degree(n) for n in G_total.nodes()]
    if G_total.number_of_edges()>0:
        edges_arr = np.array(list(G_total.edges()), dtype=np.int64).T
    else:
        edges_arr = np.array([[],[]], dtype=np.int64)
    return G_total, role_id_total, label, edges_arr, presence_map

for mol_id in range(1, 6):
    MOLECULAR_EXPERIMENTS.append({
        'name': f'causal_{mol_id}',
        'type': 'causal',
        'molecular_id': mol_id,
        'file_suffix': f'mol_causal_{mol_id}_no_conf',
        'target_causal_molecules': [mol_id],
        'confounder_pool_ids': [],
        'confound_prob': 0.0,
        'background_confounder_prob_train': 0.0,
        'background_confounder_prob_test': 0.0,
        'description': f'{ALL_MOLECULAR_TYPES[mol_id]} as causal information (no confounding)'
    })
    
    other_molecules = [mid for mid in range(1, 6) if mid != mol_id]
    MOLECULAR_EXPERIMENTS.append({
        'name': f'confound_{mol_id}',
        'type': 'confounding',
        'molecular_id': mol_id,
        'file_suffix': f'mol_confound_{mol_id}_others_causal',
        'target_causal_molecules': other_molecules,
        'confounder_pool_ids': [mol_id],
        'confound_prob': 0.7,
        'background_confounder_prob_train': 0.1,
        'background_confounder_prob_test': 0.1,
        'description': f'{ALL_MOLECULAR_TYPES[mol_id]} as confounding information, other molecules as causal information'
    })

def attach_confounders_to_graph(G, num_confounders, confounder_pool_ids=None):

    if confounder_pool_ids is None:
        confounder_pool_ids = list(range(1, 6))
        
    new_G = G
    if not hasattr(new_G, 'graph') or not isinstance(new_G.graph, dict):
        new_G.graph = {}
    new_G.graph.setdefault('attached_confounders', [])
    for _ in range(num_confounders):
        try:
            confounder_id = random.choice(confounder_pool_ids)
            
            if confounder_id < 1 or confounder_id > 5:
                print(f"Warning: Skipping confounder molecule {confounder_id} - only molecules 1-5 are supported")
                continue
            
            confounder_motif, _ = molecular_generators[confounder_id][0](
                size=random.randint(18, 30),
                node_feature_mean=_MEAN,
                std=_STD
            )
            
            if confounder_motif.number_of_nodes() == 0 or new_G.number_of_nodes() == 0:
                continue

            new_G, _, _, _ = motif_connectors[1][0](new_G, confounder_motif)
            new_G.graph['attached_confounders'].append(confounder_id)

        except Exception as e:
            print(f"Warning: Failed to attach a confounder molecule (ID: {confounder_id}). Error: {e}")
            continue
            
    return new_G

def generate_real_dataset_with_specific_molecules(target_molecule_ids=None):

    if target_molecule_ids is None:
        G, role_id, label, edge_index, motif1_present, motif2_present, \
        motif3_present, motif4_present, motif5_present = generate_real_dataset()

        molecule_presence = {
            1: motif1_present,
            2: motif2_present,
            3: motif3_present,
            4: motif4_present,
            5: motif5_present
        }
        return G, role_id, label, edge_index, molecule_presence

    if len(target_molecule_ids) == 1:
        mol_id = target_molecule_ids[0]
        label = (mol_id - 1) % 5
        molecule_presence = {i: False for i in range(1, 6)}
        try:
            G, role_id = molecular_generators[mol_id][0](
                size=random.randint(18, 30),
                node_feature_mean=_MEAN,
                std=_STD
            )
            molecule_presence[mol_id] = True
        except Exception as e:
            print(f"Warning: fallback generation for molecule {mol_id}: {e}")
            G = nx.Graph()
            G.add_nodes_from(range(6))
            G.add_edges_from([(i, i+1) for i in range(5)])
            for node in G.nodes():
                G.nodes[node]['feature'] = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
            role_id = [G.degree(node) for node in G.nodes()]

        if G.number_of_edges() > 0:
                edge_index = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
                if edge_index.shape[0] != 2:
                    edge_index = edge_index.T
        else:
            edge_index = np.array([[], []], dtype=np.int64)
        return G, role_id, label, edge_index, molecule_presence

    y = random.choice([0, 1, 2, 3, 4])
    molecule_presence = {i: False for i in range(1, 6)}
    G = nx.Graph()
    role_id = []
    molecules_used = 0
    for mol_id in target_molecule_ids:
        if molecules_used >= 3:
            break
        try:
            mol_G, mol_role = molecular_generators[mol_id][0](
                size=random.randint(18, 30),
                node_feature_mean=_MEAN,
                std=_STD
            )
            molecule_presence[mol_id] = True
            if molecules_used == 0:
                G = mol_G.copy()
                role_id = mol_role.copy() if isinstance(mol_role, list) else list(mol_role)
            else:
                try:
                    G, role_id, _, _ = motif_connectors[1][0](G, mol_G)
                except:
                    G = nx.union(G, mol_G, rename=('G-', 'mol-'))
                    G = nx.convert_node_labels_to_integers(G)
                    role_id = [G.degree(node) for node in G.nodes()]
            molecules_used += 1
        except Exception as e:
            print(f"Warning: Failed to generate molecule {mol_id}: {e}")
            continue

    if G.number_of_nodes() == 0:
        try:
            fallback_mol = target_molecule_ids[0]
            G, role_id = molecular_generators[fallback_mol][0](
                size=random.randint(18, 30),
                node_feature_mean=_MEAN,
                std=_STD
            )
            molecule_presence[fallback_mol] = True
        except:
            G = nx.Graph()
            G.add_nodes_from(range(8))
            G.add_edges_from([(0,1),(1,2),(2,3),(3,4),(4,5),(5,6),(6,7)])
            for node in G.nodes():
                G.nodes[node]['feature'] = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
            role_id = [G.degree(node) for node in G.nodes()]

    if G.number_of_edges() > 0:
        edge_index = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
        if edge_index.shape[0] != 2:
            edge_index = edge_index.T
    else:
        edge_index = np.array([[], []], dtype=np.int64)
    return G, role_id, y, edge_index, molecule_presence

def generate_dataset_with_confounder(split='train', confound_prob=0.6, confounder_pool_ids=None, 
                                   background_confounder_prob_train=0.20, background_confounder_prob_test=0.05,
                                   target_causal_molecules=None):

    if confounder_pool_ids is None:
        confounder_pool_ids = list(range(1, 6))
        
    if BALANCED_MODE and target_causal_molecules and len(target_causal_molecules) > 1:
        exclude = confounder_pool_ids if confounder_pool_ids else []
        try:
            G, role_id, label, edge_index, molecule_presence = generate_balanced_dataset(
                exclude_molecules=exclude,
                allowed_causal_molecules=target_causal_molecules
            )
        except Exception as e:
            print("Balanced generation failed, fallback to original logic:", e)
            G, role_id, label, edge_index, molecule_presence = generate_real_dataset_with_specific_molecules(target_causal_molecules)
    elif target_causal_molecules is None or len(target_causal_molecules) == 1:
        G, role_id, label, edge_index, molecule_presence = generate_real_dataset_with_specific_molecules(target_causal_molecules)
    else:
        allowed_causal_set = set(target_causal_molecules)
        confounder_id = confounder_pool_ids[0] if confounder_pool_ids else None
        max_attempts = 15
        attempt = 0
        selected = False
        while attempt < max_attempts and not selected:
            attempt += 1
            try:
                base = generate_real_dataset()
                (G_raw, role_id_raw, label_raw, edge_index_raw,
                 m1, m2, m3, m4, m5) = base
                presence_map = {1: m1, 2: m2, 3: m3, 4: m4, 5: m5}
                if confounder_id is not None and presence_map.get(confounder_id, False):
                    continue
                used_ids = {mid for mid, present in presence_map.items() if present}
                if not (used_ids & allowed_causal_set):
                    continue
                if any((mid not in allowed_causal_set) for mid in used_ids):
                    continue
                G = G_raw
                role_id = role_id_raw
                label = label_raw
                edge_index = edge_index_raw
                molecule_presence = presence_map
                selected = True
            except Exception as e:
                continue
        if not selected:
            G, role_id, label, edge_index, molecule_presence = generate_real_dataset_with_specific_molecules(target_causal_molecules)
    
    confounder_prob = confound_prob
    
    causal_molecules_present = any(molecule_presence.values())
    
    if causal_molecules_present and random.random() < confounder_prob:
        num_to_attach = random.randint(3, 5)
        G = attach_confounders_to_graph(G, num_to_attach, confounder_pool_ids)

    background_prob = background_confounder_prob_train if split != 'test' else background_confounder_prob_test
    if random.random() < background_prob:
        G = attach_confounders_to_graph(G, 1, confounder_pool_ids)

    return G, role_id, label, edge_index

def generate_mol_conf_dataset(
    num_samples: int,
    output_filename: str,
    split: str,
    experiment_config: dict = None,
    feature_noise: float = 0.05,
    edge_del_prob: float = 0.02,
    edge_add_prob: float = 0.02
):

    if experiment_config is None:
        experiment_config = {
            'confound_prob': 0.6,
            'confounder_pool_ids': list(range(1, 6)),
            'background_confounder_prob_train': 0.20,
            'background_confounder_prob_test': 0.05,
            'target_causal_molecules': None
        }
    
    confound_prob = experiment_config.get('confound_prob', 0.7)
    confounder_pool_ids = experiment_config.get('confounder_pool_ids', list(range(1, 6)))
    bg_prob_train = experiment_config.get('background_confounder_prob_train', 0.20)
    bg_prob_test = experiment_config.get('background_confounder_prob_test', 0.05)
    target_causal_molecules = experiment_config.get('target_causal_molecules', None)
    
    data_dir = './data/analysis/'
    os.makedirs(data_dir, exist_ok=True)

    node_features_list = []
    edge_index_list = []
    label_list = []
    ground_truth_list = []
    role_id_list = []
    pos_list = []
    
    conf_prob_display = f"{confound_prob*100:.0f}%" if split != 'test' else "10%"
    bg_prob_display = f"{bg_prob_train*100:.0f}%" if split != 'test' else f"{bg_prob_test*100:.0f}%"
    pbar = tqdm(total=num_samples, desc=f'Generating {split.upper()} Samples')
    
    generated_count = 0
    debug_label_conf_counts = {l: 0 for l in range(5)}
    debug_label_counts = {l: 0 for l in range(5)}
    while generated_count < num_samples:
        try:
            G, role_id, label, edge_index = generate_dataset_with_confounder(
                split=split, 
                confound_prob=confound_prob,
                confounder_pool_ids=confounder_pool_ids,
                background_confounder_prob_train=bg_prob_train,
                background_confounder_prob_test=bg_prob_test,
                target_causal_molecules=target_causal_molecules
            )

            if G.number_of_nodes() == 0:
                continue

            for n, attrs in G.nodes(data=True):
                if not any(k in attrs for k in ('feature', 'features', 'feat')):
                    G.nodes[n]['feature'] = np.random.normal(
                        loc=_MEAN, scale=_STD
                    ).astype(np.float32)

            for n, attrs in G.nodes(data=True):
                if 'feature' in attrs:
                    orig = attrs['feature']
                elif 'features' in attrs:
                    orig = attrs['features']
                else:
                    orig = attrs['feat']
                
                if not isinstance(orig, np.ndarray):
                    orig = np.array(orig, dtype=np.float32)
                
                noise = np.random.normal(0.0, feature_noise, size=orig.shape).astype(np.float32)
                G.nodes[n]['feature'] = (orig + noise)

            edges = list(G.edges())
            if len(edges) > 0:
                num_del = int(edge_del_prob * len(edges))
                if num_del > 0:
                    del_edges = random.sample(edges, min(num_del, len(edges)))
                    G.remove_edges_from(del_edges)
                
                nodes = list(G.nodes())
                if len(nodes) > 1:
                    num_add = int(edge_add_prob * len(edges))
                    for _ in range(num_add):
                        u, v = random.sample(nodes, 2)
                        if not G.has_edge(u, v):
                            G.add_edge(u, v)

            G = nx.convert_node_labels_to_integers(G, first_label=0)

            if G.number_of_edges() > 0:
                ei = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
                if ei.shape[0] != 2 and ei.shape[1] == 2:
                    ei = ei.T
                edge_idx = ei.astype(np.int64)
            else:
                edge_idx = np.array([[], []], dtype=np.int64)

            nodes = sorted(G.nodes())
            role_id = np.array([G.degree(i) for i in nodes], dtype=np.int64)

            if edge_idx.shape[1] > 0:
                gt = find_gd(edge_idx, role_id)
            else:
                gt = np.array([], dtype=np.float64)

            features = []
            for i in sorted(G.nodes()):
                attrs = G.nodes[i]
                if 'feature' in attrs:
                    feat = attrs['feature']
                elif 'features' in attrs:
                    feat = attrs['features']
                else:
                    feat = attrs['feat']
                
                if not isinstance(feat, np.ndarray):
                    feat = np.array(feat, dtype=np.float32)
                if feat.ndim == 0:
                    feat = np.array([feat], dtype=np.float32)
                elif feat.ndim == 1 and len(feat) != len(_MEAN):
                    feat = np.random.normal(loc=_MEAN, scale=_STD).astype(np.float32)
                
                features.append(feat)
            
            if features:
                node_feats = np.vstack(features).astype(np.float32)
            else:
                node_feats = np.array([], dtype=np.float32).reshape(0, len(_MEAN))

            if G.number_of_nodes() > 0:
                pos_arr = np.array(list(nx.spring_layout(G).values()), dtype=np.float32)
            else:
                pos_arr = np.array([], dtype=np.float32).reshape(0, 2)

            if 'attached_confounders' in G.graph:
                if len(G.graph['attached_confounders']) > 0:
                    debug_label_conf_counts[label] += 1
            debug_label_counts[label] += 1

            node_features_list.append(node_feats)
            edge_index_list.append(edge_idx)
            label_list.append(int(label))
            ground_truth_list.append(gt)
            role_id_list.append(role_id)
            pos_list.append(pos_arr)
            
            generated_count += 1
            pbar.update(1)

        except Exception as e:
            print(f"\nError generating a sample, skipping. Error: {e}")
            traceback.print_exc()
            continue
    
    pbar.close()

    if node_features_list:
        avg_nodes = np.mean([nf.shape[0] for nf in node_features_list])
        avg_edges = np.mean([ei.shape[1] for ei in edge_index_list])
        
        label_counts = {}
        for label in label_list:
            label_counts[label] = label_counts.get(label, 0) + 1
    print(f"Label distribution: {label_counts}")
    if BALANCED_MODE:
        print(f"BALANCED_MODE=True effective labels: {sorted(label_counts.keys())}")
    attach_rates = {l: (debug_label_conf_counts[l] / debug_label_counts[l] if debug_label_counts[l] > 0 else 0)
            for l in range(5)}
    print(f"Confounder attach counts per label: {debug_label_conf_counts}")
    print(f"Per-label attach rate: {attach_rates}")
    
    save_path = osp.join(data_dir, output_filename)
    np.save(save_path, {
        'node_features': node_features_list,
        'edge_index': edge_index_list,
        'label': label_list,
        'ground_truth': ground_truth_list,
        'role_id': role_id_list,
        'pos': pos_list
    })
    print(f"Saved {split} dataset to: {save_path}")

def run_systematic_experiments():
    for i, exp_config in enumerate(MOLECULAR_EXPERIMENTS, 1):
        analysis_type = exp_config['type']
        molecular_id = exp_config['molecular_id']
        target_causal_molecules = exp_config.get('target_causal_molecules', [molecular_id])
        confounder_pool_ids = exp_config.get('confounder_pool_ids', [])
        
        try:
            generate_mol_conf_dataset(
                num_samples=3000,
                output_filename=f"train_{exp_config['file_suffix']}.npy",
                split='train',
                experiment_config=exp_config
            )
            
            generate_mol_conf_dataset(
                num_samples=375,
                output_filename=f"val_{exp_config['file_suffix']}.npy",
                split='val',
                experiment_config=exp_config
            )
            
            generate_mol_conf_dataset(
                num_samples=375,
                output_filename=f"test_{exp_config['file_suffix']}.npy",
                split='test',
                experiment_config=exp_config
            )
            
            
        except Exception as e:
            print(f"✗ Error in experiment {exp_config['name']}: {str(e)}")
            print("Continuing with next experiment...")
            continue

if __name__ == '__main__':
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    run_systematic_experiments()
