import os
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

from general_helper_fns import ALL_POSSIBLE_GROUP_COLS, ALL_POSSIBLE_GROUPS, get_all_groups
from general_helper_fns import get_save_dir, load_data, split_data, save_pickle, load_pickle
from conformal_helper_fns import prepare_df_conformal, get_safe_threshold, processed_filtered_df
from conformal_helper_fns import get_conformal_summary, add_groups_to_filtered_dict, get_grouped_conformal_summary

import pdb

def get_split_conformal_threshold(values, alpha=0.05):
    n = len(values)
    q = (n + 1) / n * (1 - alpha)
    q = np.clip(q, 0, 1)
    return np.quantile(values, q)

def get_group_mask(df, group_columns, group_values):
        mask = np.ones(len(df), dtype=bool)
        for col, val in zip(group_columns, group_values):
            mask &= df[col].astype(str) == val
        return mask

def get_groups_coverages(df_filtered, groups):
    group_coverage_dict = {}
    for group_columns in groups: 
        # coverage
        group_name = '; '.join(group_columns)
        all_facts_true = df_filtered.groupby(['topic'] + group_columns)['label'].apply(lambda x: (x == 'S').all())
        all_facts_true_by_group = all_facts_true.groupby(group_columns).mean()
        all_facts_true_by_group.columns = ['coverage']
        # get size
        sizes = df_filtered[['topic'] + group_columns].drop_duplicates().groupby(group_columns).size().to_frame()
        sizes.columns = ['size']
        
        # merge
        all_facts_true_by_group = sizes.merge(all_facts_true_by_group, left_index=True, right_index=True)
        group_coverage_dict[group_name] = all_facts_true_by_group
    return group_coverage_dict

def concat_group_dict(group_dict, score_col_name):
    all_dfs = []
    for group_name, df in group_dict.items():
        group_values = df.index.values
        if len(df.index.names) > 1: # convert multi-index
            group_values = ['; '.join(np.array(x, dtype=str)) for x in group_values]

        df = df.reset_index(drop=True)
        df.columns = ['size', score_col_name]
        df.loc[:, 'group'] = group_name
        df.loc[:, 'group_values'] = group_values

        df = df[['group', 'group_values', 'size', score_col_name]]

        all_dfs.append(df)

    df = pd.concat(all_dfs).reset_index(drop=True)
    return df

def run_multivalid_split_conformal(
    df_input_calibration,
    df_input_test,
    df_thresholds_calibration,
    groups,
    alpha,
    MAX_STEPS=100,
    score_col='score',
):
    safe_thresholds = df_thresholds_calibration.values
    q = get_split_conformal_threshold(safe_thresholds, alpha=alpha)
    q_vector_calibration = np.ones(len(df_input_calibration)) * q
    q_vector_test = np.ones(len(df_input_test)) * q

    max_val = pd.concat([df_input_calibration, df_input_test])['score'].max()

    error_list = [np.infty]
    for t in range(MAX_STEPS):
        q_vector_calibration[np.isclose(q_vector_calibration, max_val)] -= 1e-8
        q_vector_test[np.isclose(q_vector_test, max_val)] -= 1e-8

        mask = df_input_calibration[score_col] > q_vector_calibration
        df_filtered_calibration = df_input_calibration[mask]

        # get errors
        group_coverage_dict = get_groups_coverages(df_filtered_calibration, groups)
        df_group_coverage = concat_group_dict(group_coverage_dict, 'coverage')
        df_group_coverage['coverage_error'] = ((1-alpha) - df_group_coverage['coverage']).abs()
        weighted_errors = df_group_coverage['size'] / len(df_input_calibration) *  df_group_coverage['coverage_error'] ** 2
        df_group_coverage['weighted_error'] = weighted_errors

        # find worst group
        worst_group = df_group_coverage.loc[weighted_errors.argmax()]
        # print(worst_group[['group', 'group_values']])
        group_columns = worst_group['group'].split('; ')
        group_values = str(worst_group['group_values']).split('; ')

        # get q for worst group
        mask_worst = get_group_mask(df_input_calibration, group_columns, group_values)
        worst_topics = df_input_calibration.loc[mask_worst, 'topic'].unique()
        worst_save_thresholds = df_thresholds_calibration.loc[worst_topics].values

        # pdb.set_trace()

        q = get_split_conformal_threshold(worst_save_thresholds, alpha=alpha)
        q_vector_calibration[mask_worst] = q

        # apply q to test set
        mask_worst = get_group_mask(df_input_test, group_columns, group_values)
        q_vector_test[mask_worst] = q

        error = weighted_errors.sum()
        if np.isclose(error_list, error).any():
            # print('Cannot improve.')
            break
        error_list.append(error)

    mask = df_input_calibration[score_col] > q_vector_calibration
    df_filtered_calibration = df_input_calibration[mask]
    mask = df_input_test[score_col] > q_vector_test
    df_filtered_test = df_input_test[mask]

    return df_filtered_calibration, df_filtered_test

def get_conformal_helper_multi_sc(
    df_input_calibration,
    df_input_test,
    df_thresholds_calibration,
    groups,
    step=0.05,
):
    df_filtered_dict = {'calibration': {}, 'test':{}}
    for alpha in tqdm(np.round(np.arange(step, 1, step), 3)):
        df_filtered_calibration, df_filtered_test = run_multivalid_split_conformal(
            df_input_calibration, df_input_test, df_thresholds_calibration, groups, alpha,
        )
        for split, df_filtered in zip(['calibration', 'test'], [df_filtered_calibration, df_filtered_test]):
            df_all_supported_with_abstentions, df_all_supported, df_frac_supported = processed_filtered_df(df_filtered)
            df_filtered_no_abstentions = df_filtered[df_filtered['text'] != '']
            df_filtered_dict[split][alpha] = [df_filtered, df_filtered_no_abstentions, df_all_supported_with_abstentions, df_all_supported, df_frac_supported]

    return df_filtered_dict['calibration'], df_filtered_dict['test']

### Set up

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--group_col", type=str, required=True)
    parser.add_argument("--model_name", type=str, default='Llama2_7B_Chat')
    parser.add_argument("--split", type=str, default='683', choices=['nq', '683'])
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--proportion_cal", type=float, default=0.8)

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = get_args()

    # define groups (note that we evaluate uncalibrated scores on all groups)
    if ';' in args.group_col:
        group_cols = args.group_col.split(';')
    else:
        group_cols = [args.group_col]
    groups = get_all_groups(group_cols)

    # load data
    save_dir = get_save_dir()
    results_dir = os.path.join(save_dir, f'results/{args.split}/{args.model_name}/multi_conformal/{args.group_col}/seed{args.seed}')
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # TODO
    df = load_data(
        save_dir,
        model_name=args.model_name,
        split=args.split,
    )

    # add row for each topic corresponding to the empty set (abstention)
    df = prepare_df_conformal(df)

    # get group cols
    df_groups = df[['topic'] + group_cols].drop_duplicates().set_index('topic')

    # split data into calibration and test sets
    df_calibration, df_test = split_data(df, args.seed, args.proportion_cal)

    # get safe threshold (all facts with scores above the threshold are supported)
    df_thresholds_calibration = df_calibration.groupby('topic').apply(get_safe_threshold, include_groups=False)
    df_thresholds_test = df_test.groupby('topic').apply(get_safe_threshold, include_groups=False)
    
    # just to be sure, we verify that safe thresholds were calculated correctly
    for _df, _df_thresholds in zip([df_calibration, df_test], [df_thresholds_calibration, df_thresholds_test]):
        import pandas as pd
        _df_thresholds = _df_thresholds.to_frame()
        _df_thresholds.columns = ['threshold']
        df_merge = pd.merge(_df, _df_thresholds, left_on='topic', right_index=True)
        df_merge = df_merge.loc[df_merge['score'] > df_merge['threshold']]
        print(df_merge.shape)
        check = df_merge.groupby('topic').apply(lambda x: x['correct'].all(), include_groups=False)
        if not check.all():
            problems = check.loc[~check.values]
            print(problems)
            assert False


    ### Run Multi Valid Split Conformal ###

    # use split conformal to filter dataframes. Then get intermediate dfs that will be processed into results
    filtered_dict_calibration, filtered_dict_test = get_conformal_helper_multi_sc(
        df_calibration,
        df_test,
        df_thresholds_calibration,
        [[x] for x in args.group_col.split(';')],
    )

    # add group features
    add_groups_to_filtered_dict(filtered_dict_calibration, df_groups)
    add_groups_to_filtered_dict(filtered_dict_test, df_groups)

    ##### get scores

    results_path_sc = os.path.join(results_dir, 'multi_split_conformal_results.pkl')

    if False: #os.path.exists(results_path_sc):
        print("Run already completed")
        df_summary_calibration, df_summary_test = load_pickle(results_path_sc)
    else:
        df_summary_calibration = get_conformal_summary(filtered_dict_calibration)
        df_summary_test = get_conformal_summary(filtered_dict_test)

        save_pickle(
            (df_summary_calibration, df_summary_test),
            results_path_sc,
        )
        
    ##### get scores by group

    results_path_sc_grouped = os.path.join(results_dir, 'multi_split_conformal_results_grouped.pkl')

    if False: # os.path.exists(results_path_sc_grouped):
        print("Run already completed")
        grouped_summary_calibration, grouped_summary_test = load_pickle(results_path_sc_grouped)
    else:
        grouped_summary_calibration = {}
        grouped_summary_test = {}

        for group_columns in tqdm(groups):
            group_name = '+'.join(group_columns)

            # evaluate on calibration set
            df_grouped_summary_calibration = get_grouped_conformal_summary(
                group_columns, filtered_dict_calibration,
            )
            grouped_summary_calibration[group_name] = df_grouped_summary_calibration

            # evaluate on test set
            df_grouped_summary_test = get_grouped_conformal_summary(
                group_columns, filtered_dict_test,
            )
            grouped_summary_test[group_name] = df_grouped_summary_test

        save_pickle(
            (grouped_summary_calibration, grouped_summary_test),
            results_path_sc_grouped,
        )
