import hydra
import pandas as pd
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import canonicalize_rxn, \
        get_full_length, get_partial_sequences_with_completion_ratio, \
        class_to_idx, tokenize_reaction, get_vocab_from_trained_model

def check_reaction_tokens_in_vocab(reaction_smiles, vocab_set):
    """
    Check if all tokens in a reaction are present in the vocabulary.
    
    Args:
        reaction_smiles: SMILES string for the reaction
        vocab_set: Set of valid vocabulary tokens
    
    Returns:
        Boolean indicating if all tokens are in vocab
    """
    reaction_tokens = tokenize_reaction(reaction_smiles)
    return reaction_tokens.issubset(vocab_set)

def get_data(config):
    '''
    Get the data for the reaction type classifier.
    '''
    print(f'Loading data from {config.reaction_dataset.data_dir}')
    if config.reaction_dataset.data_dir == 'uspto_50k/processed':
        data = pd.read_csv(os.path.join(PROJECT_ROOT,
                                        'data',
                                        config.reaction_dataset.data_dir,
                                        config.reaction_dataset.subset))
    elif config.reaction_dataset.data_dir == 'uspto_190/train_unique_reactions':
        data = pd.read_csv(os.path.join(PROJECT_ROOT,
                                        'data',
                                        config.reaction_dataset.data_dir,
                                        config.reaction_dataset.subset))
        if config.reaction_dataset.filter_by_vocab:
            # Load vocabulary from trained model
            vocab_list = get_vocab_from_trained_model(config.single_step_model.checkpoint_full_path)
            vocab_set = set(vocab_list)
            
            # Count how many reactions would be filtered
            original_count = len(data)
            valid_mask = data['reaction'].apply(lambda x: check_reaction_tokens_in_vocab(x, vocab_set))
            valid_count = valid_mask.sum()
            filtered_count = original_count - valid_count
            
            print(f"Original reactions: {original_count}")
            print(f"Valid reactions: {valid_count}")
            print(f"Filtered out: {filtered_count} ({filtered_count/original_count*100:.1f}%)")
            
            # Apply the filter
            data = data[valid_mask]
        # TODO: add filtering logic here            
        data.rename(columns={'reaction': 'rxn_smiles', 'true_class': 'rxn_insight_class'}, inplace=True)
        data['rxn_insight_class'] = data['rxn_insight_class'].apply(lambda x: class_to_idx[x]) 
    else:
        raise ValueError(f'Invalid data directory: {config.reaction_dataset.data_dir}')
    return data

@hydra.main(config_path='../configs', config_name='config.yaml')
def get_data_for_reaction_type_classifier(config):
    data = get_data(config)
    # rxn,property,full_length
    data['rxn_smiles'] = data['rxn_smiles'].apply(lambda x: canonicalize_rxn(x, remove_atom_map=True))
    data['rxn'] = data['rxn_smiles'].apply(lambda x: x.split('>>')[0])
    data['product_smi'] = data['rxn_smiles'].apply(lambda x: x.split('>>')[1])
    data['full_length'] = data['rxn'].apply(lambda x: get_full_length(x))
    data['property'] = data['rxn_insight_class'] # turn into 0-indexed
    df = get_partial_sequences_with_completion_ratio(data['rxn'].to_list(), 
                                                      data['property'].to_list(),
                                                      data['full_length'].to_list(), 
                                                      product_smiles=data['product_smi'].to_list(),
                                                      completion_lower_limit=config.classifier_guidance.dataset.completion_lower_limit, # from 0 to 1
                                                      max_augmentations=config.classifier_guidance.dataset.max_augmentations, 
                                                      only_augment_complete=True)
    # save data under predictors/reaction_type/uspto_50k/data.csv
    data_dir = config.classifier_guidance.dataset.data_dir.split('/')[0]
    subset = config.classifier_guidance.dataset.subset.split('.csv')[0]
    os.makedirs(os.path.join(PROJECT_ROOT, 
                            'data',
                            'predictors',
                            'reaction_type',
                            data_dir), exist_ok=True)
    df.to_csv(os.path.join(PROJECT_ROOT, 
                            'data',
                            'predictors',
                            'reaction_type',
                            data_dir,
                            f'{subset}_completion{config.classifier_guidance.dataset.completion_lower_limit}_augment{config.classifier_guidance.dataset.max_augmentations}.csv'), index=False)

if __name__ == '__main__':
    get_data_for_reaction_type_classifier()