import os
import argparse
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer, log_loss
from sklearn.linear_model import LogisticRegression

from general_helper_fns import get_all_groups
from general_helper_fns import get_save_dir, load_data, split_data
from general_helper_fns import save_pickle
from calibrate_helper_fns import get_bin_targets as _get_bin_targets
from calibrate_helper_fns import get_asce, get_brier
from calibrate_helper_fns import get_summary, get_group_summary

import pdb

PROPORTION_CAL = 0.8
NUM_BINS = 5
get_bin_targets = lambda x: _get_bin_targets(x, NUM_BINS)

### Set up

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--group_col", type=str, required=True)

    parser.add_argument("--model_name", type=str, default='Llama2_7B_Chat')
    parser.add_argument("--save_dir", type=str, default='./data/all/temperature=1.0/run=0')
    parser.add_argument("--split", type=str, default='683', choices=['nq', '683', 'all'])

    parser.add_argument("--model_name_cal", type=str, default=None)
    parser.add_argument("--save_dir_cal", type=str, default='./data/all/temperature=1.0/run=0')
    parser.add_argument("--split_cal", type=str, default='nq', choices=['nq', '683', 'all'])

    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--proportion_cal", type=float, default=0.8)
    parser.add_argument("--gridsearch", action='store_true')

    args = parser.parse_args()
    return args

args = get_args()
gcur_name = 'gcur'
if args.gridsearch:
    gcur_name += '_gridsearch'

# define groups (note that we evalutae uncalibrated scores on all groups)
if ';' in args.group_col:
    group_cols = args.group_col.split(';')
else:
    group_cols = [args.group_col]
groups = get_all_groups(group_cols)

# set up results dir
multicalibrate_subir = 'multicalibrate'
if args.model_name_cal is not None:
    multicalibrate_subir += f'+{args.model_name_cal}_{args.split_cal}'
results_dir = os.path.join(args.save_dir, f'results/{args.split}/{args.model_name}/{multicalibrate_subir}/{args.group_col}/seed{args.seed}')
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

results_path = os.path.join(results_dir, f'{gcur_name}_results.pkl')
file_path = os.path.join(results_dir, f'{gcur_name}_files.pkl')
if os.path.exists(results_path) and os.path.exists(file_path):
    print("Run already completed.")
    exit()

# load data
df = load_data(
    args.save_dir,
    model_name=args.model_name,
    split=args.split,
)
df_calibration, df_test = split_data(df, args.seed, args.proportion_cal)

if args.model_name_cal is not None:
    df = load_data(
        args.save_dir_cal,
        model_name=args.model_name_cal,
        split=args.split_cal,
    )
    df_calibration, _ = split_data(df, args.seed, args.proportion_cal)

##### set up bins

# get what the initial bins and their targets should be
bin_targets_uncalibrated = get_bin_targets(df_calibration['score'].values)

def map_to_bin(values, bin_targets=bin_targets_uncalibrated):
    closest_bins = [min(bin_targets.keys(), key=lambda x: abs(bin_targets[x] - num)) for num in values]
    return closest_bins

# map each uncalibrated score to a bin
# score_bin = uncalibrated target of the bin
for df in [df_calibration, df_test]:
    df.loc[:, 'bin'] = map_to_bin(df['score'].values)
    df.loc[:, 'score_bin'] = df['bin'].map(bin_targets_uncalibrated)


### Group-Conditional Unbiased Regression ###

def get_onehot_groups(df, groups):
    group_to_indices = {}
    for group_cols in groups:
        grouped = df.groupby(group_cols)
        for group_values, indices in grouped.groups.items():
            if len(group_cols) == 0:
                group_cols = group_cols[0]
            else:
                group_cols_name = '; '.join(group_cols)

            if isinstance(group_values, tuple):
                group_values = '; '.join([str(x) for x in group_values])
            else:
                group_values = str(group_values)
                
            group_name = group_cols_name + '=' + group_values
            group_to_indices[group_name] = indices.values

    one_hot = np.zeros((len(df), len(group_to_indices)))
    for i, indices in enumerate(group_to_indices.values()):
        one_hot[indices, i] = 1

    columns = [f'in_group-{x}' for x in group_to_indices.keys()]
    df_group_oh = pd.DataFrame(one_hot, columns=columns)

    return df_group_oh

# one hot encode the group columns
df_onehot_both = get_onehot_groups(pd.concat([df_calibration, df_test]).reset_index(drop=True), groups)
df_onehot_groups_calibration = df_onehot_both.loc[:len(df_calibration)-1]
df_onehot_groups_test = df_onehot_both.loc[len(df_calibration):]

# add the score and label to these dataframes
df_onehot_groups_calibration.insert(0, 'score', df_calibration['score'].values)
df_onehot_groups_calibration.insert(0, 'correct', df_calibration['correct'].values.astype(float))
df_onehot_groups_test.insert(0, 'score', df_test['score'].values)
df_onehot_groups_test.insert(0, 'correct', df_test['correct'].values.astype(float))

# set up logistic regression
X_calibration = df_onehot_groups_calibration[[x for x in df_onehot_groups_calibration.columns if x != 'correct']].values
X_test = df_onehot_groups_test[[x for x in df_onehot_groups_test.columns if x != 'correct']].values
y_calibration = df_onehot_groups_calibration['correct'].values
y_test = df_onehot_groups_test['correct'].values

# train model

print("Training model....")

model = LogisticRegression(
    solver='liblinear',
    fit_intercept=True,
    random_state=0)

if args.gridsearch:
    scorer = make_scorer(
        log_loss,
        greater_is_better=False,
        response_method='predict_proba',
    )
    params_dict = dict(
        penalty=['l1'],
        C=[1e-4, 1e-3, 1e-2, 0.1, 1, 10],
    )

    search = GridSearchCV(
        model,
        scoring=scorer,
        param_grid=params_dict,
        cv=5,
        n_jobs=16,
        refit=True,
    )
    search.fit(X_calibration, y_calibration)
    model = search.best_estimator_
else:
    model.fit(X_calibration, y_calibration)

y_pred_calibration = model.predict_proba(X_calibration)[:, 1]
y_pred_test = model.predict_proba(X_test)[:, 1]

df_calibration.loc[:, 'gcur_score'] = y_pred_calibration
df_test.loc[:, 'gcur_score'] = y_pred_test

bin_targets_gcur = get_bin_targets(df_calibration['gcur_score'].values)

for df in [df_calibration, df_test]:
    df.loc[:, 'gcur_bin'] = map_to_bin(df['gcur_score'].values, bin_targets=bin_targets_gcur) 
    df.loc[:, 'gcur_score_bin'] = df['gcur_bin'].map(bin_targets_gcur)

##### get scores

print("Calculating metrics...")

get_asce_gcur = lambda x: get_asce(x, score_col='gcur_score_bin')
get_brier_gcur = lambda x: get_brier(x, score_col='gcur_score_bin')

df_summary = get_summary(
    gcur_name, df_calibration, df_test, get_asce_gcur, get_brier_gcur
)
df_group_summary = get_group_summary(
    gcur_name, groups, df_calibration, df_test, get_asce_gcur, get_brier_gcur
)
save_pickle(
    (df_summary, df_group_summary),
    results_path,
)

# save files
save_pickle(
    (df_calibration, df_test),
    file_path,
)
