import os
import argparse
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer, log_loss
from sklearn.linear_model import LogisticRegression

from general_helper_fns import ALL_POSSIBLE_GROUPS
from general_helper_fns import get_save_dir, load_data, split_data
from general_helper_fns import save_pickle
from calibrate_helper_fns import get_bin_targets as _get_bin_targets
from calibrate_helper_fns import get_asce, get_brier
from calibrate_helper_fns import get_summary, get_group_summary

import pdb

PROPORTION_CAL = 0.8
NUM_BINS = 5
get_bin_targets = lambda x: _get_bin_targets(x, NUM_BINS)

### Set up

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default='Llama2_7B_Chat')
    parser.add_argument("--save_dir", type=str, default='./data/all/temperature=1.0/run=0')
    parser.add_argument("--split", type=str, default='683', choices=['nq', '683', 'all'])

    parser.add_argument("--model_name_cal", type=str, default=None)
    parser.add_argument("--save_dir_cal", type=str, default='./data/all/temperature=1.0/run=0')
    parser.add_argument("--split_cal", type=str, default='nq', choices=['nq', '683', 'all'])

    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--proportion_cal", type=float, default=0.8)
    parser.add_argument("--gridsearch", action='store_true')

    args = parser.parse_args()
    return args

args = get_args()

# load data

calibrate_subir = 'calibrate'
if args.model_name_cal is not None:
    calibrate_subir += f'+{args.model_name_cal}_{args.split_cal}'
results_dir = os.path.join(args.save_dir, f'results/{args.split}/{args.model_name}/{calibrate_subir}/seed{args.seed}')
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

# load data
df = load_data(
    args.save_dir,
    model_name=args.model_name,
    split=args.split,
)
df_calibration, df_test = split_data(df, args.seed, args.proportion_cal)

if args.model_name_cal is not None:
    df = load_data(
        args.save_dir_cal,
        model_name=args.model_name_cal,
        split=args.split_cal,
    )
    df_calibration, _ = split_data(df, args.seed, args.proportion_cal)

##### set up bins

# get what the initial bins and their targets should be
bin_targets_uncalibrated = get_bin_targets(df_calibration['score'].values)

def map_to_bin(values, bin_targets=bin_targets_uncalibrated):
    closest_bins = [min(bin_targets.keys(), key=lambda x: abs(bin_targets[x] - num)) for num in values]
    return closest_bins

# map each uncalibrated score to a bin
# score_bin = uncalibrated target of the bin
for df in [df_calibration, df_test]:
    df.loc[:, 'bin'] = map_to_bin(df['score'].values)
    df.loc[:, 'score_bin'] = df['bin'].map(bin_targets_uncalibrated)

# set up logistic regression
X_calibration = df_calibration[['score']].values
X_test = df_test[['score']].values
y_calibration = df_calibration['correct'].values
y_test = df_test['correct'].values

# train model

print("Training model....")

model = LogisticRegression(
    solver='liblinear',
    fit_intercept=True,
    random_state=0)

if args.gridsearch:
    scorer = make_scorer(
        log_loss,
        greater_is_better=False,
        response_method='predict_proba',
    )
    params_dict = dict(
        penalty=['l1'],
        C=[1e-4, 1e-3, 1e-2, 0.1, 1, 10],
    )

    search = GridSearchCV(
        model,
        scoring=scorer,
        param_grid=params_dict,
        cv=5,
        n_jobs=16,
        refit=True,
    )
    search.fit(X_calibration, y_calibration)
    model = search.best_estimator_
else:
    model.fit(X_calibration, y_calibration)

y_pred_calibration = model.predict_proba(X_calibration)[:, 1]
y_pred_test = model.predict_proba(X_test)[:, 1]

df_calibration.loc[:, 'reg_score'] = y_pred_calibration
df_test.loc[:, 'reg_score'] = y_pred_test

bin_targets_gcur = get_bin_targets(df_calibration['reg_score'].values)

for df in [df_calibration, df_test]:
    df.loc[:, 'reg_bin'] = map_to_bin(df['reg_score'].values, bin_targets=bin_targets_gcur) 
    df.loc[:, 'reg_score_bin'] = df['reg_bin'].map(bin_targets_gcur)

##### get scores

print("Calculating metrics on histogram binning scores...")

results_path_reg = os.path.join(results_dir, 'reg_results.pkl')

get_asce_gcur = lambda x: get_asce(x, score_col='reg_score_bin')
get_brier_gcur = lambda x: get_brier(x, score_col='reg_score_bin')

df_summary = get_summary(
    'reg', df_calibration, df_test, get_asce_gcur, get_brier_gcur
)
df_group_summary = get_group_summary(
    'reg', ALL_POSSIBLE_GROUPS, df_calibration, df_test, get_asce_gcur, get_brier_gcur
)
save_pickle(
    (df_summary, df_group_summary),
    results_path_reg,
)

# save files
file_path = os.path.join(results_dir, 'reg_files.pkl')
# save files
save_pickle(
    (df_calibration, df_test),
    file_path,
)
