from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import r2_score
from scipy.stats import kendalltau
from constrained_linear_regression import ConstrainedLinearRegression
import sklearn
import pickle 
import numpy as np 
from sklearn.model_selection import KFold
import torch
from itertools import combinations 
from timeit import default_timer as timer


def save_obj(obj, name):
    with open(name + '.pkl', 'wb+') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_obj(name):
    print(name + '.pkl')
    with open(name + '.pkl', 'rb') as f:
        return pickle.load(f)


def sample_mvnn(model, scale, num_samples, num_high_samples, num_val_samples = 1000, seed = 42):
    """
    Takes as input the MVNN and creates sample predictions of the form X,Y from it, to be used in linear regression down the line. 
    """

    np.random.seed(seed)

    # get the model's input dimension from the layers
#     (layers): ModuleList(
#     (0): MVNNLayerReLUProjected(in_features=38, out_features=20, bias=True)
#   )
#   (dropouts): ModuleList(
#     (0): Dropout(p=0, inplace=False
    input_dim = model.layers[0].in_features
    # print('Input dim: ', input_dim)


    # Reconstruct the full set of size 25 (or input dimensions more generally) choose 5 course bundles 
    full_list = list(combinations([i for i in range(input_dim)], 5))
    X_full = [] 
    for bundle in full_list: 
        x = np.zeros(input_dim)
        x[list(bundle)] = 1
        X_full.append(x)

    X_full = np.array(X_full)
    Y_full = model(torch.from_numpy(X_full).float())
    Y_full = Y_full.detach().numpy()
    Y_full_scaled = Y_full * scale

    # randomly sample train and validation sets from the full set
    all_indices = np.random.choice(X_full.shape[0], num_samples + num_val_samples, replace=False)
    X = X_full[all_indices[:num_samples]]
    Y_scaled = Y_full_scaled[all_indices[:num_samples]]
    X_val = X_full[all_indices[num_samples:]]
    Y_val_scaled = Y_full_scaled[all_indices[num_samples:]]

    sorted_args = np.argsort(Y_full_scaled, axis = 0)
    X_high = X_full[sorted_args[-num_high_samples:]].reshape(num_high_samples, -1)
    Y_high_scaled = Y_full_scaled[sorted_args[-num_high_samples:]].reshape(num_high_samples, -1)

    X = np.concatenate((X, X_high), axis = 0)
    Y_scaled = np.concatenate((Y_scaled, Y_high_scaled), axis = 0)


    return X, Y_scaled, X_val, Y_val_scaled, X_full, Y_full_scaled

def poly_regression_mvnn(X, Y, X_val, Y_val, linear_projection, seed = 42, model_type = 'clr', alpha = 0.1, ridge = 0, fit_intercept = False, n_courses = 25):
    """
    Takes as input the X and Y values from the MVNN, and performs polynomial linear regression on them. 
    """

    # create polynomial features of degree 2 (if you have to)
    if not linear_projection:
        poly = PolynomialFeatures(degree = 2, include_bias = False, interaction_only = True)

        X_poly = poly.fit_transform(X)
        X_val_poly = poly.fit_transform(X_val)

    else:
        X_poly = X
        X_val_poly = X_val

    # create linear regression object
    if model_type == 'lr':
        linear_model = sklearn.linear_model.LinearRegression()
        # train the model using the training sets
        linear_model.fit(X_poly, Y)
    
    elif model_type == 'lasso':
        linear_model = sklearn.linear_model.Lasso(alpha = alpha)
        # train the model using the training sets
        linear_model.fit(X_poly, Y)

    elif model_type == 'clr':
        linear_model = ConstrainedLinearRegression(max_iter=300, lasso= alpha, ridge= ridge, fit_intercept= fit_intercept)
        # Set the GUI constraints
        # base values for courses at least 0
        # and for adjustments at most -200
        min_coef = np.repeat(-200, X_poly.shape[1])
        for i in range(n_courses):
            min_coef[i] = 0
        # max value for courses at most 100 
        # and for adjustments at most 200
        max_coef = np.repeat(200, X_poly.shape[1])
        for i in range(n_courses):
            max_coef[i] = 100
        print('X_poly shape: ', X_poly.shape, 'Y shape: ', Y.shape, 'max_coef shape: ', max_coef.shape, 'min_coef shape: ', min_coef.shape)
        train_start = timer()
        linear_model.fit(X_poly, Y, max_coef=max_coef, min_coef=min_coef)
        train_end = timer()
        print(f'Model was fit in {train_end - train_start} seconds')


    # validate the model using the validation sets
    Y_val_pred = linear_model.predict(X_val_poly)

    try: 
        mae = np.mean(np.abs(Y_val_pred - Y_val))
        mae_squared = np.mean((Y_val_pred - Y_val)**2)
        r2 = r2_score(Y_val_pred, Y_val)
        tau, p_value = kendalltau(Y_val, Y_val_pred)

        print("MAE: ", mae, "MAE Squared: ", mae_squared, "R2: ", r2, "KT: ", tau)
    except:
        (mae, mae_squared, r2, tau) = (np.nan, np.nan, np.nan, np.nan)
        print('one of the generalization metrics returned a nan for projection!')


    return linear_model, mae, mae_squared, r2, tau


def check_config(Xs, Ys, X_vals, Y_vals, model_type = 'lr', alpha = 0.1, ridge = 0, fit_intercept = False, kfold = False): 
    
    models, maes, mae_squareds, r2s, taus = [], [], [], [], []
    
    if kfold:
        for i in range(len(Xs)):
            print(f'Starting model {i}')
            X_total = np.concatenate((Xs[i], X_vals[i]), axis = 0)
            Y_total = np.concatenate((Ys[i], Y_vals[i]), axis = 0)

            # Perform 5-fold cross validation
            kf = KFold(n_splits=5, shuffle = True, random_state = 42)
            kf.get_n_splits(X_total)
            
            for train_index, val_index in kf.split(X_total):
                X_train, X_val = X_total[train_index], X_total[val_index]
                Y_train, Y_val = Y_total[train_index], Y_total[val_index]

                if model_type == 'clr':
                    Y_val = Y_val.ravel()
                    Y_train = Y_train.ravel()
                
                linear_model, mae, mae_squared, r2, tau = poly_regression_mvnn(X_train, Y_train, X_val, Y_val, seed = 42, model_type = model_type, alpha = alpha, ridge = ridge, fit_intercept= fit_intercept)
                
                models.append(linear_model)
                maes.append(mae)
                mae_squareds.append(mae_squared)
                r2s.append(r2)
                taus.append(tau)

    else:
        for i in range(len(Xs)):
            print(f'Starting model {i}')
            X_train = Xs[i]
            Y_train = Ys[i]
            X_val = X_vals[i]
            Y_val = Y_vals[i]

            if model_type == 'clr':
                Y_val = Y_val.ravel()
                Y_train = Y_train.ravel()
            
            linear_model, mae, mae_squared, r2, tau = poly_regression_mvnn(X_train, Y_train, X_val, Y_val, seed = 42, model_type = model_type, alpha = alpha, ridge = ridge, fit_intercept= fit_intercept)
            
            models.append(linear_model)
            maes.append(mae)
            mae_squareds.append(mae_squared)
            r2s.append(r2)
            taus.append(tau)

    return models, np.array(maes), np.array(mae_squareds), np.array(r2s), np.array(taus)