import hydra
import time
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import pickle
import pandas as pd
import json
from syntheseus.search.graph.and_or import AndNode
from syntheseus.search.graph.route import SynthesisGraph
from rdkit import Chem
from multiguide.helpers import PROJECT_ROOT
from multiguide.evaluation.helpers import compute_diversity, load_ground_truth_routes, extract_routes_for_target
from multiguide.dataset.helpers import extract_reaction_smiles
#from scripts.evaluate_single_step_of_one_molecule import single_step_evaluation_of_one_molecule
from multiguide.syntheseus.helpers import flatten_evaluation_results
from multiguide.dataset.helpers import get_ground_truth_for_step, get_sorted_cano_smiles
from multiguide.evaluation.helpers import evaluate_results_for_one_batch

# def evaluate_single_route(config, reactions, ground_truth_route, ground_truth_property):
#     '''
#     Evaluate a single route.
#     '''
#     all_rows = []
#     for num_reaction, reaction in enumerate(reactions):
#         product_smi, reactant_predictions = extract_reaction_smiles(reaction)
#         target_property, true_reactants = get_ground_truth_for_step(num_reaction,
#                                                                  ground_truth_route,
#                                                                  ground_truth_property)
#         evaluation_df = single_step_evaluation_of_one_molecule(config,
#                                                                 reactant_predictions,
#                                                                 product_smi,
#                                                                 true_reactants=true_reactants,
#                                                                 target_property=target_property)
#         flat_evaluation_df = flatten_evaluation_results(evaluation_df)
#         all_rows.extend(flat_evaluation_df)
#     evaluation_df = pd.DataFrame(all_rows)
#     return evaluation_df

def canonicalize_and_construct_reaction(product_smi, reactants_smi):
    '''
        Canonicalize the reaction and construct the reaction SMILES string.
    '''
    product_smi_cano = get_sorted_cano_smiles(product_smi.split('.'))
    predicted_reactants_cano = get_sorted_cano_smiles(reactants_smi.split('.'))
    predicted_reaction_cano = f'{predicted_reactants_cano}>>{product_smi_cano}'
    return predicted_reaction_cano, predicted_reactants_cano, product_smi_cano

@hydra.main(config_path='../configs', config_name='config.yaml')
def main(config):
    # read found routes: need experiment folder + target idx + ground truth target name/class
    ground_truth_routes = load_ground_truth_routes(config)
    data = []
    for target_idx in range(config.multi_step_evaluation.route_start_idx, config.multi_step_evaluation.route_end_idx):
        print(f'======= processing target {target_idx}')
        gt_route = ground_truth_routes[target_idx]
        start_time = time.time()
        routes, _, search_stats = extract_routes_for_target(
            config,
            target_idx
        )
        if routes is None:
            print(f'======= no routes found for target {target_idx}')
            continue
        print(f'======= extracted {len(routes)} routes in {time.time() - start_time} seconds')
        # ground_truth_route = ground_truth_routes[i]['route']
        # if config.classifier_guidance.true_property == 'similarity_to_starting_material':
        #     ground_truth_property = [p[-1] for p in ground_truth_routes[i][config.classifier_guidance.true_property]]
        # elif config.classifier_guidance.true_property == 'classes':
        #     ground_truth_property = ground_truth_routes[i][config.classifier_guidance.true_property]
        # else:
        #     raise ValueError(f'======= unknown true property: {config.classifier_guidance.true_property}')
        # ground_truth_property = None
        if len(routes)==0:
            result = {
                'product_smi': 'C',
                'reactant_predictions': 'O',
                'reactions': 'O>>C',
                'true_product_cano': 'Br',
                'true_reactants': 'N',
                'true_reaction_cano': 'N>>Br',
                'true_class': -2,
                'sample_route_idx': -1,
                'target_idx': target_idx,
                'true_route_length': len(gt_route['route']),
                'original_target': 'S',
                'original_starting_material': 'F',
                'solved': False
            }
            result.update(search_stats)
            data.append(result)
        for route_idx, route in enumerate(routes):
            print(f'length of gt route: {len(gt_route["route"])}')
            reactions = route if config.search.type=='desp' else route.nodes()
            print(f'length of route {route_idx}: {len(reactions)}')
            for reaction_idx, reaction in enumerate(reactions):
                product_smi, reactant_predictions = extract_reaction_smiles(reaction)
                # NOTE: canonicalize the product might not be necessary (already handled by syntheseus,
                # but good to have this as fall back)
                predicted_reaction_cano, predicted_reactants_cano, predicted_product_cano = canonicalize_and_construct_reaction(
                    product_smi,
                    reactant_predictions
                )
                if reaction_idx < len(gt_route['route']):
                    # NOTE: handle the case of a branching route where the predicted reactions
                    # have a different order to ground truth reactions
                    if predicted_reaction_cano in gt_route['route']:
                        index_of_true_reaction = gt_route['route'].index(predicted_reaction_cano)
                    # Otherwise we assume the right reaction should be the one at the same position in the true route
                    else:
                        index_of_true_reaction = reaction_idx
                    true_reaction_retro = gt_route['route'][index_of_true_reaction]
                    true_reaction_cano, true_reactants_cano, true_product_cano = canonicalize_and_construct_reaction(
                        product_smi=true_reaction_retro.split('>>')[0],
                        reactants_smi=true_reaction_retro.split('>>')[-1]
                    )
                else:
                    # if the predicted route is longer than the true route
                    true_reactants_cano = 'N'
                    true_product_cano = 'Br'
                    true_reaction_cano = 'N>>Br'
                    #                     'true_product_cano': 'Br',
                    # 'true_reactants': 'N',
                    # 'true_reaction_cano': 'N>>Br',
                result = {
                    'product_smi': predicted_product_cano, # canonicalize the product for each predicted step
                    'reactant_predictions': predicted_reactants_cano,
                    'reactions': predicted_reaction_cano,
                    'true_product_cano': true_product_cano,
                    'true_reactants': true_reactants_cano,
                    'true_reaction_cano': true_reaction_cano,
                    'true_class': gt_route['reaction_data'][true_reaction_cano]['reaction_type'] if true_reaction_cano in gt_route['reaction_data'] else -2,
                    'sample_route_idx': route_idx,
                    'target_idx': target_idx,
                    'true_route_length': len(gt_route['route']),
                    'original_target': gt_route['main_target'],
                    'original_starting_material': gt_route[config.classifier_guidance.starting_material_key],
                    'solved': True
                }
                result.update(search_stats)
                data.append(result)

    if len(data) > 0:
        df = pd.DataFrame(data)
        df = evaluate_results_for_one_batch(df, config)
        out_dir = os.path.join(
            PROJECT_ROOT,
            'experiments', 
            'search',
            config.search.type,
            config.general.experiment_name,
            f'strategy_{config.search.strategy}',
            'evaluations'
        )
        os.makedirs(out_dir, exist_ok=True)
        df.to_csv(os.path.join(
            out_dir,
            f'evaluation_start{config.multi_step_evaluation.route_start_idx}_end{config.multi_step_evaluation.route_end_idx}.csv'
        ), index=False)
    print('Done evaluating.')

if __name__ == "__main__":
    main()