'''
- break the uspto full dataset into chunks to be processed in parallel by multiple slurm jobs
- processing here means:
    - data validity: drop reactions not parsed
    - split parsed reactions to reactant and product sets of canonicalized smiles
    - save the results to a new file
'''
import os
import pandas as pd
import hydra
from rdkit import Chem

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import get_reactant_and_product_from_reaction_smiles, clear_atom_map, get_data_chunk

@hydra.main(config_path='../configs', config_name='config.yaml')
def main(config):
    data_chunk = get_data_chunk(config)
    invalid_rows = []
    data_chunk['Reactants'] = None
    data_chunk['Products'] = None
    data_chunk['ReactantsSet'] = None
    data_chunk['ProductsSet'] = None
    for row_idx, row in data_chunk.iterrows():
        # remove "" around row['ReactionSmiles'] string if present
        reaction_smiles = row['ReactionSmiles'].strip('"')
        reactants, products = get_reactant_and_product_from_reaction_smiles(reaction_smiles)
        found_invalid = False
        for molecule in reactants+products:
            if Chem.MolFromSmiles(molecule) is None:
                print(f'========== Invalid molecule: {molecule} in row {row_idx}')
                invalid_rows.append(row_idx)
                found_invalid = True
                break
        if found_invalid:
            continue
        data_chunk.at[row_idx, 'ReactionSmiles'] = reaction_smiles
        data_chunk.at[row_idx, 'Reactants'] = '.'.join(reactants)
        data_chunk.at[row_idx, 'Products'] = '.'.join(products)
        data_chunk.at[row_idx, 'ReactantsSet'] = set([clear_atom_map(molecule) for molecule in reactants])
        data_chunk.at[row_idx, 'ProductsSet'] = set([clear_atom_map(molecule) for molecule in products])
    print(f'Found {len(invalid_rows)} invalid rows in chunk {config.route_dataset.start_idx}-{config.route_dataset.end_idx}')
    data_chunk['ReactionSmiles'] = data_chunk['ReactionSmiles'].astype(str)
    data_chunk = data_chunk.drop(invalid_rows, axis=0)
    data_chunk = data_chunk[['ReactionSmiles', 'Reactants', 'Products', 'ReactantsSet', 'ProductsSet', 'TextMinedYield', 'CalculatedYield', 'PatentNumber', 'ParagraphNum', 'Year']]
    data_chunk.to_csv(os.path.join(PROJECT_ROOT, 'data', f'{config.route_dataset.path_chunks}_{config.route_dataset.start_idx}_{config.route_dataset.end_idx}.csv'), index=False)

if __name__ == '__main__':
    main()
