import pandas as pd
from rdkit import Chem
from datasets import Dataset, DatasetDict
from utils.data_utils import set_random_seed
import os

def create_uspto_dataset(input_file):
    df = pd.read_pickle(input_file)
    
    df['equa'] = df['reactants_mol'].apply(lambda mol: Chem.MolToSmiles(mol) if mol else None)
    df['prod'] = df['products_mol'].apply(lambda mol: Chem.MolToSmiles(mol) if mol else None)
    
    df = df.dropna(subset=['equa', 'prod'])
    
    dataset_dict = {}
    for split in df['set'].unique():
        split_df = df[df['set'] == split][['equa', 'prod']]
        dataset_dict[split] = Dataset.from_pandas(split_df, preserve_index=False)
    
    hf_dataset = DatasetDict(dataset_dict)
    output_name = os.path.splitext(os.path.basename(input_file))[0] + '_dataset'
    output_path = f'asset/{output_name}'
    hf_dataset.save_to_disk(output_path)
    
    print(f"Dataset saved with splits: {list(hf_dataset.keys())}")
    print(f"Sample counts: {[(k, len(v)) for k, v in hf_dataset.items()]}")
    
    return hf_dataset

if __name__ == "__main__":
    set_random_seed(42)
    
    files = ['asset/uspto_50.pickle', 'asset/uspto_mixed.pickle']
    for file in files:
        if os.path.exists(file):
            print(f"Processing {file}...")
            create_uspto_dataset(file)