import numpy as np 
import scipy as sp
import argparse
import pandas as pd
from sklearn.preprocessing import StandardScaler
import h5py
import os
from matplotlib import pyplot as plt
from scipy.stats import zscore
import pickle as pkl                                         

from numpy.linalg import inv, svd
from sklearn.model_selection import KFold
from sklearn.linear_model import Ridge, RidgeCV
import time 

def get_CV_ind_specificSplits(vectorSplits, n_folds):
    n = vectorSplits[n_folds - 1]
    ind = np.zeros((n))
    for i in range(0,n_folds):
        if i == 0: 
            start = 0 
        else: 
            start = vectorSplits[i -1]
        if i == (n_folds -1):
            end = n 
        else: 
            end = vectorSplits[i]
        ind[start:end] = i 
    return ind


def ridge(X,Y,lmbda):
    return np.dot(inv(X.T.dot(X)+lmbda*np.eye(X.shape[1])),X.T.dot(Y))

def ridge_by_lambda(X, Y, Xval, Yval, lambdas=np.array([0.1,1,10,100,1000])):
    error = np.zeros((lambdas.shape[0],Y.shape[1]))
    for idx,lmbda in enumerate(lambdas):
        weights = ridge(X,Y,lmbda)
        error[idx] = 1 - R2(np.dot(Xval,weights),Yval)
    return error

def R2(Pred,Real):
    SSres = np.mean((Real-Pred)**2,0)
    SStot = np.var(Real,0)
    return np.nan_to_num(1-SSres/SStot)

def cross_val_ridge(train_features,train_data, n_splits = 10, 
                    lambdas = np.array([10**i for i in range(-6,10)]),
                    method = 'plain',
                    do_plot = False):
    
    ridge_1 = dict(plain = ridge_by_lambda)[method]
    ridge_2 = dict(plain = ridge)[method]
    
    nL = lambdas.shape[0]
    r_cv = np.zeros((nL, train_data.shape[1]))

    kf = KFold(n_splits=n_splits)
    start_t = time.time()
    for icv, (trn, val) in enumerate(kf.split(train_data)):
        print('ntrain = {}'.format(train_features[trn].shape[0]))
        cost = ridge_1(train_features[trn],train_data[trn],
                               train_features[val],train_data[val], 
                               lambdas=lambdas)
        if do_plot:
            import matplotlib.pyplot as plt
            plt.figure()
            plt.imshow(cost,aspect = 'auto')
        r_cv += cost
        if icv%3 ==0:
            print(icv)
        print('average iteration length {}'.format((time.time()-start_t)/(icv+1)))
    if do_plot:
        plt.figure()
        plt.imshow(r_cv,aspect='auto',cmap = 'RdBu_r');

    argmin_lambda = np.argmin(r_cv,axis = 0)
    weights = np.zeros((train_features.shape[1],train_data.shape[1]))
    for idx_lambda in range(lambdas.shape[0]): # this is much faster than iterating over voxels!
        idx_vox = argmin_lambda == idx_lambda
        weights[:,idx_vox] = ridge_2(train_features, train_data[:,idx_vox],lambdas[idx_lambda])
    if do_plot:
        plt.figure()
        plt.imshow(weights,aspect='auto',cmap = 'RdBu_r',vmin = -0.5,vmax = 0.5);

    return weights, np.array([lambdas[i] for i in argmin_lambda])


def corr(X,Y):
    return np.mean(zscore(X)*zscore(Y),0)

 def get_encoding_model(data, features, splits = [769, 769 + 795, 769 + 795 + 763,769 + 795 + 763 + 778]):
    """
   Given fMRI data, feature matrix, a list of the TR at the end of each of the folds 
   from cross validation this function creates an encoding model with 10 fold internal cross validation to pick 
   lambda. This is designed for 4 runs of data, as in HCP.
    Args:
        data: brain data, TRs x source (ROI, Voxel, etc)
        features: TRs x stimuli reprentation (ie. ELMo embedding)
        splits: List of the last TR in each cross validation fold. For example the default is 
        the number of the last TR from each of the 4 runs in the HCP data 
    Returns:
        corrs: encoding model predictive performance 
        preds_all: predictions TR x Source (ROI, Voxel, etc)
        weights_all: weights learned by the encoding model
        lambdas_all: lambas selected for ridge, based off the 10 fold internal cross val 
    """

    n,v = data.shape
    p = features.shape[1]
    corrs = np.zeros((len(splits),v))
    R2s = np.zeros((len(splits),v))
    ind = get_CV_ind_specificSplits(splits, len(splits))
    preds_all = np.zeros_like(data)
    weights_all = {}
    lambdas_all = np.zeros((len(splits), v))

    for i in range(len(splits)):
        # so zscore per fold (test), and out of fold (train) 
        train_data = np.nan_to_num(zscore(data[ind!=i]))
        train_features = np.nan_to_num(zscore(features[ind!=i]))
        test_data = np.nan_to_num(zscore(data[ind==i]))
        test_features = np.nan_to_num(zscore(features[ind==i]))
        weights, lambdas = cross_val_ridge(train_features, train_data)

        preds = np.dot(test_features,weights)
        preds_all[ind==i] = preds
        weights_all["fold_{fold_num}".format(fold_num = i)] = weights
        lambdas_all[i, :] = lambdas
    corrs = corr(preds_all, data)
  
    return corrs, preds_all, weights_all, lambdas_all


