import numpy as np

def R2_ceiling(trial_repeats, cutoff = 0.01):
    repeats, time, channels = trial_repeats.shape
    mean_signal = np.mean(trial_repeats, axis=0)
    total_power = np.mean(np.var(trial_repeats, axis = 1), axis=0)
    sp = 1/repeats*(repeats*np.var(mean_signal, axis=0) - total_power)
    sp[sp < cutoff] = cutoff
    return sp

def sp(trial_repeats, cutoff = 0.01):
    n = trial_repeats.shape[0]
    t1 = np.var(np.sum(trial_repeats, axis=0), axis=0)
    t2 = np.sum(np.var(trial_repeats, axis=1), axis=0)
    out = (t1 - t2)/(n*(n-1))
    out[out < cutoff] = cutoff
    return out


def correlation_ceiling(trial_repeats, explainable_var_cutoff = 0.01):
    #assumes that the var the trial repeats is 1 on each repeat
    mean_std = np.std(np.mean(trial_repeats, axis=0), axis=0, ddof=1)
    explainable_var = sp(trial_repeats, explainable_var_cutoff)
    return np.sqrt(explainable_var)/mean_std
