import os
import json
import pickle
import argparse
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer, log_loss
from sklearn.linear_model import LogisticRegression

from general_helper_fns import get_all_groups
from general_helper_fns import get_save_dir, load_data, split_data
from general_helper_fns import save_pickle, load_pickle
from calibrate_helper_fns import get_bin_targets as _get_bin_targets
from calibrate_helper_fns import get_ce_df, get_asce, get_brier
from calibrate_helper_fns import get_group_df_all, get_summary, get_group_summary

import pdb

PROPORTION_CAL = 0.8
NUM_BINS = 5
get_bin_targets = lambda x: _get_bin_targets(x, NUM_BINS)

### Set up

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--group_col", type=str, required=True)

    parser.add_argument("--model_name", type=str, default='Llama2_7B_Chat')
    parser.add_argument("--save_dir", type=str, default='./data/all/temperature=1.0/run=0')
    parser.add_argument("--split", type=str, default='683', choices=['nq', '683', 'all'])

    parser.add_argument("--model_name_cal", type=str, default=None)
    parser.add_argument("--save_dir_cal", type=str, default='./data/all/temperature=1.0/run=0')
    parser.add_argument("--split_cal", type=str, default='nq', choices=['nq', '683', 'all'])

    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--proportion_cal", type=float, default=0.8)
    parser.add_argument("--max_steps", type=int, default=100)
    parser.add_argument("--target", type=int, default=1e-3)

    args = parser.parse_args()
    return args

args = get_args()

# define groups (note that we evalutae uncalibrated scores on all groups)
if ';' in args.group_col:
    groups = args.group_col.split(';')
else:
    groups = [args.group_col]
groups = get_all_groups(groups)


# set up results dir
multicalibrate_subir = 'multicalibrate'
if args.model_name_cal is not None:
    multicalibrate_subir += f'+{args.model_name_cal}_{args.split_cal}'
results_dir = os.path.join(args.save_dir, f'results/{args.split}/{args.model_name}/{multicalibrate_subir}/{args.group_col}/seed{args.seed}')
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

results_path = os.path.join(results_dir, f'ighb_results.pkl')
file_path = os.path.join(results_dir, f'ighb_files.pkl')
if os.path.exists(results_path) and os.path.exists(file_path):
    print("Run already completed.")
    exit()

# load data
df = load_data(
    args.save_dir,
    model_name=args.model_name,
    split=args.split,
)
df_calibration, df_test = split_data(df, args.seed, args.proportion_cal)

if args.model_name_cal is not None:
    df = load_data(
        args.save_dir_cal,
        model_name=args.model_name_cal,
        split=args.split_cal,
    )
    df_calibration, _ = split_data(df, args.seed, args.proportion_cal)

##### set up bins

# get what the initial bins and their targets should be
bin_targets_uncalibrated = get_bin_targets(df_calibration['score'].values)

def map_to_bin(values, bin_targets=bin_targets_uncalibrated):
    closest_bins = [min(bin_targets.keys(), key=lambda x: abs(bin_targets[x] - num)) for num in values]
    return closest_bins

# map each uncalibrated score to a bin
# score_bin = uncalibrated target of the bin
for df in [df_calibration, df_test]:
    df.loc[:, 'bin'] = map_to_bin(df['score'].values)
    df.loc[:, 'score_bin'] = df['bin'].map(bin_targets_uncalibrated)


### Iterative Grouped Histogram Binning (IGHB)

##### helper functions for getting group statistics

get_asce_ighb = lambda x: get_asce(x, score_col='ighb_score_bin')
get_brier_ighb = lambda x: get_brier(x, score_col='ighb_score_bin')

def get_ce(df, score_col='score_bin'):
    return get_ce_df(df, score_col=score_col)[['size', 'calibration_error']]

def get_sce(df, score_col='score_bin'):
    return get_ce_df(df, score_col=score_col)[['size', 'squared_calibration_error']]

def get_group_df_by_bin(df, score_fn=get_sce, score_col='ighb_score_bin'):
    out = {}
    for group in groups:
        group_name = '; '.join(group)
        df_score = df.groupby(group).apply(lambda x: score_fn(x, score_col=score_col), include_groups=False)
        out[group_name] = df_score.reset_index(score_col)
    return out

def concat_group_dict_by_bin(group_dict, score_col='ighb_score_bin', metric_col_name='error'):
    all_dfs = []
    for group_name, df in group_dict.items():
        group_values = df.index.values
        if len(df.index.names) > 1: # convert multi-index
            group_values = ['; '.join(np.array(x, dtype=str)) for x in group_values]

        df = df.reset_index(drop=True)
        df.columns = [score_col, 'size', metric_col_name]
        df.loc[:, 'group'] = group_name
        df.loc[:, 'group_values'] = group_values

        df = df[['group', 'group_values', score_col, 'size', metric_col_name]]

        all_dfs.append(df)

    df = pd.concat(all_dfs).reset_index(drop=True)
    return df

def get_df_all_groups_by_bin(df, score_fn=get_sce, score_col='ighb_score_bin'):
    group_dict = get_group_df_by_bin(df, score_fn=score_fn, score_col=score_col)
    df_out = concat_group_dict_by_bin(group_dict, score_col=score_col, metric_col_name='error')
    return df_out

# get target for each group
def get_df_bin_target(df_calibration):
    df_bin_target = get_df_all_groups_by_bin(df_calibration, score_fn=get_ce)
    df_bin_target.loc[:, 'bin_target'] = df_bin_target['ighb_score_bin'] - df_bin_target['error']
    del df_bin_target['size']
    del df_bin_target['error']
    return df_bin_target

def get_bin_target_group_bin(worst_group_bin, df_bin_target):
    mask = np.ones(len(df_bin_target), dtype=bool)
    for col in ['ighb_score_bin', 'group', 'group_values']:
        mask &= df_bin_target[col] == worst_group_bin[col]
    bin_target = df_bin_target.loc[mask, 'bin_target'].values[0]
    return bin_target

def get_group_bin_mask(df, score_bin, group_columns, group_values):
    mask = df['ighb_score_bin'] == score_bin
    for col, val in zip(group_columns, group_values):
        mask &= df[col].astype(str) == val
    return mask


### Run IGHB

N = df_calibration.shape[0]
N_test = df_test.shape[0]

df_calibration.loc[:, 'ighb_score_bin'] = df_calibration['score_bin']
df_test.loc[:, 'ighb_score_bin'] = df_test['score_bin']

stop = False
scores_bins_prev = []

for _ in tqdm(range(args.max_steps)):
    df_group_asce = get_group_df_all(df_calibration, groups, get_asce_ighb, 'asce', use_tqdm=False)
    df_group_asce.loc[:, 'weighted_asce'] = df_group_asce['size'] / N * df_group_asce['asce']
    gasce_calibration = df_group_asce['weighted_asce'].sum()

    # # we also track how well we are doing on the test set
    # df_group_asce_test = get_group_df_all(df_test, groups, get_ighb_asce, 'asce', use_tqdm=False)
    # df_group_asce_test.loc[:, 'weighted_asce'] = df_group_asce_test['size'] / N_test * df_group_asce_test['asce']
    # gasce_test = df_group_asce_test['weighted_asce'].sum()
    # print(f'cal: {gasce_calibration:.4f}\ttest: {gasce_test:.4f}')

    if not (df_group_asce['weighted_asce'] > args.target).any():
        print("Done!")
        break

    df_sce_by_group_bin = get_df_all_groups_by_bin(df_calibration)
    df_sce_by_group_bin.loc[:, 'weighted_error'] = df_sce_by_group_bin['size'] / N * df_sce_by_group_bin['error']
    idx_max = df_sce_by_group_bin['weighted_error'].argmax()
    worst_group_bin = df_sce_by_group_bin.loc[idx_max]

    ighb_score_bin = worst_group_bin['ighb_score_bin']
    group_columns = worst_group_bin['group'].split('; ')
    group_values = str(worst_group_bin['group_values']).split('; ')

    _df_bin_target = get_df_bin_target(df_calibration)
    bin_target = get_bin_target_group_bin(worst_group_bin, _df_bin_target)
    correction = bin_target - ighb_score_bin

    # correct calibration set
    mask = get_group_bin_mask(df_calibration, ighb_score_bin, group_columns, group_values)
    df_calibration.loc[mask, 'ighb_score_bin'] += correction
    # use the same corrections on the test set
    mask = get_group_bin_mask(df_test, ighb_score_bin, group_columns, group_values)
    df_test.loc[mask, 'ighb_score_bin'] += correction

    # round scores
    bin_targets_ighb = get_bin_targets(df_calibration['ighb_score_bin'].values)
    df_calibration.loc[:, 'ighb_bin'] = map_to_bin(df_calibration['ighb_score_bin'].values, bin_targets_ighb) 
    df_calibration.loc[:, 'ighb_score_bin'] = df_calibration['ighb_bin'].map(bin_targets_ighb)
    df_test.loc[:, 'ighb_bin'] = map_to_bin(df_test['ighb_score_bin'].values, bin_targets_ighb)
    df_test.loc[:, 'ighb_score_bin'] = df_test['ighb_bin'].map(bin_targets_ighb)

    score_bins_curr = df_calibration['ighb_score_bin'].values.copy()
    for scores in scores_bins_prev:
        if np.isclose(scores, score_bins_curr).all():
            print('Stopping - can no longer further correct worst group.')
            stop = True
    scores_bins_prev.append(score_bins_curr)

    if stop:
        break


##### get scores

print("Calculating metrics...")

df_summary = get_summary(
    'ighb', df_calibration, df_test, get_asce_ighb, get_brier_ighb
)
df_group_summary = get_group_summary(
    'ighb', groups, df_calibration, df_test, get_asce_ighb, get_brier_ighb
)
save_pickle(
    (df_summary, df_group_summary),
    results_path,
)

# save files
save_pickle(
    (df_calibration, df_test),
    file_path,
)
