"""
Credits: https://github.com/alansun17904/circuit-alignment
"""

from numpy.linalg import inv, svd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.linear_model import Ridge, RidgeCV
import time
from scipy.stats import zscore


def corr(X, Y):
    return np.mean(zscore(X) * zscore(Y), 0)


def R2(Pred, Real):
    SSres = np.mean((Real - Pred) ** 2, 0)
    SStot = np.var(Real, 0)
    return np.nan_to_num(1 - SSres / SStot)


def R2r(Pred, Real):
    R2rs = R2(Pred, Real)
    ind_neg = R2rs < 0
    R2rs = np.abs(R2rs)
    R2rs = np.sqrt(R2rs)
    R2rs[ind_neg] *= -1
    return R2rs


def ridge(X, Y, lmbda):
    return np.dot(inv(X.T.dot(X) + lmbda * np.eye(X.shape[1])), X.T.dot(Y))


def ridge_by_lambda(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])):
    error = np.zeros((len(lambdas), Y.shape[1]))
    for idx, lmbda in enumerate(lambdas):
        weights = ridge(X, Y, lmbda)
        error[idx] = 1 - R2(np.dot(Xval, weights), Yval)
    return error


def ridge_sk(X, Y, lmbda):
    rd = Ridge(alpha=lmbda)
    rd.fit(X, Y)
    return rd.coef_.T


def ridgeCV_sk(X, Y, lmbdas):
    rd = RidgeCV(alphas=lmbdas)
    rd.fit(X, Y)
    return rd.coef_.T


def ridge_by_lambda_sk(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])):
    error = np.zeros((len(lambdas), Y.shape[1]))
    for idx, lmbda in enumerate(lambdas):
        weights = ridge_sk(X, Y, lmbda)
        error[idx] = 1 - R2(np.dot(Xval, weights), Yval)
    return error


def ridge_svd(X, Y, lmbda):
    U, s, Vt = svd(X, full_matrices=False)
    d = s / (s**2 + lmbda)
    return np.dot(Vt, np.diag(d).dot(U.T.dot(Y)))


def ridge_by_lambda_svd(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])):
    error = np.zeros((len(lambdas), Y.shape[1]))
    U, s, Vt = svd(X, full_matrices=False)
    for idx, lmbda in enumerate(lambdas):
        d = s / (s**2 + lmbda)
        weights = np.dot(Vt, np.diag(d).dot(U.T.dot(Y)))
        error[idx] = 1 - R2(np.dot(Xval, weights), Yval)
    return error


def kernel_ridge(X, Y, lmbda):
    return np.dot(X.T.dot(inv(X.dot(X.T) + lmbda * np.eye(X.shape[0]))), Y)


def kernel_ridge_by_lambda(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])):
    error = np.zeros((lambdas.shape[0], Y.shape[1]))
    for idx, lmbda in enumerate(lambdas):
        weights = kernel_ridge(X, Y, lmbda)
        error[idx] = 1 - R2(np.dot(Xval, weights), Yval)
    return error


def kernel_ridge_svd(X, Y, lmbda):
    U, s, Vt = svd(X.T, full_matrices=False)
    d = s / (s**2 + lmbda)
    return np.dot(np.dot(U, np.diag(d).dot(Vt)), Y)


def kernel_ridge_by_lambda_svd(
    X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])
):
    error = np.zeros((len(lambdas), Y.shape[1]))
    U, s, Vt = svd(X.T, full_matrices=False)
    for idx, lmbda in enumerate(lambdas):
        d = s / (s**2 + lmbda)
        weights = np.dot(np.dot(U, np.diag(d).dot(Vt)), Y)
        error[idx] = 1 - R2(np.dot(Xval, weights), Yval)
    return error


def cross_val_ridge(
    train_features,
    train_data,
    n_splits=10,
    lambdas=np.array([10**i for i in range(-6, 10)]),
    method="plain",
    do_plot=False,
):

    ridge_1 = dict(
        plain=ridge_by_lambda,
        svd=ridge_by_lambda_svd,
        kernel_ridge=kernel_ridge_by_lambda,
        kernel_ridge_svd=kernel_ridge_by_lambda_svd,
        ridge_sk=ridge_by_lambda_sk,
    )[method]
    ridge_2 = dict(
        plain=ridge,
        svd=ridge_svd,
        kernel_ridge=kernel_ridge,
        kernel_ridge_svd=kernel_ridge_svd,
        ridge_sk=ridge_sk,
    )[method]

    n_voxels = train_data.shape[1]
    nL = len(lambdas)
    r_cv = np.zeros((nL, train_data.shape[1]))

    kf = KFold(n_splits=n_splits)
    start_t = time.time()
    for icv, (trn, val) in enumerate(kf.split(train_data)):
        # print('ntrain = {}'.format(train_features[trn].shape[0]))
        cost = ridge_1(
            train_features[trn],
            train_data[trn],
            train_features[val],
            train_data[val],
            lambdas=lambdas,
        )
        if do_plot:
            import matplotlib.pyplot as plt

            plt.figure()
            plt.imshow(cost, aspect="auto")
        r_cv += cost
        # if icv%3 ==0:
        #    print(icv)
        # print('average iteration length {}'.format((time.time()-start_t)/(icv+1)))
    if do_plot:
        plt.figure()
        plt.imshow(r_cv, aspect="auto", cmap="RdBu_r")

    argmin_lambda = np.argmin(r_cv, axis=0)
    weights = np.zeros((train_features.shape[1], train_data.shape[1]))
    for idx_lambda in range(nL):  # this is much faster than iterating over voxels!
        idx_vox = argmin_lambda == idx_lambda
        weights[:, idx_vox] = ridge_2(
            train_features, train_data[:, idx_vox], lambdas[idx_lambda]
        )
    if do_plot:
        plt.figure()
        plt.imshow(weights, aspect="auto", cmap="RdBu_r", vmin=-0.5, vmax=0.5)

    return weights  # , np.array([lambdas[i] for i in argmin_lambda])
