import numpy as np
import pandas as pd
from sklearn.linear_model import QuantileRegressor
from sklearn.ensemble import GradientBoostingRegressor

from conformal_helper_fns import processed_filtered_df

def get_dummies(df_X):
    non_binary_columns = df_X.columns[df_X.nunique() > 2]
    for col in df_X.columns:
        if col in non_binary_columns:
            continue
        df_X.loc[:, col] = (df_X[col] == df_X[col].values[0]).astype(int)
    df_X = pd.get_dummies(df_X, columns=non_binary_columns).astype(int)
    return df_X

def get_interpolated(input, interp_size=25):
    N = len(input)
    x_original = np.linspace(0, N-1, N)
    x_interpolated = np.linspace(0, N-1, interp_size)
    input_interpolated = np.interp(x_interpolated, x_original, input)
    return input_interpolated

def get_df_interpolated_scores(df, score_col='score'):
    df_interpolated_scores = df.groupby('topic')[score_col].apply(get_interpolated)
    df_interpolated_scores = pd.DataFrame(df_interpolated_scores.tolist(), index=df_interpolated_scores.index)
    interp_cols = [f'interp_{x}' for x in df_interpolated_scores.columns]
    df_interpolated_scores.columns = interp_cols
    return df_interpolated_scores

def filter_facts_qreg(df_input, df_X, quantreg_dict, alpha):
    result = quantreg_dict[alpha]
    q_predictions = result.predict(df_X.values)
    df_q_predictions = pd.DataFrame(q_predictions, columns=['q'], index=df_X.index)

    df_input = df_input.merge(df_q_predictions, left_on='topic', right_index=True)
    mask = df_input['score'] > df_input['q']
    df_filtered = df_input[mask]
    return df_filtered

def get_conformal_helper_qreg(df_input, df_X, quantreg_dict, step=0.05):
    df_filtered_dict = {}
    for alpha in np.round(np.arange(step, 1, step), 3):
        df_filtered = filter_facts_qreg(df_input, df_X, quantreg_dict, alpha)
        df_all_supported_with_abstentions, df_all_supported, df_frac_supported = processed_filtered_df(df_filtered)
        df_filtered_no_abstentions = df_filtered[df_filtered['text'] != '']
        df_filtered_dict[alpha] = [df_filtered, df_filtered_no_abstentions, df_all_supported_with_abstentions, df_all_supported, df_frac_supported]

    return df_filtered_dict

# set up quantile regression model
def get_params(model_type, gridsearch=False):
    if model_type == 'regression':
        if gridsearch:
            return dict(
                alpha=[1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
            )
        else:
            return dict(
                alpha=[0],
            )
    elif model_type == 'gbm':
        if gridsearch:
            return dict(
                max_depth=[3, 5, 7, 9],
                min_samples_split=[2, 5, 10, 20],
                subsample=[0.5, 0.75, 1.0],
                min_samples_leaf=[1, 2, 4, 8],
            )
        else:
            return dict()
    
def get_model_class(model_type, alpha):
    if model_type == 'regression':
        return QuantileRegressor(quantile=1-alpha, fit_intercept=True)
    elif model_type == 'gbm':
        return GradientBoostingRegressor(loss='quantile', alpha=1-alpha, random_state=0)