from matplotlib import pyplot as plt

from BACKEND import cp, np, to_gpu


# Metrics used for evaluation using an equivalent online formulation since the whole dataset cannot be held in memory

def rse(y_true, y_pred):
    y_pred = to_gpu(y_pred)
    y_true = to_gpu(y_true)
    return cp.sqrt(
        cp.square(y_true - y_pred).sum() /
        cp.square(y_true - y_true.mean(axis=0)[None, :, :]).sum()
    )

class RSEAccumulator:
    def __init__(self):
        self.N = 0
        self.acc_num = None
        self.acc_denom_square = None
        self.acc_denom_exp = None

    def accumulate(self, y_true, y_pred):
        if self.N == 0:
            self.N += y_true.shape[0]
            self.acc_num = cp.square(y_pred - y_true).sum()
            self.acc_denom_square = cp.square(y_true).sum()
            self.acc_denom_exp = y_true.sum(axis=0)
        else:
            self.N += y_true.shape[0]
            self.acc_num += cp.square(y_pred - y_true).sum()
            self.acc_denom_square += cp.square(y_true).sum()
            self.acc_denom_exp += y_true.sum(axis=0)

    def reduce(self, reset=True):
        if self.N == 0:
            raise ValueError("Accumulator is empty")
        else:
            denom = self.acc_denom_square - 1 / self.N * cp.square(self.acc_denom_exp).sum() + 1e-8
            res = cp.sqrt(self.acc_num / denom)
            if reset:
                self.reset()
            return res

    def reset(self):
        self.N = 0
        self.acc_num = None
        self.acc_denom_square = None
        self.acc_denom_exp = None

def r2(y_true, y_pred, neuron_wise=False):
    """
    Coefficient of Determination
    :return:
    """
    y_pred = to_gpu(y_pred)
    y_true = to_gpu(y_true)

    score = (
            cp.square(y_pred - y_true).sum(axis=0) /
            (cp.square(y_true - y_true.mean(axis=0, keepdims=True)).sum(axis=0) + 1e-8)
    )

    if neuron_wise:
        return 1 - score.mean(axis=-1)
    else:
        return 1 - score.mean()

class R2AccumulatorPaper:
    def __init__(self, mean):
        self.acc_num = None
        self.acc_denom = None
        self.mean = mean

    def accumulate(self, y_true, y_pred):
        if self.acc_num is None:
            self.acc_num = cp.square(y_pred - y_true).sum()
            self.acc_denom = cp.square(y_true - self.mean).sum()
        else:
            self.acc_num += cp.square(y_pred - y_true).sum()
            self.acc_denom += cp.square(y_true - self.mean).sum()

    def reduce(self, reset=True):
        if self.acc_num is None:
            raise ValueError("Accumulator is empty")
        else:
            res = 1 - (self.acc_num / self.acc_denom)
            if reset:
                self.reset()
            return res

    def reset(self):
        self.acc_num = None
        self.acc_denom = None

class MSEAccumulator:
    def __init__(self, n_neurons=0):
        self.N = 0
        if n_neurons:
            self.acc = cp.zeros(n_neurons, dtype=cp.float64)
        else:
            self.acc = 0
        self.neuron_wise = (n_neurons > 0)

    def accumulate(self, y_true, y_pred):
        self.N += y_true.shape[0]
        if self.neuron_wise:
            self.acc += cp.square(y_pred - y_true).mean(axis=-1).sum(axis=0)
        else:
            self.acc += cp.square(y_pred - y_true).mean(axis=1).mean(axis=-1).sum(axis=0)

    def reduce(self, reset=True, neuron_wise=False):
        res = self.acc / self.N
        if reset:
            self.reset()
        return res

    def reset(self):
        self.N = 0
        self.acc = 0

class R2Accumulator:
    def __init__(self):
        self.N = 0
        self.acc_num = None
        self.acc_denom = None
        self.acc_denom_square = None

    def accumulate(self, y_true, y_pred):
        if self.N == 0:
            self.N += y_true.shape[0]
            self.acc_num = cp.square(y_pred - y_true).sum(axis=0)
            self.acc_denom = y_true.sum(axis=0)
            self.acc_denom_square = cp.square(y_true).sum(axis=0)
        else:

            self.N += y_true.shape[0]
            self.acc_num += cp.square(y_pred - y_true).sum(axis=0)
            self.acc_denom += y_true.sum(axis=0)
            self.acc_denom_square += cp.square(y_true).sum(axis=0)

    def reduce(self, reset=True, neuron_wise=False):
        if self.N == 0:
            raise ValueError("Accumulator is empty")
        else:
            denom = self.acc_denom_square - 1 / self.N * cp.square(self.acc_denom) + 1e-8
            if neuron_wise:
                res = 1 - (self.acc_num / denom).mean(axis=-1)
            else:
                res = 1 - (self.acc_num / denom).mean()
            if reset:
                self.reset()
            return res

    def reset(self):
        self.N = 0
        self.acc_num = None
        self.acc_denom = None
        self.acc_denom_square = None


class SpikeCountAccumulator:
    def __init__(self):
        self.N = 0
        self.n_spikes = 0

    def accumulate(self, spikes_binned):
        self.n_spikes += spikes_binned.sum(axis=-1).sum(axis=0).mean()
        self.N += spikes_binned.shape[0]

    def reduce(self, reset=True):
        if self.N == 0:
            raise ValueError("Accumulator is empty")
        else:
            mean = self.n_spikes / self.N
            if reset:
                self.resset()
            return mean

    def resset(self):
        self.N = 0
        self.n_spikes = 0


def neuron_response_correlation(X):
    """
    :param X: (N, n_neurons, n_ts)
    :return:
    """
    fig, ax = plt.subplots(figsize=(12, 12))
    X = to_gpu(X)
    N, n_neurons, n_ts = X.shape
    corrs = cp.zeros((n_neurons, n_neurons))
    norms = cp.linalg.norm(X, axis=-1)
    for n in range(N):
        corrs += X[n] @ X[n].T / (norms[n, None, :] * norms[n, :, None] * N + 1e-6)

    m = ax.matshow(corrs.get(), cmap='seismic')
    fig.colorbar(mappable=m, ax=ax)
    plt.show()


if __name__ == '__main__':
    r2a = RSEAccumulator()
    y_true = cp.random.standard_normal(size=(100, 100, 100))
    y_pred = cp.random.standard_normal(size=(100, 100, 100))
    print(f"Ref: {rse(y_true, y_pred)}")
    r2a.accumulate(y_true, y_pred)
    print(f"Full acc: {r2a.reduce()}")
    r2a.reset()
    r2a.accumulate(y_true[:20], y_pred[:20])
    r2a.accumulate(y_true[20:50], y_pred[20:50])
    r2a.accumulate(y_true[50:], y_pred[50:])
    print(f"Piecewise acc: {r2a.reduce()}")
