from AbstractModels.SpikingConvolutionModel import SpikingConvolutionModel
from AbstractModels.ConvolutionModel import ConvolutionModel

class SpikingConvolutionNetwork(ConvolutionModel):
    def __init__(
        self,
        encoder,
        snn: SpikingConvolutionModel,
        decoder,
        seq_length: int,
        input_scale: float):
        super(SpikingConvolutionNetwork, self).__init__(num_classes=snn.num_classes, is_snn=True)        
        self.name = f'{snn.name}'

        self.encoder = encoder
        self.model = snn
        self.decoder = decoder
        self.seq_length = seq_length
        self.input_scale = input_scale

    def forward(self, x):
        x = self.encoder(x)
        voltages = self.model(x)
        return self.decoder(voltages)
    
    def get_total_spikes(self) -> float:
        return self.model.get_total_spikes()
    
    def reset_total_spikes(self) -> None:
        self.model.reset_total_spikes()

    def spike_rate(self) -> list:
        return self.model.get_spike_rate()

    def save_params(self, path: str):
        self.model.save_params(path)

    def get_params(self) -> dict:
        return self.model.get_params()