import hydra
import pandas as pd
import os

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import clear_atom_map, get_full_length, \
    get_partial_sequences_with_completion_ratio, get_vocab_from_trained_model, tokenize_smiles

@hydra.main(config_path='../configs', config_name='config.yaml')
def get_data_for_reaction_type_classifier(config):
    print(f'Loading data from {config.reaction_dataset.data_dir}')
    data = pd.read_csv(os.path.join(PROJECT_ROOT,
                                    'data',
                                    config.reaction_dataset.data_dir,
                                    config.reaction_dataset.subset))
    # TODO: add filtering logic here
    # rxn,property,full_length
    data['main_target'] = data['main_target'].apply(clear_atom_map)
    data['starting_material'] = data['starting_material'].apply(clear_atom_map)

    if config.reaction_dataset.filter_by_vocab:
        # Load vocabulary from trained model
        vocab_list = get_vocab_from_trained_model(config.single_step_model.checkpoint_full_path)
        vocab_set = set(vocab_list)

        # Count how many reactions would be filtered
        original_count = len(data)

        # Check both main_target and starting_material tokens are in vocab
        def check_both_columns_in_vocab(row):
            main_target_tokens = tokenize_smiles(row['main_target'])  # Your SMILES tokenizer function
            starting_material_tokens = tokenize_smiles(row['starting_material'])
            
            main_target_valid = main_target_tokens.issubset(vocab_set)
            starting_material_valid = starting_material_tokens.issubset(vocab_set)
            
            return main_target_valid and starting_material_valid

        valid_mask = data.apply(check_both_columns_in_vocab, axis=1)
        valid_count = valid_mask.sum()
        filtered_count = original_count - valid_count

        print(f"Original reactions: {original_count}")
        print(f"Valid reactions: {valid_count}")
        print(f"Filtered out: {filtered_count} ({filtered_count/original_count*100:.1f}%)")

        # Apply the filter
        data = data[valid_mask]

    # rxn,property,full_length
    data['full_length'] = data['main_target'].apply(lambda x: get_full_length)
    data['property'] = data['similarity'] # turn into 0-indexed
    df = get_partial_sequences_with_completion_ratio(data['main_target'].to_list(),
                                                      data['property'].to_list(),
                                                      data['full_length'].to_list(),
                                                      completion_lower_limit=config.classifier_guidance.dataset.completion_lower_limit, # from 0 to 1
                                                      max_augmentations=config.classifier_guidance.dataset.max_augmentations,
                                                      only_augment_complete=True,
                                                      starting_material_separator=config.classifier_guidance.dataset.separator,
                                                      starting_material_smiles=data['starting_material'].to_list())
    # save data under predictors/reaction_type/uspto_50k/data.csv
    data_dir = config.classifier_guidance.dataset.data_dir.split('/')[0]
    subset = config.classifier_guidance.dataset.subset.split('.csv')[0]
    os.makedirs(os.path.join(PROJECT_ROOT,
                            'data',
                            'predictors',
                            'tanimoto',
                            data_dir), exist_ok=True)
    output_path = os.path.join(PROJECT_ROOT,
                            'data',
                            'predictors',
                            'tanimoto',
                            data_dir,
                            f'{subset}_completion{config.classifier_guidance.dataset.completion_lower_limit}_augment{config.classifier_guidance.dataset.max_augmentations}.csv')
    print(f'Saving to {output_path}')
    df.to_csv(output_path, index=False)

if __name__ == '__main__':
    get_data_for_reaction_type_classifier()