'''
    Extract starting material from similarity routes.
'''
import hydra
import json
import os
import pandas as pd
from rdkit import Chem

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import get_starting_material_from_route

@hydra.main(config_path='../configs', config_name='config.yaml', version_base=None)
def extract_starting_material_from_similarity_routes(cfg):
    '''
        Extract starting material from similarity routes.
    '''
    # read routes from json file
    compound_dir = os.path.join(
        PROJECT_ROOT,
        'data', 
        'route_similarity_data',
        'Outputs', 
        cfg.route_dataset.compound_name
    )
    processed_dir = os.path.join(
        PROJECT_ROOT,
        'data', 
        'route_similarity_data',
        'processed'
    )
    os.makedirs(processed_dir, exist_ok=True)
    # extract all routes in turn
    route_files = [f for f in os.listdir(compound_dir) if f.endswith('_tree.json')]
    print(f'Found {len(route_files)} route files in {compound_dir}')
    all_starting_materials = []
    for route_file in route_files:
        route_path = os.path.join(compound_dir, route_file)
        route = json.load(open(route_path, 'r', encoding='utf-8'))
        print(f'Loaded route from {route_path}')
        # extract reactions from the route
        children = route['children']
        while len(children) > 0:
            child = children.pop(0)
            # applies to both reactions and mols
            if child['type'] == 'mol' and 'children' not in child:
                all_starting_materials.append(child['smiles'])
            elif 'children' in child:
                children.extend(child['children'])

    print(f'Found {len(all_starting_materials)} starting materials')
    # save starting materials to txt file
    starting_materials_path = os.path.join(processed_dir, f'{cfg.route_dataset.compound_name}_starting_materials.txt')
    with open(starting_materials_path, 'w', encoding='utf-8') as f:
        for starting_material in all_starting_materials:
            f.write(starting_material + '\n')
    print(f'Saved starting materials to {starting_materials_path}')

if __name__ == '__main__':
    extract_starting_material_from_similarity_routes()