from slayer_model.utils import torch_to_cupy
from ..metrics import R2Accumulator, RSEAccumulator, R2AccumulatorPaper

# Helper methods for evaluating the model on a dataset


def eval_set(data, sample_layer, fit_layer, ts, fit_idcs, num_spikes=50, pad=25, r2_paper=False):
    r2a = R2Accumulator()
    rsea = RSEAccumulator()
    if r2_paper:
        r2a_paper = R2AccumulatorPaper(data.mean)

    for x, y in data:
        t_sample = sample_layer.compute_full_trains(x, ts, progress=False,
                                                    num_spikes=num_spikes, pad=pad)
        del x
        v_fit = fit_layer.compute_in_voltage(t_sample, ts, return_cupy=True, fit_idcs=fit_idcs, method='c')
        r2a.accumulate(y, v_fit)
        rsea.accumulate(y, v_fit)
        if r2_paper:
            r2a_paper.accumulate(y, v_fit)
        del t_sample
        del v_fit
        del y

    if r2_paper:
        return r2a.reduce(), rsea.reduce(), r2a_paper.reduce()
    return r2a.reduce(), rsea.reduce()

def eval_set_torch(data, net, fit_idcs, r2_paper=False):
    r2a = R2Accumulator()
    rsea = RSEAccumulator()
    if r2_paper:
        r2a_paper = R2AccumulatorPaper(data.mean)

    for x, y in data:
        v_fit = net(x)[:, :, fit_idcs]
        y_cp = torch_to_cupy(y)
        v_fit_cp = torch_to_cupy(v_fit)
        r2a.accumulate(y_cp, v_fit_cp)
        rsea.accumulate(y_cp, v_fit_cp)
        if r2_paper:
            r2a_paper.accumulate(y, v_fit)
        del v_fit
        del y_cp

    if r2_paper:
        return r2a.reduce(), rsea.reduce(), r2a_paper.reduce()
    return r2a.reduce(), rsea.reduce()