from SNN.Layers.NeuronConfig import NeuronConfig

import torch
import json

class SpikingConvolutionModel(torch.nn.Module):
    def __init__(
        self, 
        num_classes: int,
        config: NeuronConfig = NeuronConfig(),
    ):
        super(SpikingConvolutionModel, self).__init__()
        self.num_classes = num_classes
        self.config = config

        self.name: str = None
        self.model: torch.nn.Module = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.params = None

    def forward(self, x):
        seq_length, batch_size = x.shape[0], x.shape[1]

        voltages = torch.zeros(
            seq_length, batch_size, self.num_classes, device=x.device, dtype=x.dtype
        )
        
        state = None
        for ts in range(seq_length):
            v, state = self.model(x[ts], state)
            voltages[ts] = v
        return voltages
    
    def get_total_spikes(self) -> float:
        total_spikes = 0
        for layer in self.modules():
            if hasattr(layer, "spikes"):
                if layer.spikes is not None:
                    total_spikes += layer.spikes
        return total_spikes
    
    def get_spike_rate(self) -> list:
        spike_rates = []
        for layer in self.modules():
            if hasattr(layer, "spike_rate"):
                if layer.spike_rate is not None:
                    spike_rates.append(layer.spike_rate)
        return spike_rates

    def reset_total_spikes(self) -> None:        
        for layer in self.modules():
            if hasattr(layer, "spikes"):
                layer.reset_spikes()
    
    def save_params(self, path: str):
        # save params which is a RecordClass
        with open(path, "w") as f:
            json.dump(self.get_params(), f, indent=4)

    def get_params(self) -> dict:
        pass