import hydra
import os
import pandas as pd

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import compare_reactant_smiles, get_reaction_smiles, clear_atom_map, class_to_idx


def compute_class_info(df):
    '''
        Compute the class information for the dataframe.
    '''
    if 'classifier_output' in df.columns:
        classifier_output_to_rxn_insight_class_accuracy = ((df['classifier_output'] == df['rxn_insight_class'])).mean()
        classifier_output_to_true_class = ((df['classifier_output'] == df['true_class'])).mean()
    else:
        classifier_output_to_rxn_insight_class_accuracy = None
        classifier_output_to_true_class = None

    rxn_insight_class_to_true_class = ((df['rxn_insight_class'] == df['true_class'])).mean()

    return classifier_output_to_rxn_insight_class_accuracy, classifier_output_to_true_class, rxn_insight_class_to_true_class


def compute_average_topk_and_coverage(df, topk_key='new_topk', coverage_key='round_trip_accuracy'):
    '''
        Compute the average topk and roundtrip accuracy for the dataframe.
    '''
    num_products = df['product_smi'].nunique()
    topk = {1: 0, 3: 0, 5: 0, 10: 0, 50: 0, 100: 0}
    # Original logic for exact matches
    topk_with_rank = df.groupby('product_smi').apply(
            lambda x: pd.DataFrame({topk_key: x.reset_index(drop=True)[topk_key]==1}),
            include_groups=False
        ).reset_index()
    topk_matches_df = topk_with_rank[topk_with_rank[topk_key]]
    for k in topk:
        topk[k] = topk_matches_df[topk_matches_df['level_1']+1<=k].shape[0]/num_products
        
    coverage = {1: 0, 3: 0, 5: 0, 10: 0}
    # For round-trip: check if ANY top-k prediction succeeds per product
    for k in coverage:
        product_success = df.groupby('product_smi').apply(
            lambda x: (x.head(k)['is_correct'] == 1).any(),
            include_groups=False
        )
        coverage[k] = product_success.mean()
    return topk, coverage

def not_necessary_with_latest_evaluation_files(df, true_reactions_path):
    '''
        Add new_topk, round_trip_accuracy, and true_class columns to the dataframe.
        Should not be necessary with latest evaluation files.
    '''
    df['new_topk'] = (df['true_reactants']==df['reactant_predictions']).astype(int)
    df['round_trip_accuracy'] = df.apply(lambda x: int(x['product_smi'] in x['round_trip_results']\
                                        or compare_reactant_smiles(x['reactant_predictions'], x['true_reactants'])), axis=1)
    # assign true classes
    #true_df_path = os.path.join(PROJECT_ROOT, 'data', 'schneiderk50k', 'raw_test.csv')
    true_df = pd.read_csv(true_reactions_path)
    true_df['product_smi']= true_df['reactants>reagents>production']\
                                        .apply(lambda x: get_reaction_smiles(x))\
                                        .apply(lambda x: clear_atom_map(x.split('>>')[-1]))
    #true_df['product_smi'] = true_df['reactants>reagents>production'].apply(lambda x: canonicalize_rxn(x, should_remove_atom_map=True).split('>>')[1])
    df = pd.merge(df, true_df[['product_smi', 'class']], on='product_smi', how='left')
    df = df.rename(columns={'class': 'true_class'})
    df['rxn_insight_class_str'] = df['rxn_insight_CLASS']
    df['rxn_insight_class'] = df['rxn_insight_class_str'].apply(lambda x: class_to_idx[x])
    return df

def read_batches_from_experiment(experiment_dir, start_batch=None, end_batch=None):
    '''
        Read all batches from an experiment.
    '''
    #csv_dir = os.path.join(PROJECT_ROOT, 'experiments_old', experiment_name)
    csvs = sorted([f for f in os.listdir(experiment_dir) if f.endswith('.csv')], key=lambda x: int(x.split('_start')[1].split('_end')[0]))
    print(f'====== {len(csvs)} files found ======')
    if start_batch is None:
        start_batch = 0
    if end_batch is None:
        end_batch = len(csvs)+1
    # read all csv and concatenate them in pandas dataframe
    all_files = [os.path.join(experiment_dir, f) for f in csvs[start_batch:end_batch+1]]
    print(f'====== processing {len(all_files)} files ======')
    df = pd.concat([pd.read_csv(f) for f in all_files])
    return df

def load_data(experiment_dir):
    '''
        Load the data from the experiment.
    '''
    # read all csvs
    df = read_batches_from_experiment(experiment_dir)
    if '50k_steeredFalse' in experiment_dir:
        df = not_necessary_with_latest_evaluation_files(df, true_reactions_path=os.path.join(PROJECT_ROOT, 'data',
                                                                                            'uspto_50k', 'raw', 
                                                                                            'test_rxninsightclass_no_overlap.csv'))
    print(f"# of unique product and reactant pairs: {df.groupby(['product_smi', 'true_reactants']).size().shape[0]}")
    return df

def safe_groupby_accuracy(df, col1, col2):
    result = df.groupby('product_smi').apply(
        lambda x: (x[col1].values == x[col2].values).mean(),
        include_groups=False
    )
    return result.mean()

def compute_metrics_for_dataset(df, topk_key='new_topk', coverage_key='round_trip_accuracy_10'):
    # Core performance metrics
    topk, coverage = compute_average_topk_and_coverage(df, topk_key=topk_key, coverage_key=coverage_key)
    
    # Class matching accuracies (overall)
    cls_to_rxn, cls_to_true, rxn_to_true = compute_class_info(df)
    
    # Class matching accuracies (per product)
    if 'classifier_output' in df.columns:
        prod_classifier_output_to_true_class_acc = safe_groupby_accuracy(df, 'classifier_output', 'true_class')
        prod_classifier_output_to_rxn_insight_class_acc = safe_groupby_accuracy(df, 'classifier_output', 'rxn_insight_class')
    else:
        prod_classifier_output_to_true_class_acc = None
        prod_classifier_output_to_rxn_insight_class_acc = None
    prod_rxn_insight_class_to_true_class_acc = safe_groupby_accuracy(df, 'rxn_insight_class', 'true_class')
    return {
        'topk': topk,
        'coverage': coverage,
        'cls_to_rxn': cls_to_rxn,
        'cls_to_true': cls_to_true,
        'rxn_to_true': rxn_to_true,
        'prod_classifier_output_to_true_class_acc': prod_classifier_output_to_true_class_acc,
        'prod_rxn_insight_class_to_true_class_acc': prod_rxn_insight_class_to_true_class_acc,
        'prod_classifier_output_to_rxn_insight_class_acc': prod_classifier_output_to_rxn_insight_class_acc,
    }

def find_guided_improvements(unguided_df, guided_df, unguided_key='round_trip_accuracy', guided_key='topk_1'):
    """
    Find products where guided model finds solutions but unguided doesn't.
    
    Args:
        unguided_key: Column to check for valid solutions in unguided data
        guided_key: Column to check for valid solutions in guided data
    """
    # Products with no solutions in unguided
    no_solution_products = (unguided_df.groupby('product_smi')
                           .filter(lambda x: ~x[unguided_key].any())['product_smi']
                           .unique())
    
    # Of those, which have solutions in guided
    guided_solutions = (guided_df[guided_df['product_smi'].isin(no_solution_products)]
                       .groupby('product_smi')
                       .filter(lambda x: x[guided_key].any())['product_smi']
                       .unique())
    
    return {
        'products_no_unguided_solution': len(no_solution_products),
        'products_with_guided_solution': len(guided_solutions),
        'improvement_rate': len(guided_solutions) / len(no_solution_products) if len(no_solution_products) > 0 else 0
    }

@hydra.main(config_path='../configs', config_name='config.yaml')
def main(config):
    unguided_dir = os.path.join(PROJECT_ROOT, 'experiments', 'single_step_50k', 'no_guidance', 'modelrootaligned_datasetschneider50k_20250828_165003')
    unguided_df = load_data(experiment_dir=unguided_dir)
    #unguided_dir = os.path.join(PROJECT_ROOT, 'experiments', 'single_step_50k', 'no_guidance', 'modelrootaligned_datasetschneider50k_20250828_165003')
    #guided_dir = os.path.join(PROJECT_ROOT, 'experiments', 'single_step_50k', 'reaction_type', '50k_steeredtrue_guidance1.5_length10_results100_time20250814_100438')
    guided_dir = os.path.join(PROJECT_ROOT, 'experiments', 'single_step_50k', 'token_prefix', 'single_step_datauspto_50kraw_steeredtrue_guidance1000_length0_results100_time20250908_143752')
    unguided_df = load_data(experiment_dir=unguided_dir)
    guided_df = load_data(experiment_dir=guided_dir)
    metrics_unguided = compute_metrics_for_dataset(unguided_df, topk_key='topk_1', coverage_key='round_trip_accuracy_10')
    metrics_guided = compute_metrics_for_dataset(guided_df, topk_key='topk_1', coverage_key='is_correct')
    print(f'metrics_unguided: {metrics_unguided}')
    print(f'metrics_guided: {metrics_guided}')
    unguided_key = 'is_correct'
    guided_key = 'topk'
    guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')

    unguided_key = 'is_correct'
    guided_key = 'is_correct'
    guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')

    unguided_key = 'topk_1'
    guided_key = 'topk'
    guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')

    unguided_key = 'topk_1'
    guided_key = 'is_correct'
    guided_improvements = find_guided_improvements(unguided_df, guided_df, unguided_key=unguided_key, guided_key=guided_key)
    print(f'guided_improvements, unguided key: {unguided_key}, guided key: {guided_key}: {guided_improvements}')


if __name__ == '__main__':
    main()