import hydra
import pandas as pd
import os

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import canonicalize_rxn, get_full_length, get_partial_sequences_with_completion_ratio

@hydra.main(config_path='../configs', config_name='config.yaml')
def get_data_for_reaction_type_classifier(config):
    print(f'Loading data from {config.reaction_dataset.data_dir}')
    data = pd.read_csv(os.path.join(PROJECT_ROOT, 
                                    'data',
                                    config.reaction_dataset.data_dir, 
                                    'raw',
                                    f'{config.reaction_dataset.subset}.csv'))
    # rxn,property,full_length
    data['rxn'] = data['rxn_smiles'].apply(lambda x: canonicalize_rxn(x, remove_atom_map=True)).apply(lambda x: x.split('>>')[-1])
    data['full_length'] = data['rxn'].apply(lambda x: get_full_length(x))
    data['property'] = data['rxn_insight_class'] # turn into 0-indexed
    df = get_partial_sequences_with_completion_ratio(data['rxn'].to_list(), 
                                                      data['property'].to_list(),
                                                      data['full_length'].to_list(), 
                                                      completion_lower_limit=config.classifier_guidance.dataset.completion_lower_limit, # from 0 to 1
                                                      max_augmentations=config.classifier_guidance.dataset.max_augmentations, 
                                                      only_augment_complete=True)
    # save data under predictors/reaction_type/uspto_50k/data.csv
    os.makedirs(os.path.join(PROJECT_ROOT, 
                            'data',
                            'predictors',
                            'reaction_type',
                            'uspto_50k'), exist_ok=True)
    df.to_csv(os.path.join(PROJECT_ROOT, 
                            'data',
                            'predictors',
                            'reaction_type',
                            'uspto_50k',
                            f'{config.reaction_dataset.subset}_product_completion{config.classifier_guidance.dataset.completion_lower_limit}_augment{config.classifier_guidance.dataset.max_augmentations}.csv'), index=False)

if __name__ == '__main__':
    get_data_for_reaction_type_classifier()