import os
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

from sklearn.metrics import mean_pinball_loss, make_scorer
from sklearn.model_selection import GridSearchCV

from general_helper_fns import get_all_groups
from general_helper_fns import get_save_dir, load_data, split_data, save_pickle, load_pickle
from conformal_helper_fns import prepare_df_conformal, get_safe_threshold, processed_filtered_df
from conformal_helper_fns import get_conformal_summary, add_groups_to_filtered_dict, get_grouped_conformal_summary

from multivalid_conformal_helper_fns import get_dummies, get_df_interpolated_scores, get_conformal_helper_qreg
from multivalid_conformal_helper_fns import get_params, get_model_class

import pdb

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--group_col", type=str, required=True)
    parser.add_argument("--model_name", type=str, default='Llama2_7B_Chat')
    parser.add_argument("--split", type=str, default='683', choices=['nq', '683'])
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--proportion_cal", type=float, default=0.8)
    parser.add_argument("--model_type", type=str, default='regression', choices=['regression', 'gbm'])
    parser.add_argument("--gridsearch", action='store_true')

    args = parser.parse_args()
    return args
    
args = get_args()

# define groups (note that we evaluate uncalibrated scores on all groups)
if ';' in args.group_col:
    group_cols = args.group_col.split(';')
else:
    group_cols = [args.group_col]
groups = get_all_groups(group_cols)

# load data
save_dir = get_save_dir()
results_dir = os.path.join(save_dir, f'results/{args.split}/{args.model_name}/multi_conformal/{args.group_col}/seed{args.seed}')
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

df = load_data(
    save_dir,
    model_name=args.model_name,
    split=args.split,
)
    
# add row for each topic corresponding to the empty set (abstention)
df = prepare_df_conformal(df)

# get group cols
df_groups = df[['topic'] + group_cols].drop_duplicates().set_index('topic')

# split data into calibration and test sets
df_calibration, df_test = split_data(df, args.seed, args.proportion_cal)

# get safe threshold (all facts with scores above the threshold are supported)
df_thresholds_calibration = df_calibration.groupby('topic').apply(get_safe_threshold, include_groups=False)
df_thresholds_test = df_test.groupby('topic').apply(get_safe_threshold, include_groups=False)
safe_thresholds = df_thresholds_calibration.values

# get model inputs
df_Xy_calibration = pd.merge(df_groups, df_thresholds_calibration.to_frame(), left_on='topic', right_index=True)
df_Xy_test = pd.merge(df_groups, df_thresholds_test.to_frame(), left_on='topic', right_index=True)

df_y_calibration = df_Xy_calibration[0]
df_y_test = df_Xy_test[0]
df_X_calibration = df_Xy_calibration[[x for x in df_Xy_calibration.columns if x != 0]]
df_X_test = df_Xy_test[[x for x in df_Xy_test.columns if x != 0]]

# one hot encode
df_X_combined = pd.concat((df_X_calibration, df_X_test))
df_X_combined = get_dummies(df_X_combined)
df_X_calibration = df_X_combined[:len(df_X_calibration)]
df_X_test = df_X_combined[len(df_X_calibration):]

# add (interpolated) scores of each biography
df_interpolated_scores_calibration = get_df_interpolated_scores(df_calibration)
df_interpolated_scores_test = get_df_interpolated_scores(df_test)

df_X_calibration = pd.merge(df_X_calibration, df_interpolated_scores_calibration, left_index=True, right_index=True)
df_X_test = pd.merge(df_X_test, df_interpolated_scores_test, left_index=True, right_index=True)

##### Train models

model_path_qreq = os.path.join(results_dir, 'qreg_models.pkl')

if os.path.exists(model_path_qreq):
    print("Models already trained.")
    quantreg_dict = load_pickle(model_path_qreq)
else:
    print("Training...")
    step = 0.05
    quantreg_dict = {}
    for alpha in tqdm(np.round(np.arange(step, 1, step), 3)):
        neg_mean_pinball_loss_scorer = make_scorer(
            mean_pinball_loss,
            alpha=1-alpha,
            greater_is_better=False,  # maximize the negative loss
        )
        model = get_model_class(args.model_type, alpha)
        param_dist = get_params(args.model_type, args.gridsearch)
        search = GridSearchCV(
            model,
            scoring=neg_mean_pinball_loss_scorer,
            param_grid=param_dist,
            n_jobs=16,
            refit=True,
        ).fit(df_X_calibration.values, df_y_calibration.values)

        quantreg_dict[alpha] = search.best_estimator_

    save_pickle(quantreg_dict, model_path_qreq)

# filter using predicted quantiles
filtered_dict_calibration = get_conformal_helper_qreg(df_calibration, df_X_calibration, quantreg_dict)
filtered_dict_test = get_conformal_helper_qreg(df_test, df_X_test, quantreg_dict)

# add group features
add_groups_to_filtered_dict(filtered_dict_calibration, df_groups)
add_groups_to_filtered_dict(filtered_dict_test, df_groups)

##### get scores

results_path_qreq = os.path.join(results_dir, 'qreg_results.pkl')

if os.path.exists(results_path_qreq):
    print("Run already completed")
else:
    df_summary_calibration = get_conformal_summary(filtered_dict_calibration)
    df_summary_test = get_conformal_summary(filtered_dict_test)

    save_pickle(
        (df_summary_calibration, df_summary_test),
        results_path_qreq,
    )
    
##### get scores by group

results_path_qreg_grouped = os.path.join(results_dir, 'qreg_results_grouped.pkl')

if os.path.exists(results_path_qreg_grouped):
    print("Run already completed")
else:
    grouped_summary_calibration = {}
    grouped_summary_test = {}

    for group_columns in tqdm(groups):
        group_name = '+'.join(group_columns)

        # evaluate on calibration set
        df_grouped_summary_calibration = get_grouped_conformal_summary(
            group_columns, filtered_dict_calibration,
        )
        grouped_summary_calibration[group_name] = df_grouped_summary_calibration

        # evaluate on test set
        df_grouped_summary_test = get_grouped_conformal_summary(
            group_columns, filtered_dict_test,
        )
        grouped_summary_test[group_name] = df_grouped_summary_test

    save_pickle(
        (grouped_summary_calibration, grouped_summary_test),
        results_path_qreg_grouped,
    )
