from typing import Callable

from AbstractModels.Model import Model

from SNN.Layers.NeuronConfig import NeuronConfig

import torch

import json

class SpikingModel(Model):
    def __init__(
        self,
        num_classes: int | None = None,
        encoder: Callable | torch.nn.Module = torch.nn.Identity(),
        decoder: Callable | torch.nn.Module = torch.nn.Identity(),
        config: NeuronConfig = NeuronConfig(),
    ):
        super(SpikingModel, self).__init__(num_classes=num_classes)
        self.encoder: Callable | torch.nn.Module = encoder
        self.decoder: Callable | torch.nn.Module = decoder
        
        self.config = config
        self.params = None

    def forward(self, x: torch.Tensor, epoch: None | int = None) -> torch.Tensor:
        pass
    
    def get_total_spikes(self) -> float:
        total_spikes = 0
        for layer in self.modules():
            if hasattr(layer, "spikes"):
                if layer.spikes is not None:
                    total_spikes += layer.spikes
        return total_spikes
    
    def get_avg_spike_rate(self) -> float:
        spikes = 0
        total = 0
        for layer in self.modules():
            if hasattr(layer, "spikes") and hasattr(layer, "total") and layer.total != 0:
                spikes += layer.spikes
                total += layer.total
        if total == 0:
            return 0
        return (spikes / total) * 100

    def get_spike_rate(self) -> list:
        spike_rates = []
        for layer in self.modules():
            if hasattr(layer, "spikes") and hasattr(layer, "total") and layer.total != 0:
                spike_rates.append((layer.spikes / (layer.total)) * 100) 
        return spike_rates

    def reset_total_spikes(self) -> None:        
        for layer in self.modules():
            if hasattr(layer, "spikes"):
                layer.reset_spikes()
    
    def save_params(self, path: str):
        # save params which is a RecordClass
        with open(path, "w") as f:
            json.dump(self.get_params(), f, indent=4)

    def get_params(self) -> dict:
        pass