import os
import hydra
import pandas as pd
import json
import pickle

from multiguide.helpers import PROJECT_ROOT
from multiguide.evaluation.helpers import generate_evaluation_summary, print_evaluation_summary
from multiguide.evaluation.helpers import load_experiment_results, _calculate_per_experiment_metrics
from multiguide.evaluation.helpers import calculate_route_completion_rates, simplify_metrics
from multiguide.evaluation.helpers import select_best_experiment_per_product

@hydra.main(config_path='../configs', config_name='config.yaml', version_base=None)
def report_manual_synthesis_results(config):
    # Logic: from unguided experiments, compute the sample quality + route complete metrics
    # from guided: pick the best experiment params per product, compute the sample quality + route complete metrics
        # alternative: keep track of all experiments that solved a given product?
    # might add improvement metrics
    # TODO: replace with loop to print entire latex table
    experiment_regex = r'uspto_190_seed42_modelrootaligned_steeredfalse_guidance0_length0_results100_candidates53_time20251027_111323'
    experiment_regex = r'uspto_190_seed42_modelrootaligned_steeredtrue_'
    experiment_regex = r'uspto_190_50kcheckpoint_seed42_modelgraph2edits_steeredfalse_guidance0_length0_results100_candidates72_time20251027_184647'
    #experiment_regex = r'uspto_190_50kcheckpoint_seed42_modelmegan_steeredfalse_guidance0_length0_results100_candidates72_time20251027_184455'
    #experiment_regex = r'uspto_190_50kcheckpoint_seed42_modelmhnreact_steeredfalse_guidance0_length0_results100_candidates72_time20251027_184608'
    #experiment_regex = r'uspto_190_50kcheckpoint_seed42_modelrootaligned_steeredtrue_guidance\d+\.?\d*_length\d+\.?\d*_results100_candidates72_time\d+'
    experiment_regex = r'uspto_190_seed42_modelrootaligned_steeredtrue_guidance\d+\.?\d*_length\d+_results100_candidates53_time\d+'
    #experiment_regex = r'uspto_190_seed42_modelrootaligned_steeredfalse_guidance0_length0_results100_candidates53_time20251027_111323'
    #experiment_regex = r'uspto_190_seed42_modelneuralsym_steeredfalse_guidance0_length0_results100_candidates53_time20251027_164404'
    #experiment_regex = r'uspto_190_50kcheckpoint_seed42_modelrootaligned_steeredtrue_guidance1.5_length15_results100_candidates53_time20251027_172241'
    #experiment_regex = r'uspto_190_50kcheckpoint_seed42_modelrootaligned_steeredfalse_guidance0_length0_results100_candidates53_time20251027_172315'
    experiment_dir = 'experiments/manual_synthesis'
    experiment_group = 'reaction_type'
    #experiment_group = 'no_guidance'
    experiment_filters = {'experiment_regex': experiment_regex}
    results = load_experiment_results(PROJECT_ROOT, experiment_dir, experiment_group, experiment_filters)
    guided_data, guided_experiments = select_best_experiment_per_product(
        list_dfs=results.values(), 
        list_experiment_names=results.keys()
    )
    guided_quality_metrics = _calculate_per_experiment_metrics(guided_data)
    true_routes_path = 'uspto_190/in_json/test_processed.json'
    full_true_routes_path = os.path.join(PROJECT_ROOT, 'data', true_routes_path)
    with open(full_true_routes_path, 'r') as f:
        true_routes = json.load(f)
    route_completion = calculate_route_completion_rates(
        results, true_routes, use_starting_material=True, max_steps=100
    )
    print(route_completion)
    pickle.dump(route_completion['mixed_param_completion']['route_details'], open(
        os.path.join(
            PROJECT_ROOT, 
            'notebooks', 
            'route_completion_rootaligned_reaction_type.pkl'
        ), 
        'wb'
    ))
    # TODO: add smthg to average across unguided experiments vs aggregating guided results
    metrics_of_interest = [
        'perc_samples_per_product',
        'percentage_products_with_exact_match',
        'percentage_products_with_class_correct',
        'perc_class_correct_samples_per_product',
        'percentage_products_with_rxn_name_correct',
        'perc_rxn_name_correct_samples_per_product',
        'percentage_products_with_round_trip_correct',
        'perc_round_trip_correct_samples_per_product',
        'avg_topk_1',
        'avg_topk_3',
        'avg_topk_5',
        'avg_topk_100',
        # 'avg_coverage_1',
        # 'avg_coverage_3',
        # 'avg_coverage_5',
        # 'avg_tanimoto_to_starting',
        # 'avg_tanimoto_to_target'
    ]
    guided_quality_metrics = simplify_metrics(guided_quality_metrics)
    #print(format_latex_row('Rsmiles-G', guided_quality_metrics, metrics_of_interest))
    #print(guided_quality_metrics)
    
    # method    & samples_product & prodexact_match & perc_with_class & samples_class   & rxn_name_correc & rxnname_samples & rt_correct      & rt_samples      &  top_1          & top_3           & top_5           & top_100 
    # Rsmiles-G & \makecell{0.68} & \makecell{0.82} & \makecell{0.67} & \makecell{0.07} & \makecell{0.99} & \makecell{0.14} & \makecell{0.91} & \makecell{0.12} & \makecell{0.50} & \makecell{0.67} & \makecell{0.71} & \makecell{0.84} & \makecell{0.69} & \makecell{0.82} & \makecell{0.85} & \makecell{0.59} & \makecell{0.65} \\
    # Rsmiles-G & \makecell{0.70} & \makecell{0.83} & \makecell{0.71} & \makecell{0.07} & \makecell{0.99} & \makecell{0.14} & \makecell{0.91} & \makecell{0.12} & \makecell{0.52} & \makecell{0.68} & \makecell{0.72} & \makecell{0.84} \\
    # Rsmiles   & \makecell{0.60} & \makecell{0.80} & \makecell{0.51} & \makecell{0.06} & \makecell{0.99} & \makecell{0.13} & \makecell{0.90} & \makecell{0.11} & \makecell{0.46} & \makecell{0.63} & \makecell{0.68} & \makecell{0.80} & \makecell{0.66} & \makecell{0.80} & \makecell{0.83} & \makecell{0.59} & \makecell{0.66} \\
    # Neuralsym & \makecell{0.23} & \makecell{0.69} & \makecell{0.27} & \makecell{0.04} & \makecell{1.00} & \makecell{0.14} & \makecell{0.84} & \makecell{0.10} & \makecell{0.35} & \makecell{0.51} & \makecell{0.57} & \makecell{0.69} \\
    
    # RsmilesG72 & \makecell{0.53} & \makecell{0.54} & \makecell{0.65} & \makecell{0.07} & \makecell{0.99} & \makecell{0.19} & \makecell{0.78} & \makecell{0.15} & \makecell{0.30} & \makecell{0.41} & \makecell{0.45} & \makecell{0.54} \\
    # RsmilesG72 & \makecell{0.47} & \makecell{0.52} & \makecell{0.54} & \makecell{0.06} & \makecell{1.00} & \makecell{0.20} & \makecell{0.78} & \makecell{0.15} & \makecell{0.28} & \makecell{0.39} & \makecell{0.43} & \makecell{0.52} \\
    # RsmilesG72 & \makecell{0.59} & \makecell{0.49} & \makecell{0.53} & \makecell{0.07} & \makecell{1.00} & \makecell{0.20} & \makecell{0.77} & \makecell{0.17} & \makecell{0.22} & \makecell{0.34} & \makecell{0.39} & \makecell{0.50} \\
    # Rsmiles5   & \makecell{0.36} & \makecell{0.51} & \makecell{0.42} & \makecell{0.04} & \makecell{1.00} & \makecell{0.18} & \makecell{0.78} & \makecell{0.13} & \makecell{0.26} & \makecell{0.39} & \makecell{0.44} & \makecell{0.51} \\
    # Graph2edit & \makecell{0.39} & \makecell{0.49} & \makecell{0.22} & \makecell{0.07} & \makecell{0.98} & \makecell{0.16} & \makecell{0.77} & \makecell{0.11} & \makecell{0.26} & \makecell{0.37} & \makecell{0.41} & \makecell{0.49} \\
    # Megan      & \makecell{0.52} & \makecell{0.50} & \makecell{0.35} & \makecell{0.05} & \makecell{0.99} & \makecell{0.28} & \makecell{0.76} & \makecell{0.23} & \makecell{0.22} & \makecell{0.33} & \makecell{0.38} & \makecell{0.50} \\
    # Mhnreact   & \makecell{0.45} & \makecell{0.48} & \makecell{0.25} & \makecell{0.07} & \makecell{1.00} & \makecell{0.31} & \makecell{0.76} & \makecell{0.22} & \makecell{0.21} & \makecell{0.35} & \makecell{0.38} & \makecell{0.48} \\
    # 

    # RsmilesG53 & \makecell{0.59} & \makecell{0.49} & \makecell{0.53} & \makecell{0.07} & \makecell{1.00} & \makecell{0.20} & \makecell{0.77} & \makecell{0.17} & \makecell{0.22} & \makecell{0.34} & \makecell{0.39} & \makecell{0.50} \\
    

    # summary = generate_evaluation_summary(
    #     project_root=PROJECT_ROOT,
    #     experiment_dir=experiment_dir,
    #     experiment_group=experiment_group, 
    #     use_starting_material=False,
    #     true_routes_path=true_routes_path,
    #     process_df=process_df,
    #     experiment_filters={'experiment_regex': experiment_regex},
    #     baseline_experiment_group=baseline_experiment_group,
    #     baseline_experiment_filters=baseline_experiment_filters
    # )

    # # TODO: print in latex table
    # metrics = ['avg_num_samples_per_product','avg_class_correct_samples_per_product', 
    #             'avg_rxn_name_correct_samples_per_product', 'avg_round_trip_correct_samples_per_product', 
    #             'avg_best_tanimoto_to_target_per_product', 'avg_best_tanimoto_to_starting_per_product'
    #             ]
    # for metric in metrics:
    #     metric_name = summary['baseline_comparison']['improvement_metrics']['comparison_summary'][metric]['metric_name']
    #     print('-'*100)
    #     print(f"{metric_name}")
    #     print(f"guided value: {summary['baseline_comparison']['improvement_metrics']['comparison_summary'][metric]['guided_value']}")
    #     print(f"baseline value: {summary['baseline_comparison']['improvement_metrics']['comparison_summary'][metric]['baseline_value']}")
    #     print(f"absolute improvement: {summary['baseline_comparison']['improvement_metrics']['comparison_summary'][metric]['absolute_improvement']}")
    #     print(f"improvement ratio: {summary['baseline_comparison']['improvement_metrics']['comparison_summary'][metric]['improvement_ratio']}")
    #     print(f"improvement percentage: {summary['baseline_comparison']['improvement_metrics']['comparison_summary'][metric]['improvement_percentage']}")

if __name__ == '__main__':
    report_manual_synthesis_results()