'''
    Report keys metrics in single-step evaluation. Calculates averages over the entire dataset, per product, and additional stats.
'''
import os
import hydra
import pandas as pd
import numpy as np
from typing import Dict, Any, List, Optional

from multiguide.helpers import PROJECT_ROOT
from multiguide.evaluation.helpers import _calculate_per_experiment_metrics, analyze_ground_truth_coverage_across_experiments
from multiguide.evaluation.helpers import load_single_step_results
from multiguide.dataset.helpers import compare_reactant_smiles

def get_unsolved_products(df, key='topk'):
    products = df.groupby('product_smi')[key].any().reset_index()
    products_no_match = products[~products[key]]['product_smi'].unique().tolist()
    return products_no_match

def get_unsolved_products_across_experiments(unguided_experiments):
    '''
        Get the unsolved products across the unguided experiments.
    '''
    experiment_group_dir = os.path.join(
        PROJECT_ROOT,
        'experiments',
        'single_step_50k',
        'no_guidance'
    )
    unguided_dfs = []
    common_unsolved_products = {'topk': set(), 'round_trip_accuracy': set()}
    all_unsolved_products = {'topk': [], 'round_trip_accuracy': []}
    for experiment_name in unguided_experiments:
        experiment_dir = os.path.join(experiment_group_dir, experiment_name)
        experiment_df = load_single_step_results(experiment_dir)
        unguided_dfs.append(experiment_df)
        for key in common_unsolved_products.keys():
            unsolved_products_experiment = get_unsolved_products(experiment_df, key=key)
            all_unsolved_products[key].append(set(unsolved_products_experiment))
    for key in common_unsolved_products.keys():  
        common_unsolved_products[key] = set.intersection(*all_unsolved_products[key])
    return common_unsolved_products, unguided_dfs

def get_dfs_recovered_by_single_guided_experiments(
    common_unsolved_products, 
    list_experiment_names,
    group_dir,
    guided_key='topk',
    unguided_key='topk'
):
    '''
        Get the dfs recovered by the single guided experiments.
    '''
    products_no_match = common_unsolved_products[unguided_key]
    products_solved = []
    all_experiment_dfs = []
    for experiment_name in list_experiment_names:
        experiment_dir = os.path.join(group_dir, experiment_name)
        guided_df = load_single_step_results(experiment_dir)
        products_solved_df = guided_df.groupby('product_smi')[guided_key].any().reset_index()
        products_solved_by_experiment = products_solved_df[products_solved_df[guided_key] \
                                            & (products_solved_df['product_smi'].isin(products_no_match))]['product_smi'].tolist()
        products_solved_by_this_experiment = []
        for p in products_solved_by_experiment:
            if p not in products_solved:
                products_solved.append(p)
                products_solved_by_this_experiment.append(p)
        products_solved_by_experiment_df = guided_df[guided_df['product_smi'].isin(products_solved_by_this_experiment)]
        all_experiment_dfs.append(products_solved_by_experiment_df)
    return all_experiment_dfs

def jaccard_similarity(pred1, pred2):
    return len(set(pred1).intersection(set(pred2)))/len(set(pred1).union(set(pred2)))

@hydra.main(config_path='../configs', config_name='config.yaml')
def report_single_step_evaluation(config):
    '''
        Report keys metrics in single-step evaluation. Calculates averages over the entire dataset, 
        per product, and additional stats.
    '''
    # unguided_experiments = [
    #     '50k_seed42_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300',
    #     '50k_seed90_rootaligned_steeredfalse_guidance0.0_length0_results100_candidates72_time20251022_151349',
    #     '50k_seedrandom_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_001004'
    # ]
    # unguided_experiments = [
    #     '50k_seed42_modelneuralsym_steeredfalse_guidance0_length0_results100_candidates72_time20251023_122629',
    #     '50k_seed90_modelneuralsym_steeredfalse_guidance0_length0_results100_candidates72_time20251023_014219',
    #     '50k_seed101_modelneuralsym_steeredfalse_guidance0_length0_results100_candidates72_time20251023_122554'
    # ]
    # unguided_experiments = [
    #     '50k_seed42_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_171624',
    #     '50k_seed90_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_174108',
    #     '50k_seed101_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_174131'
    # ]
    unguided_experiments = [
        '50k_seed42_modelmegan_steeredfalse_guidance0_length0_results100_candidates72_time20251023_140013',
        '50k_seed90_modelmegan_steeredfalse_guidance0_length0_results1_candidates72_time20251023_183946',
        '50k_seed101_modelmegan_steeredfalse_guidance0_length0_results100_candidates72_time20251024_000601'
    ]
    # unguided_experiments = [
    #     '50k_seed42_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_171624',
    #     '50k_seed42_modelneuralsym_steeredfalse_guidance0_length0_results100_candidates72_time20251023_122629',
    #     '50k_seed42_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300',
    #     '50k_seed42_modelgraph2edits_steeredfalse_guidance0_length0_results100_candidates72_time20251023_135250',
    #     '50k_seed42_modelmegan_steeredfalse_guidance0_length0_results100_candidates72_time20251023_140013',
    #     '50k_seed42_modelmhnreact_steeredfalse_guidance0_length0_results100_candidates72_time20251023_162655'
    # ]
    common_unsolved_products, unguided_dfs = get_unsolved_products_across_experiments(unguided_experiments)
    # guided_experiments = [
    #     '50k_steeredtrue_guidance0.5_length15_results100_candidates72_time20251021_185901',
    #     '50k_steeredtrue_guidance1.0_length5_results100_candidates72_time20251021_224507',
    #     # '50k_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_100057',
    #     '50k_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_161934',
    #     '50k_steeredtrue_guidance1.5_length10_results100_candidates72_time20251020_192941',
    #     '50k_steeredtrue_guidance1.5_length15_results100_candidates72_time20251021_134411',
    #     '50k_steeredtrue_guidance2.0_length15_results100_candidates72_time20251021_164034',
    #     '50k_seed90_steeredtrue_guidance0.5_length5_results100_candidates72_time20251022_193804',
    #     '50k_seed90_modelrootaligned_steeredtrue_guidance0.7_length7_results100_candidates72_time20251023_015026'
    # ]
    guided_experiments = [
        #'50k_seed42_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_171624'
        '50k_steeredtrue_guidance0.5_length15_results100_candidates72_time20251021_185901'
    ]
    guided_group_dir = os.path.join(
        PROJECT_ROOT,
        'experiments', 
        'single_step_50k',
        'reaction_type'
    )
    guided_key='topk'
    unguided_key='topk'
    guided_experiment_dfs = get_dfs_recovered_by_single_guided_experiments(
        common_unsolved_products,
        guided_experiments,
        guided_group_dir,
        guided_key,
        unguided_key
    )
    # {'total_samples': 5142,
    # 'total_products': 272,
    # 'avg_samples_per_product': 18.904411764705884,
    # 'sample_exact_match_accuracy': 0.0,
    # 'products_with_exact_match': 0,
    # 'percentage_products_with_exact_match': 0.0,
    # 'sample_class_accuracy': 0.25009723842862697,
    # 'products_with_class_correct_samples': 224,
    # 'percentage_products_with_class_correct': 0.8235294117647058,
    # 'avg_class_correct_samples_per_product': 5.741071428571429,
    # 'sample_rxn_name_accuracy': 0.6353558926487748,
    # 'products_with_rxn_name_correct_samples': 269,
    # 'percentage_products_with_rxn_name_correct': 0.9889705882352942,
    # 'avg_rxn_name_correct_samples_per_product': 12.144981412639405,
    # 'sample_round_trip_accuracy': 0.4634383508362505,
    # 'products_with_round_trip_correct_samples': 172,
    # 'percentage_products_with_round_trip_correct': 0.6323529411764706,
    # 'avg_round_trip_correct_samples_per_product': 13.854651162790697,
    # 'avg_tanimoto_to_starting': 0.636574532084749,
    # 'max_tanimoto_to_starting': 1.0,
    # 'avg_tanimoto_to_target': 0.8192727116602525,
    # 'max_tanimoto_to_target': 1.0,
    # 'avg_topk': {1: 0.0, 3: 0.0, 5: 0.0, 10: 0.0, 50: 0.0, 100: 0.0},
    # 'avg_coverage': {1: 0.5183823529411765,
    # 3: 0.6176470588235294,
    # 5: 0.6323529411764706,
    # 10: 0.6323529411764706}}
    metrics_of_interest = [
        'total_products',
        'unsolved_products',
        'avg_samples_per_product', 
        'avg_class_correct_samples_per_product',
        'avg_rxn_name_correct_samples_per_product',
        'avg_round_trip_correct_samples_per_product',
        #'avg_tanimoto_to_starting',
        #'avg_tanimoto_to_target'
    ]
    guided_recovered_df = pd.concat(guided_experiment_dfs)
    guided_metrics = _calculate_per_experiment_metrics(guided_recovered_df)
    # average unguided metrics
    guided_metrics['unsolved_products'] = 0
    all_unguided_metrics = {}
    jaccard_scores = []
    for unguided_df in unguided_dfs:
        recovered_unguided_df = unguided_df[unguided_df['product_smi'].isin(guided_recovered_df['product_smi'].unique().tolist())]
        unguided_metrics = _calculate_per_experiment_metrics(recovered_unguided_df)
        unguided_metrics['unsolved_products'] = len(common_unsolved_products[unguided_key])
        # compute the mean and std for each metric of interest
        # Get all predictions for this product from both dataframes
        for product_smi in recovered_unguided_df['product_smi'].unique():
            unguided_preds = recovered_unguided_df[recovered_unguided_df['product_smi'] == product_smi]['reactant_predictions'].tolist()
            guided_preds = guided_recovered_df[guided_recovered_df['product_smi'] == product_smi]['reactant_predictions'].tolist()
            jaccard = jaccard_similarity(unguided_preds, guided_preds)
            jaccard_scores.append(jaccard)
        for metric in metrics_of_interest:
            if metric not in all_unguided_metrics:
                all_unguided_metrics[metric] = [unguided_metrics[metric]]
            else:
                all_unguided_metrics[metric].append(unguided_metrics[metric])
    averaged_unguided_metrics = {}
    averaged_unguided_metrics['jaccard_similarity'] = np.mean(jaccard_scores)
    averaged_unguided_metrics['jaccard_similarity_std'] = np.std(jaccard_scores)
    for metric in metrics_of_interest:
        averaged_unguided_metrics[metric] = np.mean(all_unguided_metrics[metric])
        averaged_unguided_metrics[metric + '_std'] = np.std(all_unguided_metrics[metric])
    # print guided vs unguided metrics of interest
    for metric in metrics_of_interest:
        print(f"{metric}: guided {guided_metrics[metric]}, unguided {averaged_unguided_metrics[metric]} ± {averaged_unguided_metrics[metric + '_std']}")
    print(f"jaccard_similarity: {averaged_unguided_metrics['jaccard_similarity']} ± {averaged_unguided_metrics['jaccard_similarity_std']}")
    print('done')
    # list_dfs = []
    # list_experiment_names = [
    #     '50k_steeredtrue_guidance0.5_length15_results100_candidates72_time20251021_185901',
    #     '50k_steeredtrue_guidance1.0_length5_results100_candidates72_time20251021_224507',
    #     # '50k_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_100057',
    #     '50k_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_161934',
    #     '50k_steeredtrue_guidance1.5_length10_results100_candidates72_time20251020_192941',
    #     '50k_steeredtrue_guidance1.5_length15_results100_candidates72_time20251021_134411',
    #     '50k_steeredtrue_guidance2.0_length15_results100_candidates72_time20251021_164034'
    # ]
    # for experiment_name in list_experiment_names:
    #     experiment_dir = os.path.join(
    #         PROJECT_ROOT,
    #         'experiments', 
    #         'single_step_50k',
    #         'reaction_type',
    #         experiment_name
    #     )
    #     df = load_single_step_results(experiment_dir)
    #     list_dfs.append(df)
    # results = analyze_ground_truth_coverage_across_experiments(
    #     experiment_dfs=list_dfs,
    #     experiment_names=list_experiment_names
    # )
    # print(results)
    # load the results
    # experiment_dir = os.path.join(
    #     PROJECT_ROOT,
    #     'experiments', 
    #     'single_step_50k',
    #     'reaction_type',
    #     '50k_steeredtrue_guidance0.5_length15_results100_candidates72_time20251021_185901'
    # )
    # guided_df = load_single_step_results(experiment_dir)
    # not_guided_experiment_dir = os.path.join(
    #     PROJECT_ROOT,
    #     'experiments', 
    #     'single_step_50k',
    #     'no_guidance',
    #     '50k_steeredfalse_guidance0_length0_results100_candidates72_time20251021_001004'
    # )
    # not_guided_df = load_single_step_results(not_guided_experiment_dir)
    # comparison_results = compare_experiment_pair(
    #     df1=not_guided_df,
    #     df2=guided_df,
    #     exp1_name='no_guidance',
    #     exp2_name='guided_0.5_15',
    #     key1='topk',
    #     key2='round_trip_accuracy'
    # )
    # print(comparison_results)
    # compute metrics
    # df['round_trip_accuracy'] = df.apply(lambda x: x['product_smi'] in x['round_trip_results'], axis=1)
    # df['topk'] = df.apply(lambda x: compare_reactant_smiles(x['true_reactants'], x['reactant_predictions']), axis=1)
    # metrics = _calculate_per_experiment_metrics(df)
    # print(metrics)

if __name__ == "__main__":
    report_single_step_evaluation()
