import hydra
import os
import pandas as pd
from multiguide.dataset.helpers import class_to_idx

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import compare_reactant_smiles, \
                                 get_reaction_smiles, compute_class_info, clear_atom_map, \
                                 compute_average_topk_and_coverage

def compute_average_topk_and_coverage(df, topk_key='new_topk', coverage_key='round_trip_accuracy'):
    '''
        Compute the average topk and roundtrip accuracy for the dataframe.
    '''
    num_products = df['product_smi'].nunique()
    topk = {1: 0, 3: 0, 5: 0, 10: 0, 50: 0, 100: 0}
    # Original logic for exact matches
    topk_with_rank = df.groupby('product_smi').apply(
            lambda x: pd.DataFrame({topk_key: x.reset_index(drop=True)[topk_key]==1}),
            include_groups=False
        ).reset_index()
    topk_matches_df = topk_with_rank[topk_with_rank[topk_key]]
    for k in topk:
        topk[k] = topk_matches_df[topk_matches_df['level_1']+1<=k].shape[0]/num_products

    coverage = {1: 0, 3: 0, 5: 0, 10: 0}
    # For round-trip: check if ANY top-k prediction succeeds per product
    for k in coverage:
        product_success = df.groupby('product_smi').apply(
            lambda x: (x.head(k)[coverage_key] == 1).any(),
            include_groups=False
        )
        coverage[k] = product_success.mean()
    return topk, coverage

def not_necessary_with_latest_evaluation_files(df, true_reactions_path):
    '''
        Add new_topk, round_trip_accuracy, and true_class columns to the dataframe.
        Should not be necessary with latest evaluation files.
    '''
    df['new_topk'] = df.apply(lambda x: compare_reactant_smiles(x['true_reactants'], x['reactant_predictions']), axis=1)
    df['round_trip_accuracy'] = df.apply(lambda x: int(x['product_smi'] in x['round_trip_results']\
                                        or compare_reactant_smiles(x['reactant_predictions'], x['true_reactants'])), axis=1)
    # assign true classes
    #true_df_path = os.path.join(PROJECT_ROOT, 'data', 'schneiderk50k', 'raw_test.csv')
    true_df = pd.read_csv(true_reactions_path)
    true_df['product_smi']= true_df['reactants>reagents>production']\
                                        .apply(lambda x: get_reaction_smiles(x))\
                                        .apply(lambda x: clear_atom_map(x.split('>>')[-1]))
    #true_df['product_smi'] = true_df['reactants>reagents>production'].apply(lambda x: canonicalize_rxn(x, should_remove_atom_map=True).split('>>')[1])
    df = pd.merge(df, true_df[['product_smi', 'class']], on='product_smi', how='left')
    df = df.rename(columns={'class': 'true_class'})
    # apply rxn_insight_info
    df['pred_class'] = df['rxn_insight_info'].apply(lambda x: class_to_idx[eval(x)['CLASS']] if pd.notna(x) else None)
    # df['rxn_insight_class_str'] = df['rxn_insight_CLASS']
    # df['rxn_insight_class'] = df['rxn_insight_class_str'].apply(lambda x: class_to_idx[x])
    return df

def read_batches_from_experiment(experiment_name, start_batch=None, end_batch=None):
    '''
        Read all batches from an experiment.
    '''
    csv_dir = os.path.join(PROJECT_ROOT, experiment_name)
    csvs = sorted([f for f in os.listdir(csv_dir) if f.endswith('.csv')], key=lambda x: int(x.split('_start')[1].split('_end')[0]))
    print(f'====== {len(csvs)} files found ======')
    if start_batch is None:
        start_batch = 0
    if end_batch is None:
        end_batch = len(csvs)+1
    # read all csv and concatenate them in pandas dataframe
    all_files = [os.path.join(csv_dir, f) for f in csvs[start_batch:end_batch+1]]
    print(f'====== processing {len(all_files)} files ======')
    df = pd.concat([pd.read_csv(f) for f in all_files])
    return df

def load_data(experiment_name, reprocess=False):
    '''
        Load the data from the experiment.
    '''
    # read all csvs
    df = read_batches_from_experiment(experiment_name)
    if reprocess:
        df = not_necessary_with_latest_evaluation_files(df, true_reactions_path=os.path.join(PROJECT_ROOT, 'data',
                                                                                            'uspto_50k', 'raw', 
                                                                                            'test_rxninsightclass_no_overlap.csv'))
    print(f"# of unique product and reactant pairs: {df.groupby(['product_smi', 'true_reactants']).size().shape[0]}")
    return df

def safe_groupby_accuracy(df, col1, col2):
    result = df.groupby('product_smi').apply(
        lambda x: (x[col1].values == x[col2].values).mean(),
        include_groups=False
    )
    return result.mean()

def compute_metrics_for_dataset(df, topk_key='topk', coverage_key='round_trip_coverage'):
    # Core performance metrics
    topk, coverage = compute_average_topk_and_coverage(df, topk_key=topk_key, coverage_key=coverage_key)
    
    # Class matching accuracies (overall)
    #cls_to_rxn, cls_to_true, rxn_to_true = compute_class_info(df)
    cls_to_true = compute_class_info(df)
    
    # Class matching accuracies (per product)
    if 'classifier_output' in df.columns:
        prod_cls_acc = safe_groupby_accuracy(df, 'classifier_output', 'true_class')
        prod_cls_to_rxn_acc = safe_groupby_accuracy(df, 'classifier_output', 'rxn_insight_class')
    else:
        prod_cls_acc = None
        prod_cls_to_rxn_acc = None

    prod_rxn_acc = safe_groupby_accuracy(df, 'pred_class', 'true_class')
    
    return {
        'topk': topk,
        'coverage': coverage,
        #'cls_to_rxn': cls_to_rxn,
        'cls_to_true': cls_to_true,
        #'rxn_to_true': rxn_to_true,
        'prod_cls_acc': prod_cls_acc,
        'prod_rxn_acc': prod_rxn_acc,
        'prod_cls_to_rxn_acc': prod_cls_to_rxn_acc,
    }

def find_guided_improvements(unguided_df, guided_df, unguided_key='round_trip_accuracy', guided_key='topk_1'):
    """
    Find products where guided model finds solutions but unguided doesn't.
    
    Args:
        unguided_key: Column to check for valid solutions in unguided data
        guided_key: Column to check for valid solutions in guided data
    """
    # Products with no solutions in unguided
    no_solution_products = (unguided_df.groupby('product_smi')
                           .filter(lambda x: ~x[unguided_key].any())['product_smi']
                           .unique())
    
    # Of those, which have solutions in guided
    guided_solutions = (guided_df[guided_df['product_smi'].isin(no_solution_products)]
                       .groupby('product_smi')
                       .filter(lambda x: x[guided_key].any())['product_smi']
                       .unique())
    
    return {
        'products_no_unguided_solution': len(no_solution_products),
        'products_with_guided_solution': len(guided_solutions),
        'improvement_rate': len(guided_solutions) / len(no_solution_products) if len(no_solution_products) > 0 else 0
    }

@hydra.main(config_path='../configs', config_name='config.yaml')
def main(config):
    experiment_group = 'no_guidance'
    experiment_name = 'rsmiles'
    unguided_df = load_data(experiment_name=f'experiments/single_step_50k/{experiment_group}/{experiment_name}', reprocess=False)
    # for guided 50k coverage key is round_trip_accuracy_10
    metrics_unguided = compute_metrics_for_dataset(unguided_df, topk_key='topk', coverage_key='round_trip_coverage')
    print(f'Unguided experiment: {experiment_group}/{experiment_name}')
    print(f'metrics_unguided: {metrics_unguided}')

    experiment_group = 'no_guidance'
    #experiment_name = 'modelretroknn_datasetschneider50k_20250827_122836'
    #experiment_name = 'modelgraph2edits_datasetschneider50k_20250828_105541'
    experiment_group = 'reaction_type'
    experiment_name = '50k_steeredtrue_guidance1.5_length10_results100_time20250814_100438'
    guided_df = load_data(experiment_name=f'experiments/single_step_50k/{experiment_group}/{experiment_name}', reprocess=False)
    # for guided 50k coverage key is round_trip_accuracy_10
    metrics_guided = compute_metrics_for_dataset(guided_df, topk_key='topk', coverage_key='round_trip_coverage')
    print(f'Guided experiment: {experiment_group}/{experiment_name}')
    print(f'metrics_guided: {metrics_guided}')
    unguided_key = 'round_trip_coverage'
    guided_key = 'topk'
    guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')
    unguided_key = 'round_trip_coverage'
    guided_key = 'round_trip_coverage'
    guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')
    unguided_key = 'topk'
    guided_key = 'round_trip_coverage'
    guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')
    unguided_key = 'topk'
    guided_key = 'topk'
    guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')


    # unguided_df = load_data(experiment_name='50k_steeredFalse')
    # guided_df = load_data(experiment_name='50k_steeredtrue_guidance1.5_length10_results100_time20250814_100438')
    # metrics_unguided = compute_metrics_for_dataset(unguided_df, topk_key='new_topk', coverage_key='round_trip_accuracy_10')
    # metrics_guided = compute_metrics_for_dataset(guided_df, topk_key='topk_1', coverage_key='round_trip_accuracy_10')
    # print(f'metrics_unguided: {metrics_unguided}')
    # print(f'metrics_guided: {metrics_guided}')
    # unguided_key = 'round_trip_accuracy'
    # guided_key = 'topk_1'
    # guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    # print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')

    # unguided_key = 'round_trip_accuracy'
    # guided_key = 'round_trip_accuracy'
    # guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    # print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')

if __name__ == '__main__':
    main()