import os
import argparse
import numpy as np
from tqdm import tqdm

from general_helper_fns import ALL_POSSIBLE_GROUP_COLS, ALL_POSSIBLE_GROUPS
from general_helper_fns import get_save_dir, load_data, split_data, save_pickle, load_pickle
from conformal_helper_fns import prepare_df_conformal, get_safe_threshold, processed_filtered_df
from conformal_helper_fns import get_conformal_summary, add_groups_to_filtered_dict, get_grouped_conformal_summary

import pdb

def get_split_conformal_threshold(values, alpha=0.05):
    n = len(values)
    q = (n + 1) / n * (1 - alpha)
    q = np.clip(q, 0, 1)
    return np.quantile(values, q)

def filter_facts_sc(df_input, safe_thresholds, alpha):
    q = get_split_conformal_threshold(safe_thresholds, alpha=alpha)
    mask = df_input['score'] > q
    df_filtered = df_input[mask]
    return df_filtered

def get_conformal_helper_sc(df_input, safe_thresholds, step=0.05):
    df_filtered_dict = {}
    for alpha in tqdm(np.round(np.arange(step, 1, step), 3)):
        df_filtered = filter_facts_sc(df_input, safe_thresholds, alpha)
        df_all_supported_with_abstentions, df_all_supported, df_frac_supported = processed_filtered_df(df_filtered)
        df_filtered_no_abstentions = df_filtered[df_filtered['text'] != '']
        df_filtered_dict[alpha] = [df_filtered, df_filtered_no_abstentions, df_all_supported_with_abstentions, df_all_supported, df_frac_supported]
    return df_filtered_dict

### Set up

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default='Llama2_7B_Chat')
    parser.add_argument("--split", type=str, default='683', choices=['nq', '683'])
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--proportion_cal", type=float, default=0.8)
    parser.add_argument("--same_num", action='store_true')

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = get_args()

    # load data
    save_dir = get_save_dir()
    results_dir = os.path.join(save_dir, f'results/{args.split}/{args.model_name}/conformal/seed{args.seed}')
    if args.same_num:
        results_dir = os.path.join(save_dir, f'results/{args.split}/{args.model_name}/conformal_same_num/seed{args.seed}')
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    df = load_data(
        save_dir,
        model_name=args.model_name,
        split=args.split,
    )

    if args.same_num:
        df_sizes = df.groupby('topic').size()
        cutoff = int(np.quantile(df_sizes, 0.1))
        mask = df_sizes >= cutoff
        topics_above_cutoff = df_sizes.loc[mask].index.values

        df = df.loc[df['topic'].isin(topics_above_cutoff)]
        df = df.sample(frac=1, random_state=args.seed) # shuffle and sample cutoff
        df = df.groupby('topic').head(cutoff)

    # add row for each topic corresponding to the empty set (abstention)
    df = prepare_df_conformal(df)

    # split data into calibration and test sets
    df_calibration, df_test = split_data(df, args.seed, args.proportion_cal)

    # get safe threshold (all facts with scores above the threshold are supported)
    df_thresholds_calibration = df_calibration.groupby('topic').apply(get_safe_threshold, include_groups=False)
    df_thresholds_test = df_test.groupby('topic').apply(get_safe_threshold, include_groups=False)
    
    # just to be sure, we verify that safe thresholds were calculated correctly
    for _df, _df_thresholds in zip([df_calibration, df_test], [df_thresholds_calibration, df_thresholds_test]):
        import pandas as pd
        _df_thresholds = _df_thresholds.to_frame()
        _df_thresholds.columns = ['threshold']
        df_merge = pd.merge(_df, _df_thresholds, left_on='topic', right_index=True)
        df_merge = df_merge.loc[df_merge['score'] > df_merge['threshold']]
        print(df_merge.shape)
        check = df_merge.groupby('topic').apply(lambda x: x['correct'].all(), include_groups=False)
        if not check.all():
            problems = check.loc[~check.values]
            print(problems)
            assert False

    ### Run Split Conformal ###

    # use split conformal to filter dataframes. Then get intermediate dfs that will be processed into results
    safe_thresholds = df_thresholds_calibration.values
    filtered_dict_calibration = get_conformal_helper_sc(df_calibration, safe_thresholds)
    filtered_dict_test = get_conformal_helper_sc(df_test, safe_thresholds)

    # add group features
    df_groups = df[['topic'] + ALL_POSSIBLE_GROUP_COLS].drop_duplicates().set_index('topic')
    add_groups_to_filtered_dict(filtered_dict_calibration, df_groups)
    add_groups_to_filtered_dict(filtered_dict_test, df_groups)


    ##### get scores

    results_path_sc = os.path.join(results_dir, 'split_conformal_results.pkl')

    if os.path.exists(results_path_sc):
        print("Run already completed")
        df_summary_calibration, df_summary_test = load_pickle(results_path_sc)
    else:
        df_summary_calibration = get_conformal_summary(filtered_dict_calibration)
        df_summary_test = get_conformal_summary(filtered_dict_test)

        save_pickle(
            (df_summary_calibration, df_summary_test),
            results_path_sc,
        )
        
    ##### get scores by group

    results_path_sc_grouped = os.path.join(results_dir, 'split_conformal_results_grouped.pkl')

    if os.path.exists(results_path_sc_grouped):
        print("Run already completed")
        grouped_summary_calibration, grouped_summary_test = load_pickle(results_path_sc_grouped)
    else:
        groups = ALL_POSSIBLE_GROUPS

        grouped_summary_calibration = {}
        grouped_summary_test = {}

        for group_columns in tqdm(groups):
            group_name = '+'.join(group_columns)

            # evaluate on calibration set
            df_grouped_summary_calibration = get_grouped_conformal_summary(
                group_columns, filtered_dict_calibration,
            )
            grouped_summary_calibration[group_name] = df_grouped_summary_calibration

            # evaluate on test set
            df_grouped_summary_test = get_grouped_conformal_summary(
                group_columns, filtered_dict_test,
            )
            grouped_summary_test[group_name] = df_grouped_summary_test

        save_pickle(
            (grouped_summary_calibration, grouped_summary_test),
            results_path_sc_grouped,
        )

        