# process_lohi.py (uses process_utils.mol_to_graph)

import os
import numpy as np
import pandas as pd
import pickle
import argparse
import warnings
import random
import torch

from rdkit import Chem
from tqdm import tqdm
import deepchem as dc
from deepchem.feat.molecule_featurizers import RDKitDescriptors

# HI/LO splitter
import lohi_splitter as lohi

# >>> Use your canonical graph/embedding utilities <<<
from process_utils import (
    mol_to_graph,          # builds x (int64, 2 cols) and edge_attr (int64, 2 cols)
    build_dataset,
    get_gnn_embedding,
    get_smi_ted_embedding,
)

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.filterwarnings('ignore')

def set_args():
    parser = argparse.ArgumentParser(description="Preprocess molecular datasets with Lo-Hi splitter.")
    # Data / features
    parser.add_argument('--output_dir', default='../data', help="Root directory for the generated Lo-Hi data.")
    parser.add_argument('--dataset_name', '-dn', default='molnet', choices=['molnet'])
    parser.add_argument('--property', '-p', type=str, required=True, 
                        choices=['bace', 'esol', 'freesolv', 'lipo', 'bbbp', 'clintox', 'sider'])
    parser.add_argument('--data_rep', '-dr', type=str, default=None, choices=['gnn', 'smi_ted'])
    parser.add_argument('--seed', default=42, type=int)

    # Lo-Hi parameters
    parser.add_argument('--hi_similarity_threshold', type=float, default=0.4)
    parser.add_argument('--train_min_frac', type=float, default=0.7)
    parser.add_argument('--test_min_frac', type=float, default=0.1)

    parser.add_argument('--lo_similarity_threshold', type=float, default=0.4)
    parser.add_argument('--min_cluster_size', type=int, default=5)
    parser.add_argument('--max_clusters', type=int, default=50)
    parser.add_argument('--std_threshold', type=float, default=0.6)

    return parser.parse_args()

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

def save_split_as_csv(args, dataset_dict, split_type, property_name):
    """Save train/eval CSVs for compatibility with your main loop."""
    dir_path = os.path.join(args.output_dir, args.dataset_name, split_type, property_name)
    os.makedirs(dir_path, exist_ok=True)

    for split_name in ['train', 'eval']:
        if split_name not in dataset_dict:
            continue
        df = pd.DataFrame({
            'smiles': dataset_dict[split_name]['smiles'],
            'target': np.array(dataset_dict[split_name]['targets']).ravel()
        })
        csv_path = os.path.join(dir_path, f'{split_name}_featurized.csv')
        df.to_csv(csv_path, index=False)
        print(f"Saved {split_name} data to: {csv_path}")

def generate_splits(args):
    featurizer = RDKitDescriptors(is_normalized=True)
    target_index = 0

    print(f"Loading '{args.property}' dataset from MoleculeNet...")
    if args.property == 'bace':
        _, datasets, _ = dc.molnet.load_bace_regression(featurizer=featurizer)
    elif args.property == 'esol':
        _, datasets, _ = dc.molnet.load_delaney(featurizer=featurizer)
    elif args.property == 'freesolv':
        _, datasets, _ = dc.molnet.load_freesolv(featurizer=featurizer)
    elif args.property == 'lipo':
        _, datasets, _ = dc.molnet.load_lipo(featurizer=featurizer)
    elif args.property == 'bbbp':
        _, datasets, _ = dc.molnet.load_bbbp(featurizer=featurizer)
    elif args.property == 'clintox':
        _, datasets, _ = dc.molnet.load_clintox(featurizer=featurizer)
        target_index = datasets[2].tasks.tolist().index('CT_TOX')
    elif args.property == 'sider':
        _, datasets, _ = dc.molnet.load_sider(featurizer=featurizer)
        target_index = datasets[2].tasks.tolist().index('Hepatobiliary disorders')
    else:
        raise ValueError(f"Dataset '{args.property}' is not supported.")

    train, valid, test = datasets

    all_reps    = np.concatenate([train.X,  valid.X,  test.X])
    all_smiles  = np.concatenate([train.ids, valid.ids, test.ids])
    all_targets = np.concatenate([train.y,  valid.y,  test.y])[:, target_index]

    # Build graphs with YOUR mol_to_graph (int indices + edge_attr)
    graph_data_list, valid_reps, valid_smiles, valid_targets = [], [], [], []
    print("Cleaning molecules and building graphs via process_utils.mol_to_graph ...")
    for rep, smi, target in tqdm(zip(all_reps, all_smiles, all_targets), total=len(all_smiles)):
        if np.isnan(rep).any() or np.isnan(target):
            continue
        mol = smi if isinstance(smi, Chem.Mol) else Chem.MolFromSmiles(smi)
        if mol is None:
            continue

        data = mol_to_graph(mol)    # <- ensures x: int64 (N,2) and edge_attr: int64 (E,2)
        data.y = [float(target)]
        graph_data_list.append(data)

        valid_reps.append(rep)
        valid_smiles.append(Chem.MolToSmiles(mol))
        valid_targets.append(float(target))

    print(f"Total valid molecules for splitting: {len(valid_smiles)}")

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Optional representations
    if args.data_rep == 'gnn':
        processed_reps = get_gnn_embedding(graph_data_list, device)
        filename = f'{args.property}_gnn.pkl'
    elif args.data_rep == 'smi_ted':
        processed_reps = get_smi_ted_embedding(valid_smiles, device)
        filename = f'{args.property}_smi_ted.pkl'
    else:
        processed_reps = valid_reps
        filename = f'{args.property}.pkl'

    # HI split
    print("\n" + "="*20 + " HI SPLIT " + "="*20)
    perform_hi_split(args, filename, graph_data_list, processed_reps, valid_smiles, valid_targets)

    # LO split
    print("\n" + "="*20 + " LO SPLIT " + "="*20)
    perform_lo_split(args, filename, graph_data_list, processed_reps, valid_smiles, valid_targets)

    print(f"\n✅ All splits generated successfully in '{args.output_dir}'.")

def perform_hi_split(args, filename, graph_data_list, reps, smiles, targets):
    print(f"Performing 'hi' split for '{args.property}'...")
    try:
        train_indices, eval_indices = lohi.hi_train_test_split(
            smiles=smiles,
            similarity_threshold=args.hi_similarity_threshold,
            train_min_frac=args.train_min_frac,
            test_min_frac=args.test_min_frac,
            verbose=True
        )
        print(f"Train samples: {len(train_indices)} | Eval samples: {len(eval_indices)}")

        dataset_dict = {
            'train': build_dataset(graph_data_list, reps, smiles, targets, train_indices),
            'eval':  build_dataset(graph_data_list, reps, smiles, targets, eval_indices),
        }

        save_path = os.path.join(args.output_dir, args.dataset_name, 'hi', args.property, filename)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as f:
            pickle.dump(dataset_dict, f)

        print(f"HI split data saved to: {save_path}")
        save_split_as_csv(args, dataset_dict, 'hi', args.property)
    except Exception as e:
        print(f"Could not perform HI split for {args.property}. Error: {e}")

def perform_lo_split(args, filename, graph_data_list, reps, smiles, targets):
    print(f"Performing 'lo' split for '{args.property}'...")
    try:
        eval_clusters_smiles, train_smiles_list = lohi.lo_train_test_split(
            smiles=smiles,
            values=targets,
            threshold=args.lo_similarity_threshold,
            min_cluster_size=args.min_cluster_size,
            max_clusters=args.max_clusters,
            std_threshold=args.std_threshold
        )

        smiles_to_idx = {smi: i for i, smi in enumerate(smiles)}
        train_indices = [smiles_to_idx[smi] for smi in train_smiles_list]
        flat_eval_smiles = [smi for cluster in eval_clusters_smiles for smi in cluster]
        eval_indices = [smiles_to_idx[smi] for smi in flat_eval_smiles]

        print(f"Train samples: {len(train_indices)} | Eval samples: {len(eval_indices)} (clusters: {len(eval_clusters_smiles)})")

        dataset_dict = {
            'train':    build_dataset(graph_data_list, reps, smiles, targets, train_indices),
            'eval':     build_dataset(graph_data_list, reps, smiles, targets, eval_indices),
            'clusters': eval_clusters_smiles
        }

        save_path = os.path.join(args.output_dir, args.dataset_name, 'lo', args.property, filename)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as f:
            pickle.dump(dataset_dict, f)

        print(f"LO split data saved to: {save_path}")
        save_split_as_csv(args, dataset_dict, 'lo', args.property)
    except Exception as e:
        print(f"Could not perform LO split for {args.property}. Error: {e}")

def main():
    args = set_args()
    set_seed(args.seed)
    if args.dataset_name == 'molnet':
        generate_splits(args)
    else:
        print(f"Dataset collection '{args.dataset_name}' is not yet supported in this script.")

if __name__ == "__main__":
    main()