from scipy.spatial.distance import pdist, squareform
from scipy.stats import zscore
import numpy as np

def corr(X,Y,axis=0):
    # computes the correlation of x1 with y1, x2 with y2, and so on
    return np.mean(zscore(X,axis=axis)*zscore(Y,axis=axis),axis=axis)

# computes the pair-wise correlation of all variables in X with all variables in Y
def crosscorr(X,Y,axis=0):
    nvars_x = X.shape[-1]
    nvars_y = Y.shape[-1]

    num_samples = X.shape[0]

    rep = np.float32(np.repeat(X,nvars_y,axis=1))
    rep = np.reshape(rep, [-1, nvars_x, nvars_y])
    rep2 = np.float32(np.repeat(Y,nvars_x,axis=1))
    rep2 = np.reshape(rep2, [-1, nvars_y, nvars_x])
    rep2 = np.swapaxes(rep2, 1, 2)
    
    return corr(rep, rep2)
    
def get_subj_predictions(subj_id, nlp_feature_name):
    # data-specific, to be filled in
    # returns the predictions for a subject with the specified subject ID made using an encoding model 
    #         trained using the input features specified by nlp_feature_name
    # predictions is a matrix of size num_samples by num_brain_sources (voxels, sensors, electrodes, etc.)
    return predictions

def get_subj_test_data(subj_id):
    # data-specific, to be filled in
    # returns the test data for the subject with the specified subject ID
    # test_data is a matrix of size num_samples by num_brain_sources (voxels, sensors, electrodes, etc.)
    return test_data

def compute_isc(subj_id1, subj_id2):
    # isc is a vector of length num_brain_sources
    data1 = get_subj_test_data(subj_id1)
    data2 = get_subj_test_data(subj_id2)
    isc = corr(data1, data2)
    return isc

def compute_per_subj_isc(subj_id1, all_subj_ids):
    # isc is a vector of length num_brain_sources
    # computes the ISC w.r.t. one subject
    isc = []
    for subj_id2 in all_subj_ids:
        if subj_ids2 == subj_id1:
            continue
        pairwise_isc = compute_isc(subj_id1, subj_id2)
        isc.append(pairwise_isc)
    isc = np.mean(np.vstack(isc),0)
    return isc

def compute_residuals(subj_id1):
    # computes the pairwise residuals of predicting each brain source from every other brain source
    # returns a tensor of dimensions num brain sources x num brain sources x num samples
    test_data = get_subj_test_data(subj_id1)
    brain_corrs = 1-squareform(pdist(test_data.T, 'correlation'))
    num_samples, num_vars = test_data.shape
    brain_residuals = np.zeros([num_vars, num_vars, num_samples])
    
    for i in range(num_vars):
        for j in range(num_vars):
            brain_residuals[i,j,:] = test_data[:,j] - brain_corrs[i,j]*test_data[:,i]
    return brain_residuals

def compute_source_residuals(subj_id1, all_subj_ids):
    source_residuals = []
    residuals_subj1 = compute_residuals(subj_id1)
    for subj_id2 in all_subj_ids:
        if subj_ids2 == subj_id1:
            continue
            
        residuals_subj2 = compute_residuals(subj_id2)
        source_residuals.append(corr(residuals_subj1, residuals_subj2))
    source_residuals = np.mean(np.vstack(source_residuals),0)
    return source_residuals

    
subj_id = 1
all_subj_ids = [1,2,3,4,5,6]
predictions = get_subj_predictions(subj_id)
test = get_subj_test_data(subj_id)

# encoding model performance for subject
encoding_performance = corr(predictions, test)

# normalized encoding performance by ISC
isc = compute_subj_isc(subj_id, all_subj_ids)
norm_encoding_performance = encoding_performance/np.sqrt(isc)

# functional connectivity
func_connectivity = 1-squareform(pdist(test_data.T, metric='correlation'))

# source generalization 
source_generalization = crosscorr(test, test)

# normalized source generalization by ISC
isc = compute_subj_isc(subj_id, all_subj_ids)
normalized_generalization = source_generalization/np.sqrt(isc)

# source residuals
source_residuals = np.sqrt(compute_source_residuals(subj_id, all_subj_ids))

# normalized source residuals by ISC
isc = compute_subj_isc(subj_id, all_subj_ids)
normalized_source_residuals = source_residuals/np.sqrt(isc)