import os
import numpy as np
import pandas as pd
import scipy.stats as stats
from utils.util import load_config, get_accuracy_matrix, get_num_cot_steps_matrix
from utils.faithfulness import get_faithfulness_matrix
from utils.selection import get_ft_examples, get_qe_prompt, get_ea_response

run_name_titles = {
    'baseline_0': 'ZS',
    'baseline_1': 'ZS-CoT',
    'baseline_2': 'GTA',
    'baseline_3': 'DU',
    'approach_1': 'DF',
    'approach_2': 'SU',
    'approach_3': 'SF',
}

def get_metric_value(metric, results_dir, run_name, std_err=False, cot=True, best_check=None, use_parsed=False, hard_faith=False, return_array=False):
    response_dir = f'responses_{best_check}/' if best_check else ''
    if metric == 'accuracy':
        acc_matrix = get_accuracy_matrix(results_dir + f'{run_name}/{response_dir}', cot=cot, use_parsed=use_parsed)
        if return_array:
            return acc_matrix.flatten()
        elif std_err:
            return acc_matrix.mean(), acc_matrix.std() / np.sqrt(len(acc_matrix.flatten()))
        else:
            return acc_matrix.mean()
    elif metric == 'faithfulness':
        idx = int(hard_faith)
        faith_matrix = get_faithfulness_matrix(results_dir + f'{run_name}/{response_dir}')[idx]
        if return_array:
            return faith_matrix.flatten()
        elif std_err:
            return faith_matrix.mean(), faith_matrix.std() / np.sqrt(len(faith_matrix.flatten()))
        else:
            return faith_matrix.mean()
    elif metric == 'num_cot_steps':
        return get_num_cot_steps_matrix(results_dir + f'{run_name}/{response_dir}').mean()
    else:
        raise ValueError(f"Unknown metric: {metric}")

def get_ft_results_dict(dataset_name, model_name, run_names,
                        hard_faith=False, use_parsed=False, return_array=False,
                        best_checks=None, best_checks_dict=False,
                        axes=['accuracy', 'faithfulness'], std_err=False):
    results_dir = f'ft_results/{dataset_name}/{model_name}/test_n_100_seed_42_temp_0.0_maxtokens_512/'
    results_dict = {}
    best_checks = best_checks[model_name][dataset_name] if best_checks else None
    if best_checks:
        for run_name, best_check in zip(run_names, best_checks):
            if best_checks_dict:
                best_check = best_checks[run_name]
            results = [get_metric_value(metric, results_dir, run_name, std_err=std_err, return_array=return_array,
                                        best_check=best_check, use_parsed=use_parsed,
                                        hard_faith=hard_faith) for metric in axes]
            results_dict[run_name] = tuple(results)
    else:
        for run_name in run_names:
            results_list = []
            for i in range(3):
                results = [get_metric_value(metric, results_dir, run_name, i+1, use_parsed, hard_faith) for metric in axes]
                results_list.append(results)
            results_dict[run_name] = tuple(map(lambda x: [result[x] for result in results_list], range(len(axes))))
    return results_dict

def get_icl_results_dict(dataset_name, model_name, run_names,
                        hard_faith=False, use_parsed=True, std_err=False, return_array=False,
                        axes=['accuracy', 'faithfulness']):
    results_dir = f'icl_results_neurips/{dataset_name}/{model_name}/test_n_100_seed_42_temp_0.0_maxtokens_1024/'
    results_dict = {}
    for run_name in run_names:
        results = [get_metric_value(metric, results_dir, run_name, std_err=std_err,
                                    best_check=None, use_parsed=use_parsed, return_array=return_array,
                                    hard_faith=hard_faith) for metric in axes]
        results_dict[run_name] = tuple(results)
    return results_dict

def get_results_dict(dataset_names, model_name, std_err=False, hard_faith=False, use_parsed=False, axes=['accuracy', 'faithfulness'], cot=True, return_array=False):
    results_dict = {}
    for dataset_name in dataset_names:
        results_dir = f'results/{dataset_name}/{model_name}/test_n_100_seed_42_temp_0.0_maxtokens_512/'
        results = [get_metric_value(metric, results_dir, 'responses', std_err=std_err, cot=cot,
                                    best_check=None, use_parsed=use_parsed, return_array=return_array,
                                    hard_faith=hard_faith) for metric in axes]
        results_dict[dataset_name] = tuple(results)
    return results_dict

def get_title(run_name):
    title = run_name_titles[run_name.removesuffix('_c')]
    title = title + r'$^\mathrm{c}$' if run_name.endswith('_c') else title
    return title

def p_value_latex_table(dataset_names, model_name, run_names, best_checkpoints, use_parsed=True, hard_faith=False, icl=False):
    comparisons = ['baseline_3', 'baseline_3_c', 'approach_1', 'approach_1_c', 'approach_2', 'approach_2_c', 'approach_3', 'approach_3_c']
    baselines = ['baseline_1', 'baseline_2']
    latex_strs = [get_title(run_name) for run_name in comparisons]
    
    base_results = get_results_dict(dataset_names, model_name, return_array=True, hard_faith=hard_faith, use_parsed=use_parsed, std_err=False)
    
    for dataset_name in dataset_names:
        if icl:
            results_dict = get_icl_results_dict(dataset_name, model_name, run_names,
                                            use_parsed=use_parsed, hard_faith=hard_faith, std_err=False, return_array=True)
        else:
            results_dict = get_ft_results_dict(dataset_name, model_name, run_names,
                                            use_parsed=use_parsed, hard_faith=hard_faith,
                                            best_checks=best_checkpoints, best_checks_dict=True, std_err=False, return_array=True)
        # Add base results to the beginning of the lists
        acc, faith = base_results[dataset_name]
        results_dict = {'baseline_1': (acc, faith)} | results_dict
        
        for comp in comparisons:
            comp_data = results_dict[comp][1] 
            
            for baseline in baselines:
                baseline_data = results_dict[baseline][1]
                comp_data, baseline_data = equalize_lengths(comp_data, baseline_data)
                t_stat, p_value = stats.ttest_rel(comp_data, baseline_data)
                latex_strs[comparisons.index(comp)] += f' & {p_value:.4f}'
    
    return '\\\\\n'.join(latex_strs) + '\\\\'

def equalize_lengths(comp_data, baseline_data):
    min_len = min(len(comp_data), len(baseline_data))
    return comp_data[:min_len], baseline_data[:min_len]

def ft_latex_table_bold(dataset_names, model_name, run_names, best_checkpoints, use_parsed=True, hard_faith=False, icl=False):
    all_run_names = ['baseline_1'] + run_names
    latex_strs = [get_title(run_name) for run_name in all_run_names]
    base_results = get_results_dict(dataset_names, model_name, std_err=True, hard_faith=hard_faith, use_parsed=use_parsed)
    
    for dataset_name in dataset_names:
        if icl:
            results_dict = get_icl_results_dict(dataset_name, model_name, run_names,
                                            use_parsed=use_parsed, hard_faith=hard_faith, std_err=True)
        else:
            results_dict = get_ft_results_dict(dataset_name, model_name, run_names,
                                            use_parsed=use_parsed, hard_faith=hard_faith,
                                            best_checks=best_checkpoints, best_checks_dict=True, std_err=True)
        # Add base results to the beginning of the lists
        acc, faith = base_results[dataset_name]
        results_dict = {'baseline_1': (acc, faith)} | results_dict
        
        # Initialize best values
        best_acc = max(results_dict[run_name][0][0] for run_name in all_run_names)
        best_faith = max(results_dict[run_name][1][0] for run_name in all_run_names)
        
        for run_name in all_run_names:
            acc, faith = results_dict[run_name]
            acc_mean, acc_std_err = acc
            faith_mean, faith_std_err = faith
            
            # Apply bold formatting to the best values
            acc_str = f'\\textbf{{{acc_mean:.2f}}}' if acc_mean == best_acc else f'{acc_mean:.2f}'
            faith_str = f'\\textbf{{{faith_mean:.2f}}}' if faith_mean == best_faith else f'{faith_mean:.2f}'
            
            latex_strs[all_run_names.index(run_name)] += f' & {acc_str} $\\pm$ {acc_std_err:.2f} & {faith_str} $\\pm$ {faith_std_err:.2f}'
    
    return '\\\\\n'.join(latex_strs) + '\\\\'

def add_results_to_run_table(run_table, dataset_name, model_name, best_checks=None):
    run_names = run_table['Run'].str.lower().str.replace(' ', '_').tolist()

    if best_checks:
        accs_array = np.zeros(len(run_names))
        faith_array = np.zeros(len(run_names))
        for i, run_name in enumerate(run_names):
            results_dir = f'ft_results/{dataset_name}/{model_name}/test_n_100_seed_42_temp_0.0_maxtokens_512/{run_name}/'
            accs = [get_accuracy_matrix(results_dir+f'responses_{i+1}/').mean() for i in range(3)]
            faith = [get_faithfulness_matrix(results_dir+f'responses_{i+1}/')[0].mean() for i in range(3)]
            accs_array[i] = accs
            faith_array[i] = faith

    else:
        accs_array = np.zeros((len(run_names), 3))
        faith_array = np.zeros((len(run_names), 3))
        for i, run_name in enumerate(run_names):
            # Get faithfulness and accuracy from ft_results/
            results_dir = f'ft_results/{dataset_name}/{model_name}/test_n_100_seed_42_temp_0.0_maxtokens_512/{run_name}/'
            # 3 checkpoints per run
            accs = [get_accuracy_matrix(results_dir+f'responses_{i+1}/').mean() for i in range(3)]
            faith = [get_faithfulness_matrix(results_dir+f'responses_{i+1}/')[0].mean() for i in range(3)]
            accs_array[i] = accs
            faith_array[i] = faith
        
        for i, run_name in enumerate(run_names):
            for j in range(3):
                run_table.loc[i, f'Ckpt {j+1} Acc.'] = f'{100*accs_array[i, j]:.2f}%'
                run_table.loc[i, f'Ckpt {j+1} Faith.'] = f'{100*faith_array[i, j]:.2f}%'

    return run_table

def create_run_table_from_jsons(save_dir, run_names, temperatures):
    data = []
    for i, run_name in enumerate(run_names):
        temp = temperatures[run_name]
        # responses_dir = f'results/{dataset_name}/{model_name}/train_n_400_seed_42_temp_{temp}_maxtokens_512/responses/'

        ft_examples = load_config(f'{save_dir}{run_name}.json')['ft_examples']
        n_ft_examples = len(ft_examples)

        acc = np.mean([ex['label']==ex['llm_label'] for ex in ft_examples])
        faith = np.mean([ex['faithfulness'] for ex in ft_examples])

        data.append([run_name.title().replace('_', ' '), temp, n_ft_examples,
                     f'{100*acc:.2f}%', f'{100*faith:.2f}%'])

    return pd.DataFrame(data, columns=['Run', 'Temperature', 'No. Examples',
                                       'Examples Acc.', 'Examples Faith'])

def create_run_table_from_scratch(runs):
    data = []
    for run_name, args in runs.items():
        responses_dir, top_p, prompt_fn, response_fn, filter_correct = args
        temp = float(responses_dir[responses_dir.find('temp_'):].split('_')[1])

        ft_examples = get_ft_examples(responses_dir, top_p, prompt_fn, response_fn, filter_correct)
        n_ft_examples = len(ft_examples)

        acc = np.mean([ex['label']==ex['llm_label'] for ex in ft_examples])

        acc_full = get_accuracy_matrix(responses_dir)

        faith = np.mean([ex['faithfulness'] for ex in ft_examples])

        faith_full, _ = get_faithfulness_matrix(responses_dir)

        if faith_full.shape[1] > 1:
            best_idx = np.argmax(faith_full, axis=1)
            faith_full = faith_full[np.arange(faith_full.shape[0]), best_idx]
            acc_full = acc_full[np.arange(acc_full.shape[0]), best_idx]

        data.append([run_name.title().replace('_', ' '), temp, n_ft_examples,
                     f'{100*acc:.2f}%', f'{100*np.mean(acc_full):.2f}%',
                     f'{100*faith:.2f}%', f'{100*np.mean(faith_full):.2f}%'])

    return pd.DataFrame(data, columns=['Run', 'Temperature', 'No. Examples',
                                       'Examples Acc.', 'Dataset Acc.',
                                       'Examples Faith', 'Dataset Faith'])

def get_run_tables(run_names, top_ps, dataset_name, model_name, temperatures):
    run_tables = []
    for run_name in run_names:
        temp = temperatures[run_name]
        responses_dir = f'results/{dataset_name}/{model_name}/train_n_400_seed_42_temp_{temp}_maxtokens_512/responses/'
        filter_correct = True if run_name.endswith('_c') else False
        runs = {
            f'{run_name}_{top_p}': [responses_dir, top_p, get_qe_prompt, get_ea_response, filter_correct] for top_p in top_ps
        }
        run_table = create_run_table_from_scratch(runs)
        run_tables.append(run_table)
    return run_tables