import pandas as pd
import numpy as np
import os
import sys
import pickle
from rdkit import Chem
from process_data import process_smiles
from process_utils import build_dataset

def create_splits(processed_data, original_df):
    """
    Splits the processed data into train, test_id, and test_ood sets.
    """
    smiles_to_info = {
        row.smiles: {'split': row.split, 'cliff_mol': row.cliff_mol}
        for _, row in original_df.iterrows()
    }

    train_idx, id_idx, ood_idx = [], [], []
    for idx, smi in enumerate(processed_data['smiles']):
        info = smiles_to_info.get(smi)
        if not info:
            continue
        if info['split'] == 'train':
            train_idx.append(idx)
        elif info['split'] == 'test':
            (ood_idx if info['cliff_mol'] else id_idx).append(idx)

    train_set = build_dataset(
        processed_data['graphs'],
        processed_data['reps'],
        processed_data['smiles'],
        processed_data['targets'],
        train_idx
    )
    test_id_set = build_dataset(
        processed_data['graphs'],
        processed_data['reps'],
        processed_data['smiles'],
        processed_data['targets'],
        id_idx
    )
    test_ood_set = build_dataset(
        processed_data['graphs'],
        processed_data['reps'],
        processed_data['smiles'],
        processed_data['targets'],
        ood_idx
    )

    return train_set, test_id_set, test_ood_set


def save_datasets_and_pickle(datasets, output_dir, base_name):
    """
    Saves CSVs (only smiles and targets) and a pickle file under the specified output directory.
    """
    os.makedirs(output_dir, exist_ok=True)

    split_to_fname = {
        'train': 'train_featurized.csv',
        'eval':  'eval_featurized.csv',
        'ood':   'ood_featurized.csv'
    }

    for split, data in datasets.items():
        fname = split_to_fname.get(split)
        if not fname:
            continue
        df_smiles = pd.DataFrame({
            'smiles': data['smiles'],
            'target': data['targets']
        })
        path = os.path.join(output_dir, fname)
        df_smiles.to_csv(path, index=False)
        print(f"Saved {split} CSV (smiles & target) to: {path}")

    pkl_path = os.path.join(output_dir, f'{base_name}.pkl')
    with open(pkl_path, 'wb') as f:
        pickle.dump(datasets, f)
    print(f"Saved dataset dict pickle to: {pkl_path}")


def main():
    class Args:
        def __init__(self):
            self.data_rep = None

    args = Args()

    data_dir = '../data/MoleculeACE_raw'
    csv_files = sorted([f for f in os.listdir(data_dir) if f.lower().endswith('.csv')])

    if not csv_files:
        print(f"No CSV files found in directory: {data_dir}")
        sys.exit(1)

    output_base_dir = '../data/ac/ac'

    for csv_file in csv_files:
        try:
            print(f"\nProcessing file: {csv_file}")
            csv_path = os.path.join(data_dir, csv_file)
            original_df = pd.read_csv(csv_path)
            base_name = os.path.splitext(os.path.basename(csv_file))[0]

            if 'smiles' not in original_df.columns or 'y' not in original_df.columns:
                print(f"Skipping {csv_file}: required columns 'smiles' or 'y' not found.")
                continue

            df_proc = (
                original_df[['smiles', 'y']]
                .rename(columns={'y': 'cls_label'})
            )
            records = df_proc.to_dict('records')
            processed = process_smiles(args, records)
            train_set, test_id_set, test_ood_set = create_splits(processed, original_df)
            datasets = {
                'train': train_set,
                'eval':  test_id_set,
                'ood':   test_ood_set
            }
            output_dir = os.path.join(output_base_dir, base_name)
            save_datasets_and_pickle(datasets, output_dir, base_name)

        except Exception as e:
            print(f"Error processing {csv_file}: {e}")

if __name__ == '__main__':
    main()
