from BACKEND import np, sp, cp
from model.layer import SRMLayer
from model.utils.enums import OutType

# WIP: Not used yet

class MultiLayerSRM:
    def __init__(self, layers: list[SRMLayer]) -> None:
        self.layers = layers

    def compute_full_trains(self, s_in, eval_ts, out:OutType=OutType.VOLT):
        for l_idx in range(len(self.layers) - 1):
            s_in = self.layers[l_idx].compute_full_trains(s_in, eval_ts, progress=False)
        if out==OutType.VOLT:
            return self.layers[-1].compute_in_voltage(s_in, eval_ts)
        elif out==OutType.SPIKE:
            return self.layers[-1].compute_full_trains(s_in, eval_ts, progress=False)
        else:
            raise ValueError(f"Invalid output type: {out}")

    def fit_sample(self, s_in, s_out, eval_ts, pred, pred_params, rcond=1e-8):
        pass