import numpy as np
from regression.regression_utils import concat_time_series_inputs
from regression.ceilings import R2_ceiling, sp

def correlation_loss(input_time_series, target_time_series):
    #input_time_series shape (batch, channels)
    #target_time_series shape (batch, channels)
    normalized_input_time_series = (input_time_series - np.mean(input_time_series, axis=0)[None, :])/(np.std(input_time_series, axis=0)[None,:])
    normalized_target_time_series = (target_time_series - np.mean(target_time_series, axis=0)[None, :])/(np.std(target_time_series, axis=0)[None,:])
    return np.mean(normalized_input_time_series*normalized_target_time_series, axis=0)

def r2_loss(input_time_series, target_time_series):
    #input_time_series shape (batch, channels)
    #target_time_series shape (batch, channels)
    SSres = np.mean((target_time_series - input_time_series) ** 2, axis=0)
    SStot = np.var(target_time_series, axis=0)
    return np.nan_to_num(1 - SSres / SStot)

def brain_score(correlation_ceiling, ceiling_cutoff = None):
    def _loss(input_time_series, target_time_series):
        corr_loss = correlation_loss(input_time_series, target_time_series)
        channel_brain_scores = corr_loss/correlation_ceiling
        if not ceiling_cutoff is None:
            return np.mean(channel_brain_scores[correlation_ceiling > ceiling_cutoff])
        return np.mean(channel_brain_scores)
    return _loss

def brain_score_1(trial_repeats, sp_cutoff = 0.01, ceiling_cutoff = None):
    s = sp(trial_repeats, sp_cutoff)
    y_bar = np.mean(trial_repeats, axis=0)
    def _loss(input_time_series):
        cov = np.mean((input_time_series - input_time_series.mean(axis=0)[None,:])*(y_bar - y_bar.mean(axis=0)[None,:]), axis=0)
        cc_norm = cov/np.std(input_time_series, axis=0) * 1/np.sqrt(s)
        
        if not ceiling_cutoff is None:
            nc = np.std(input_time_series, axis=0)/np.sqrt(s)
            return np.mean(cc_norm[nc > ceiling_cutoff])
        
        return np.mean(cc_norm)
    return _loss
    
def concat_time_series_loss(input_time_series_list, target_time_series_list, loss_fn):
    concat_input_time_series = concat_time_series_inputs(input_time_series_list)
    concat_target_time_series = concat_time_series_inputs(target_time_series_list)
    return loss_fn(concat_input_time_series, concat_target_time_series)

def spe_and_cc_norm(orig_data, data_pred, data_norm=True, max_flooring=None):
    '''
    Computes the signal power explained and the cc_norm of a model given the observed and predicted values
    Assumes normalization unless data_norm is set to False
    
    orig_data: 3D numpy array (trials, timepoints, voxels)
    
    data_pred: 2D numpy array (timepoints, voxels)
    
    data_norm: bool -> Set to False if not pre-normalized
    
    max_flooring: None/float (0-1) -> If not None, compute cc_norm in an alternate way that floors cc_max by max_flooring.
    This is helpful to clean up bad voxels that are not at all language selective.
    
    According to Schoppe: https://www.frontiersin.org/articles/10.3389/fncom.2016.00010/full
    '''
    y = np.mean(orig_data, axis=0)
    num_trials = len(orig_data)
    if not data_norm:
        variance_across_time = np.var(orig_data, axis=1, ddof=1)
        TP = np.mean(variance_across_time, axis=0)
    else:
        TP = np.zeros(orig_data.shape[2]) + 1
    SP = (1 / (num_trials-1)) * ((num_trials * np.var(y, axis=0, ddof=1)) - TP) 
    SPE_num = (np.var(y, axis=0, ddof=1) - np.var(y - data_pred, axis=0, ddof=1)) 
    SPE = (np.var(y, axis=0, ddof=1) - np.var(y - data_pred, axis=0, ddof=1)) / SP
    y_flip = np.swapaxes(y, axis1=0, axis2=1)
    data_flip = np.swapaxes(data_pred, axis1=0, axis2=1)
    covs = np.zeros(y_flip.shape[0])
    for i, row in enumerate(y_flip):
        covs[i] = np.cov(y_flip[i], data_flip[i])[0][1]
    cc_norm =  np.sqrt(1/SP) * (covs / np.sqrt(np.var(data_pred, axis=0, ddof=1)))
    cc_max = None
    if max_flooring is not None:
        cc_max = np.nan_to_num(1 / (np.sqrt(1 + ((1/num_trials) * ((TP/SP)-1)))))
        #cc_max = np.maximum(cc_max, np.zeros(cc_max.shape) + max_flooring)
        corrs = np.zeros(y_flip.shape[0])
        for i, row in enumerate(y_flip):
            corrs[i] = np.corrcoef(y_flip[i], data_flip[i])[0][1]
        cc_norm = corrs / cc_max
        return SPE, cc_norm, cc_max, corrs
    else:
        return SPE, cc_norm, cc_max