'''
Add atom mapping and correct header to prepare a file of reactions for rsmiles training.

'''
import hydra
import os
import pandas as pd
from rxnmapper import RXNMapper

from multiguide.helpers import PROJECT_ROOT
#from multiguide.evaluation.helpers import get_rxn_insight_info

# NOTE: need to define these here to minimize dependencies in rxnmapper environment
class_to_idx = {
    'Heteroatom Alkylation and Arylation': 0,
    'Acylation': 1,
    'C-C Coupling': 2,
    'Aromatic Heterocycle Formation': 3,
    'Protection': 4,
    'Deprotection': 5,
    'Reduction': 6,
    'Oxidation': 7,
    'Functional Group Addition': 8,
    'Functional Group Interconversion': 9,
    'Miscellaneous': 10,
    '-1': -1
}

def atom_map_reactions_with_rxnmapper(reactions, 
                                      return_mapping_confidence=False):
    '''
        Add atom mapping to reactions.
    '''
    rxn_mapper = RXNMapper()
    #print(f'rxn_mapper: {rxn_mapper.device}')
    try:
        res = rxn_mapper.get_attention_guided_atom_maps(reactions)
    except Exception as e:
        print(f'Error mapping reactions: {e}')
        print(reactions)
        raise
    if return_mapping_confidence:
        return res
    else:
        return [atom_mapping_tuple['mapped_rxn'] for atom_mapping_tuple in res]

@hydra.main(config_path='../configs', config_name='config.yaml')
def prepare_file_for_rsmiles_training(cfg):
    # header: id,class,reactants>reagents>production
    # load data
    data = pd.read_csv(os.path.join(PROJECT_ROOT, 'data', cfg.reaction_dataset.data_dir, cfg.reaction_dataset.subset))
    data = data.iloc[cfg.reaction_dataset.start_idx:cfg.reaction_dataset.end_idx]
    print(f'Processing {len(data)} reactions')
    rxn_mapper = RXNMapper()
    print(f'rxn_mapper: {rxn_mapper.device}')
    #ids = data.index
    mapped_reactions = [] 
    ids = []
    classes = []
    batch_size = 1
    num_reactions_skipped = 0
    for i in range(0, len(data), batch_size):
        print(f'Processing batch {i//batch_size+1} of {len(data)//batch_size+1}')
        batch = data['reaction'].tolist()[i:i+batch_size]
        print(f'len(batch): {len(batch)}')
        try:
            atom_mapped_batch = atom_map_reactions_with_rxnmapper(batch, return_mapping_confidence=False)
        except Exception as e:
            print(f'Error mapping reactions: {e}, skipping')
            num_reactions_skipped += len(batch)
            continue
        print(f'len(atom_mapped_batch): {len(atom_mapped_batch)}')
        mapped_reactions.extend(atom_mapped_batch)
        ids.extend(data.index[i:i+batch_size].tolist())
        if 'true_class' in data.columns:
            classes.extend([class_to_idx[x] for x in data['true_class'].tolist()[i:i+batch_size]])
        else:
            #rxn_insight_info = get_rxn_insight_info(batch)
            # NOTE: reaction class is not used for rsmiles training
            classes.extend([-1]*len(batch))
    # save data
    out_data_dir = cfg.reaction_dataset.data_dir.split('/')[0]
    out_dir = os.path.join(PROJECT_ROOT, 'data', out_data_dir, 'raw_for_rsmiles_training')
    os.makedirs(out_dir, exist_ok=True)
    out_data = pd.DataFrame({'id': ids, 'class': classes, 'reactants>reagents>production': mapped_reactions})
    #data.drop(columns=['reaction','true_class'], inplace=True)
    print(f'Saving data to {os.path.join(out_dir, cfg.reaction_dataset.subset)}')
    out_data.to_csv(os.path.join(out_dir, 'raw_'+cfg.reaction_dataset.subset), index=False)

if __name__ == '__main__':
    prepare_file_for_rsmiles_training()