'''
Unifying script to process 50k dataset
- Remove duplicates
- Check overlap with other subsets (test with train and val, val with train)
- Add rxn insight information to the dataset: reaciton type + dict
- Least/most similar reactants
'''
import os
import json
import pickle
import hydra
import pandas as pd
from rdkit import Chem

from setup_path import *
from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import get_rxn_insight_info, find_reactant_by_similarity_rank
from multiguide.dataset.helpers import get_tanimoto
from multiguide.dataset.helpers import get_sorted_cano_smiles, class_to_idx, get_similarity_for_one_pair, get_similarity, get_starting_material_from_route

def reaction_dict_to_csv(reaction_data, output_filename='reactions.csv'):
    """
    Convert reaction dictionary to CSV format using pandas.
    
    Args:
        reaction_data: Dictionary with reactions as keys and info as values
        output_filename: Name of the output CSV file
    """
    
    # Convert dictionary to list of dictionaries for DataFrame
    rows = []
    
    for reaction_key, reaction_info in reaction_data.items():
        # Create a copy to avoid modifying original data
        row = reaction_info.copy()
        
        # Convert rxn_insight_info dict to string if it's a dict
        if 'rxn_insight_info' in row and isinstance(row['rxn_insight_info'], dict):
            row['rxn_insight_info'] = str(row['rxn_insight_info'])
        
        rows.append(row)
    
    # Create DataFrame
    df = pd.DataFrame(rows)
    
    # Define columns to keep (excluding first 3 as requested)
    columns_to_keep = [
        'sorted_cano_reactants',
        'sorted_cano_products', 
        'sorted_cano_reactions',
        'rxn_insight_info',
        'reaction_type',
        'most_similar_reactants',
        'most_similar_reactants_similarity',
        'least_similar_reactants',
        'least_similar_reactants_similarity'
    ]
    
    # Select only the columns we want, in the correct order
    #df = df[columns_to_keep]

    # deduplicate by sorted_cano_reactionts
    df = df.drop_duplicates(subset=['sorted_cano_reactions'])
    
    # Save to CSV
    df.to_csv(output_filename, index=False)
    
    print(f"CSV file '{output_filename}' created successfully!")
    print(f"Shape: {df.shape}")
    return df

def get_starting_material_from_route_with_rank(route, config):
    '''
        Get starting material from route.
        Args:
            route (list): list of reaction smiles
        Returns:
            list: list of starting material smiles
    '''
    starting_material = get_starting_material_from_route(route)
    main_target = route[0].split('>>')[0]
    # pick heaviest starting material
    similarity_target_to_starting_material = get_similarity(main_target, 
                                                            starting_material,
                                                            config.reaction_dataset.similarity_type, 
                                                            config.reaction_dataset.combination_weight)
    similarity_and_num_atoms = [sm + (Chem.MolFromSmiles(sm[1]).GetNumAtoms(),) for sm in similarity_target_to_starting_material]
    # get the heaviest + most similar starting material to target
    # NOTE: not using heaviness now because it's likely to be taken into account in the similarity score
    most_similar_sm = sorted(similarity_and_num_atoms, key=lambda x: (x[2]), reverse=True)[0]
    least_similar_sm = sorted(similarity_and_num_atoms, key=lambda x: (-x[2]), reverse=True)[0]
    heaviest_sm = sorted(similarity_and_num_atoms, key=lambda x: x[3], reverse=True)[0]
    return most_similar_sm, least_similar_sm, heaviest_sm

@hydra.main(config_path='../configs', config_name='config.yaml', version_base=None)
def process_uspto_hard(cfg):
    '''
        Process the 50k dataset and return a pandas dataframe with a unified 
        nomenclature (column names).
    '''
    # load original data
    # in_path = os.path.join(
    #     PROJECT_ROOT,
    #     'data', 
    #     cfg.route_dataset.route_dir,
    #     'original', 
    #     f'{cfg.route_dataset.subset}.pkl'
    # )
    # routes = pickle.load(
    #     open(
    #         in_path,
    #         'rb'
    #     )
    # )
    # processed_path = os.path.join(
    #     PROJECT_ROOT,
    #     'data', 
    #     cfg.route_dataset.route_dir,
    #     'processed', 
    #     f'{cfg.route_dataset.subset}.csv'
    # )
    # df = pd.read_csv(processed_path)
    # print(f'Loaded {len(df)} reactions from {processed_path}')
    # df = df.drop_duplicates(subset=['sorted_cano_reactions'])
    # print(f'Deduplicated {len(df)} reactions')
    # df.to_csv(processed_path, index=False)
    # print(f'Saved {len(df)} reactions to {processed_path}')
    # return
    in_path = os.path.join(
        PROJECT_ROOT,
        'data', 
        cfg.route_dataset.route_dir,
        cfg.route_dataset.subset
    )
    routes = json.load(
        open(
            in_path,
            'rb'
        )
    )
    print(f'Loaded {len(routes)} routes from {in_path}')
    reaction_data = {}
    json_routes = []
    # sorted_cano_reactants,sorted_cano_products,sorted_cano_reactions,rxn_insight_info,reaction_type,most_similar_reactants,most_similar_reactants_similarity,least_similar_reactants,least_similar_reactants_similarity
    for iter_route_idx, iter_route in enumerate(routes):
        if isinstance(iter_route, dict):
            route_idx = iter_route['route_name']
            route = iter_route['route_as_list']
        else:
            route_idx = iter_route_idx
            route = iter_route
        # get main target and most and least similar reactants per route
        main_target = route[0].split('>>')[0]
        route_most_similar_sm, route_least_similar_sm, route_heaviest_sm = get_starting_material_from_route_with_rank(
            route, cfg
        )
        route_reaction_data = {}
        for step, reaction in enumerate(route):
            reactants = reaction.split('>>')[1].split('.')
            main_product = reaction.split('>>')[0].split('.')[0]
            all_products = reaction.split('>>')[0].split('.')
            most_similar_reactant, most_similar_reactant_similarity = find_reactant_by_similarity_rank(
                reactants,
                main_product,
                cfg,
                ranking='most'
            )
            least_similar_reactant, least_similar_reactant_similarity = find_reactant_by_similarity_rank(
                reactants,
                main_product,
                cfg,
                ranking='least'
            )
            sorted_cano_reactants = get_sorted_cano_smiles(
                reactants
            )
            sorted_cano_products = get_sorted_cano_smiles(
                all_products
            )
            sorted_cano_reactions = sorted_cano_reactants + '>>' + sorted_cano_products
            forward_reaction = sorted_cano_reactants + '>>' + sorted_cano_products
            rxn_insight_info = get_rxn_insight_info(forward_reaction)
            main_target_similarity = get_similarity_for_one_pair(
                main_target, 
                '.'.join(sorted_cano_reactants), 
                cfg.reaction_dataset.similarity_type,
                cfg.reaction_dataset.combination_weight
            )
            print(f'route_most_similar_sm: {route_most_similar_sm[1]}')
            route_most_sm_to_reactants_similarity_combined = get_tanimoto(
                route_most_similar_sm[1],
                sorted_cano_reactants
            )
            route_most_sm_to_reactants_similarity = [get_tanimoto(route_most_similar_sm[1], r) for r in sorted_cano_reactants.split('.')]
            route_most_sm_to_reactants_similarity_max = max(route_most_sm_to_reactants_similarity)
            route_most_sm_to_reactants_max_reactant = sorted_cano_reactants.split('.')[
                route_most_sm_to_reactants_similarity.index(route_most_sm_to_reactants_similarity_max)
            ]
            route_least_sm_to_reactants_similarity_combined = get_tanimoto(
                route_least_similar_sm[1],
                sorted_cano_reactants
            )
            route_least_sm_to_reactants_similarity = [get_tanimoto(route_least_similar_sm[1], r) for r in sorted_cano_reactants.split('.')]
            route_least_sm_to_reactants_similarity_max = max(route_least_sm_to_reactants_similarity)
            route_least_sm_to_reactants_max_reactant = sorted_cano_reactants.split('.')[
                route_least_sm_to_reactants_similarity.index(route_least_sm_to_reactants_similarity_max)
            ]
            route_heaviest_sm_to_reactants_similarity_combined = get_tanimoto(
                route_heaviest_sm[1],
                sorted_cano_reactants
            )
            route_heaviest_sm_to_reactants_similarity = [get_tanimoto(route_heaviest_sm[1], r) for r in sorted_cano_reactants.split('.')]
            route_heaviest_sm_to_reactants_similarity_max = max(route_heaviest_sm_to_reactants_similarity)
            route_heaviest_sm_to_reactants_max_reactant = sorted_cano_reactants.split('.')[
                route_heaviest_sm_to_reactants_similarity.index(route_heaviest_sm_to_reactants_similarity_max)
            ]
            info = {
                'step': [step],
                'route_idx': [route_idx],
                'sorted_cano_products': sorted_cano_products,
                'sorted_cano_reactants': sorted_cano_reactants,
                'sorted_cano_reactions': sorted_cano_reactions,
                'reaction_type': class_to_idx[rxn_insight_info['CLASS']],
                'rxn_insight_info': rxn_insight_info,
                'main_target': main_target,
                'main_target_similarity': main_target_similarity,
                'immediate_most_similar_reactants': most_similar_reactant,
                'immediate_most_similar_reactants_similarity': most_similar_reactant_similarity,
                'immediate_least_similar_reactants': least_similar_reactant,
                'immediate_least_similar_reactants_similarity': least_similar_reactant_similarity,
                #'most_similar_reactants': route_most_similar_sm[1], # route_most_similar_starting_material
                #'most_sm_to_target_similarity': route_most_similar_sm[2], # route_most_similar_starting_material_similarity
                "most_sm": route_most_similar_sm[1],
                'most_sm_to_reactants_similarity_combined': route_most_sm_to_reactants_similarity_combined,
                'most_sm_to_reactants_max_reactant': route_most_sm_to_reactants_max_reactant,
                'most_sm_to_reactants_similarity_max': route_most_sm_to_reactants_similarity_max,
                #'most_sm_to_reactants_similarity': route_most_sm_to_reactants_similarity,
                #'least_similar_reactants': route_least_similar_sm[1], # route_least_similar_starting_material
                #'least_sm_to_target_similarity': route_least_similar_sm[2], # route_least_similar_starting_material_similarity
                "least_sm": route_least_similar_sm[1],
                'least_sm_to_reactants_similarity_combined': route_least_sm_to_reactants_similarity_combined,
                'least_sm_to_reactants_max_reactant': route_least_sm_to_reactants_max_reactant,
                'least_sm_to_reactants_similarity_max': route_least_sm_to_reactants_similarity_max,
                #'least_sm_to_reactants_similarity': route_least_sm_to_reactants_similarity,
                #'least_sm_to_reactants_max_reactant': route_least_sm_to_reactants_max_reactant,
                #'heaviest_starting_material': route_heaviest_sm[1],
                #'heaviest_sm_to_target_similarity': route_heaviest_sm[2],
                "heaviest_sm": route_heaviest_sm[1],
                'heaviest_sm_to_reactants_similarity_combined': route_heaviest_sm_to_reactants_similarity_combined,
                'heaviest_sm_to_reactants_max_reactant': route_heaviest_sm_to_reactants_max_reactant,
                'heaviest_sm_to_reactants_similarity_max': route_heaviest_sm_to_reactants_similarity_max,
                #'heaviest_sm_to_reactants_similarity': route_heaviest_sm_to_reactants_similarity,
            }
            route_reaction_data[sorted_cano_reactions] = info
            if sorted_cano_reactions not in reaction_data:
                reaction_data[sorted_cano_reactions] = info
            else:
                reaction_data[sorted_cano_reactions]['step'].append(step)
                reaction_data[sorted_cano_reactions]['route_idx'].append(route_idx)
        json_routes.append({
            'route_idx': route_idx,
            'route': route,
            'main_target': main_target,
            'route_most_similar_starting_material': route_most_similar_sm[1],
            'route_most_similar_starting_material_similarity': route_most_similar_sm[2],
            'route_least_similar_starting_material': route_least_similar_sm[1],
            'route_least_similar_starting_material_similarity': route_least_similar_sm[2],
            'route_heaviest_starting_material': route_heaviest_sm[1],
            'route_heaviest_starting_material_similarity': route_heaviest_sm[2],
            'reaction_data': route_reaction_data
        })
    print(f'Loaded {len(reaction_data)} reactions from {in_path}')
    # save reaction data to csv
    reaction_data_df = pd.DataFrame(reaction_data)
    output_dir = os.path.join(
        PROJECT_ROOT,
        'data', 
        cfg.route_dataset.processed_route_dir
    )
    os.makedirs(output_dir, exist_ok=True)
    reaction_data_out_path = os.path.join(
        output_dir,
        cfg.route_dataset.processed_reaction_file
    )
    #reaction_data_df.to_csv(reaction_data_out_path, index=False)
    reaction_dict_to_csv(reaction_data, output_filename=reaction_data_out_path)
    print(f'Saved {len(reaction_data_df)} reactions to {reaction_data_out_path}')
    # save json routes to json file
    json_routes_out_path = os.path.join(
        PROJECT_ROOT,
        'data', 
        cfg.route_dataset.processed_route_dir,
        cfg.route_dataset.processed_route_file
    )
    with open(json_routes_out_path, 'w') as f:
        json.dump(json_routes, f, indent=4)
    print(f'Saved {len(json_routes)} routes to {json_routes_out_path}')

if __name__ == '__main__':
    process_uspto_hard() # pylint: disable=no-value-for-parameter
