import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, Optional

from slayer_model.layer import SRMLayer

# Convenient network wrapper class

class SRMNetwork(nn.Module):
    """
    Complete SRM network using Lava-DL enhanced layers
    """

    def __init__(self, layers: list):
        """
        Args:
            layers: List of layer configurations
        """
        super().__init__()

        self.layers = nn.ModuleList()

        for layer_config in layers:
            layer = SRMLayer(
                **layer_config
            )
            self.layers.append(layer)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through network

        Args:
            x: Input tensor (batch_size, n_inputs, n_timesteps) or continuous

        Returns:
            Output spikes or voltage
        """
        current = x

        for i, layer in enumerate(self.layers):
            current = layer(current)

        return current

    def init_fluct_rg(self, x, mu_u, xi):
        for layer in self.layers:
            layer.init_fluct_rg(x, mu_u, xi)
            x = layer(x)

    def load_from_sswim(self, sample_layer, fit_layer):
        self.layers[0].load_from_sswim(sample_layer)
        self.layers[1].load_from_sswim(fit_layer)

    def get_regularisation_weights(self, neuron_wise=False):
        return self.layers[1].get_regularisation_weights(neuron_wise=neuron_wise)


# Utility functions for easy network creation
def create_srm_network_from_config(config: dict) -> SRMNetwork:
    """
    Create SRM network from configuration dictionary
    Maintains interface compatibility with your PyTorch version
    """

    return SRMNetwork(
        layers=config['layers'],
    )