import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd
import os
import json
import seaborn as sns
import numpy as np
from sklearn.metrics import auc


math_tasks = ['main','elementary_math_qa_question_only','tokenized',
                    'unit_conversion_si_conversion','intersect_geometry']
generation_tasks = ['sciq','typescript_chunks','disfl_qa','qa_wikidata','polish_sequence_labeling']
bad_tasks = ['wsc','cola','copa']
tasks_to_exclude = math_tasks + generation_tasks + bad_tasks

def calculate_auc(df):
    """Calculate the AUC given the normalized data and accuracy"""
        
    df = df.sort_values('data_size_normalized')
    auc_val = auc(np.array(df['data_size_normalized']), np.array(df['accuracy_normalized']))

    return auc_val

def clean_curve(df, data_cutoff, human_eval):
    """Extrapolate the curve for missing values (performance at data size <=2500) or non-increasing accuracy"""

    for i in [50, 100, 200, 500, 1000, 2500]:
        if i not in df['data_size'].unique():
            df = pd.concat([df, pd.DataFrame({'data_size':[i], 'accuracy':[0]})])

    df= df[(df['data_size'] <=data_cutoff)]
    df= df.sort_values('data_size').set_index('data_size')

    prev_acc = 0
    for i in df.index:
        curr_acc = float(df[df.index==i]['accuracy'].values[0])
        if prev_acc > curr_acc:
            df.loc[i] = prev_acc
        else:
            prev_acc = curr_acc

    df = df.reset_index()
    df['accuracy_normalized'] = (df['accuracy'] - df['accuracy'].min()) /(human_eval - df['accuracy'].min())
    df['data_size_normalized'] = df['data_size'] / df['data_size'].max()

    return df


def load_run_full_results(model_name, tasks, data_cutoff=2500, use_log = False):
    """Load each fine-tuning experiment result and calculate AUC"""

    d = {}

    with open("../results/human_eval.json", "r") as f:
        d_human_eval = json.load(f)
            
    for task_json in tasks:

        task_name=task_json.split('_full')[0].split(f'{model_name}_')[-1]
        human_eval = None if task_name not in d_human_eval else d_human_eval[task_name]['score']
        
        if human_eval == None:
            print(f"{task_name} has no human eval")
            continue
        if task_name in tasks_to_exclude:
            continue

        with open(f"../results/{task_json}") as f:
            ff = json.load(f)

        # data preprocessing
        df_full = pd.DataFrame(ff).T.reset_index().rename(columns={'index':'data_size'})
        df_full = df_full.rename(columns={'exact_string_match_accuracy':'accuracy'})
        df_full['data_size'] = df_full['data_size'].astype(int)
        df_full['accuracy'] = df_full['accuracy'].astype(float)
        df_full = df_full[(df_full['data_size']>=0) & (df_full['data_size']<=data_cutoff)]
        
        # calculate max accuracy possible
        max_acc = df_full['accuracy'].max().item()
        min_acc = df_full[df_full['data_size']>=0]['accuracy'].min().item()
        human_eval = max(human_eval, max_acc) if human_eval is not None else max_acc

        # now eliminate the data
        if df_full.shape[0] < 2:
            continue
        
        # process AUC and extrapolated curve
        df_clean = clean_curve(df_full, data_cutoff=data_cutoff, human_eval=human_eval)
        df_clean['reached_max'] = np.where(df_clean['accuracy'] + 0.02 < max_acc, 0, 1)

        if use_log:
            df_clean['data_size_org'] = df_clean['data_size']
            df_clean['data_size'] = np.where(df_clean['data_size'] == 0, 1, df_clean['data_size'])
            df_clean['data_size'] = np.log2(df_clean['data_size'].astype(int))
            df_clean['data_size_normalized'] = df_clean['data_size'] / df_clean['data_size'].max()

        # calculate misc metrics
        auc_val_clean = calculate_auc(df_clean)

        d[task_name] = {'max_acc': max_acc,
                        'min_acc': min_acc,
                        'extrapolation_auc': auc_val_clean,
                        'human_eval': human_eval
                        }
            
    return d


def run_auc_calculation(model_prefix, data_cutoff):
    """Run the script end-to-end for all tasks,
        using the specified model results and data-cutoffs for maximum data budget"""

    all_tasks = [i for i in os.listdir('../results') if 'full_result_v2' in i
            and model_prefix in i
            and 'notused' not in i
            ]
    
    auc_res = load_run_full_results(model_name=model_prefix,
                                tasks=all_tasks,
                                use_log=True,
                                data_cutoff=data_cutoff
                                )
    
    with open(f"../results/auc_res/{model_prefix}_auc_logscale_by_task_{data_cutoff}.json", "w") as f:
        json.dump(auc_res, f, indent=4)

if __name__=='__main__':
    run_auc_calculation("mistral",2500)
