import os
import argparse

from general_helper_fns import ALL_POSSIBLE_GROUPS
from general_helper_fns import get_save_dir, load_data, split_data
from general_helper_fns import save_pickle
from calibrate_helper_fns import get_bin_targets as _get_bin_targets
from calibrate_helper_fns import get_asce, get_brier
from calibrate_helper_fns import get_summary, get_group_summary

import pdb

PROPORTION_CAL = 0.8
NUM_BINS = 5
get_bin_targets = lambda x: _get_bin_targets(x, NUM_BINS)

### Set up

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default='Llama2_7B_Chat')
    parser.add_argument("--save_dir", type=str, default='./data/all/temperature=1.0/run=0')
    parser.add_argument("--split", type=str, default='683', choices=['nq', '683', 'all'])

    parser.add_argument("--model_name_cal", type=str, default=None)
    parser.add_argument("--save_dir_cal", type=str, default='./data/all/temperature=1.0/run=0')
    parser.add_argument("--split_cal", type=str, default='nq', choices=['nq', '683', 'all'])

    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--proportion_cal", type=float, default=0.8)

    args = parser.parse_args()
    return args

args = get_args()

# load data

calibrate_subir = 'calibrate'
if args.model_name_cal is not None:
    calibrate_subir += f'+{args.model_name_cal}_{args.split_cal}'
results_dir = os.path.join(args.save_dir, f'results/{args.split}/{args.model_name}/{calibrate_subir}/seed{args.seed}')
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

# load data
df = load_data(
    args.save_dir,
    model_name=args.model_name,
    split=args.split,
)
df_calibration, df_test = split_data(df, args.seed, args.proportion_cal)

if args.model_name_cal is not None:
    df = load_data(
        args.save_dir_cal,
        model_name=args.model_name_cal,
        split=args.split_cal,
    )
    df_calibration, _ = split_data(df, args.seed, args.proportion_cal)

##### set up bins

# get what the initial bins and their targets should be
bin_targets_uncalibrated = get_bin_targets(df_calibration['score'].values)

def map_to_bin(values, bin_targets=bin_targets_uncalibrated):
    closest_bins = [min(bin_targets.keys(), key=lambda x: abs(bin_targets[x] - num)) for num in values]
    return closest_bins

# map each uncalibrated score to a bin
# score_bin = uncalibrated target of the bin
for df in [df_calibration, df_test]:
    df.loc[:, 'bin'] = map_to_bin(df['score'].values)
    df.loc[:, 'score_bin'] = df['bin'].map(bin_targets_uncalibrated)


##### get uncalibrated scores

results_path_uncalibrated = os.path.join(results_dir, 'uncalibrated_results.pkl')

if os.path.exists(results_path_uncalibrated):
    print("Metrics for uncalibrated scores have already been calculated.")
else:
    print("Calculating metrics on uncalibrated scores...")

    df_summary = get_summary(
        'uncalibrated', df_calibration, df_test, get_asce, get_brier
    )
    df_group_summary = get_group_summary(
        'uncalibrated', ALL_POSSIBLE_GROUPS, df_calibration, df_test, get_asce, get_brier
    )
    save_pickle(
        (df_summary, df_group_summary),
        results_path_uncalibrated,
    )

### Run Histogram Binning ###

# get the target for each bin
df_hb_correction = df_calibration.groupby('bin')[['correct', 'score_bin']].mean()
df_hb_correction['hb_correction'] = df_hb_correction['correct'] - df_hb_correction['score_bin']
bin_targets_hb = df_hb_correction['correct'].to_dict()

# get how much we need to add to correct each bin
hb_correction = df_hb_correction['hb_correction']
hb_correction_dict = hb_correction.to_dict()

for df in [df_calibration, df_test]:
    df.loc[:, 'hb_score_bin'] = df_calibration['score_bin'] + df_calibration['bin'].map(hb_correction_dict)

##### get hb scores

get_asce_hb = lambda x: get_asce(x, score_col='hb_score_bin')
get_brier_hb = lambda x: get_brier(x, score_col='hb_score_bin')

results_path_hb = os.path.join(results_dir, 'hb_results.pkl')

if os.path.exists(results_path_hb):
    print("Metrics for histogram binning scores have already been calculated.")
else:
    print("Calculating metrics on histogram binning scores...")

    df_summary = get_summary(
        'hb', df_calibration, df_test, get_asce_hb, get_brier_hb
    )
    df_group_summary = get_group_summary(
        'hb', ALL_POSSIBLE_GROUPS, df_calibration, df_test, get_asce_hb, get_brier_hb
    )
    save_pickle(
        (df_summary, df_group_summary),
        results_path_hb,
    )

    # save files
    file_path = os.path.join(results_dir, 'hb_files.pkl')
    if not os.path.exists(file_path):
        save_pickle(
            (df_calibration, df_test),
            file_path,
        )
