import numpy as np
from scipy import signal
import pdb
from copy import deepcopy

from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler

#from pyuoi.linear_model.var import VAR
from FCCA.cov_util import form_lag_matrix

def decimate_(X, q):

    Xdecimated = []
    for i in range(X.shape[1]):
        Xdecimated.append(signal.decimate(X[:, i], q))

    return np.array(Xdecimated).T

# If X has trial structure, need to seperately normalize each trial
def standardize(X):

    scaler = StandardScaler()

    if type(X) == list:
        Xstd = [scaler.fit_transform(x) for x in X] 
    elif np.ndim(X) == 3:
        Xstd = np.array([scaler.fit_transform(X[idx, ...]) 
                         for idx in range(X.shape[0])])
    else:
        Xstd = scaler.fit_transform(X)

    return Xstd

def lr_preprocess(Xtest, Xtrain, Ztest, Ztrain, trainlag, testlag, decoding_window):

    # If no trial structure is present, convert to a list for easy coding
    if np.ndim(Xtrain) == 2:
        Xtrain = [Xtrain]
        Xtest = [Xtest]

        Ztrain = [Ztrain]
        Ztest = [Ztest]

    Ztrain = standardize(Ztrain)
    Xtrain = standardize(Xtrain)

    Ztest = standardize(Ztest)
    Xtest = standardize(Xtest)

    # Apply train lag
    if trainlag > 0:
        Xtrain = [x[:-trainlag, :] for x in Xtrain]
        Ztrain = [z[trainlag:, :] for z in Ztrain]
    elif trainlag < 0:
        Xtrain = [x[-trainlag:, :] for x in Xtrain]
        Ztrain = [z[:trainlag, :] for z in Ztrain]


    # Apply test lag
    if testlag > 0:
        Xtest = [x[:-trainlag, :] for x in Xtest]
        Ztest = [z[trainlag:, :] for z in Ztest]
    elif testlag < 0:
        Xtest = [x[-trainlag:, :] for x in Xtest]
        Ztest = [z[:trainlag, :] for z in Ztest]

    # Apply decoding window
    Xtrain = [form_lag_matrix(x, decoding_window) for x in Xtrain]
    Xtest = [form_lag_matrix(x, decoding_window) for x in Xtest]

    Ztrain = [z[decoding_window//2:, :] for z in Ztrain]
    Ztrain = [z[:x.shape[0], :] for z, x in zip(Ztrain, Xtrain)]

    Ztest = [z[decoding_window//2:, :] for z in Ztest]
    Ztest = [z[:x.shape[0], :] for z, x in zip(Ztest, Xtest)]

    # Flatten trial structure as regression will not care about it
    Xtrain = np.concatenate(Xtrain)
    Xtest = np.concatenate(Xtest)
    Ztrain = np.concatenate(Ztrain)
    Ztest = np.concatenate(Ztest)

    return Xtest, Xtrain, Ztest, Ztrain


def lr_decoder(Xtest, Xtrain, Ztest, Ztrain, trainlag, testlag, decoding_window=1):
    behavior_dim = Ztrain[0].shape[-1]
    Xtest, Xtrain, Ztest, Ztrain = lr_preprocess(Xtest, Xtrain, Ztest, Ztrain, trainlag, testlag, decoding_window)
    decodingregressor = LinearRegression(fit_intercept=True)
    decodingregressor.fit(Xtrain, Ztrain)
    Zpred = decodingregressor.predict(Xtest)
    lr_r2_pos = r2_score(Ztest[..., 0:behavior_dim], Zpred[..., 0:behavior_dim])
    lr_r2_pos = r2_score(Ztest, Zpred)
    return lr_r2_pos, decodingregressor
