#We use lava.lib.dl.slayer.block.cuba blocks (CUBA LIF) and keeps the ResNet structure (conv stem + 4 residual stages + global avg + dense head), with full spike propagation over time.

import torch
import torch.nn as nn
import lava.lib.dl.slayer as slayer


def default_neuron_params():
    """Reasonable CUBA-LIF neuron params (you can tune these)."""
    neuron_params = {
        "threshold": 1.25,
        "current_decay": 0.25,
        "voltage_decay": 0.03,
        "tau_grad": 0.03,
        "scale_grad": 3.0,
        "requires_grad": True,
        # Optionally:
        # "dropout": slayer.neuron.Dropout(p=0.05),
    }
    return neuron_params


class SpikingBasicBlock(nn.Module):
    """
    ResNet BasicBlock implemented with Lava-DL CUBA LIF blocks.

    Input / output shape: [N, C, H, W, T]
    """
    expansion = 1

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 stride: int = 1,
                 neuron_params=None):
        super().__init__()
        if neuron_params is None:
            neuron_params = default_neuron_params()

        # First 3×3 spiking conv
        self.conv1 = slayer.block.cuba.Conv(
            neuron_params,
            in_features=in_channels,
            out_features=out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            weight_norm=True,
            delay=True,
        )

        # Second 3×3 spiking conv
        self.conv2 = slayer.block.cuba.Conv(
            neuron_params,
            in_features=out_channels,
            out_features=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            weight_norm=True,
            delay=True,
        )

        # Optional downsample for residual path
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.downsample = slayer.block.cuba.Conv(
                neuron_params,
                in_features=in_channels,
                out_features=out_channels * self.expansion,
                kernel_size=1,
                stride=stride,
                padding=0,
                weight_norm=True,
                delay=True,
            )
        else:
            self.downsample = None

    def forward(self, x):
        """
        x: [N, C_in, H, W, T]
        returns: [N, C_out, H', W', T]
        """
        identity = x

        out = self.conv1(x)     # spike conv + LIF
        out = self.conv2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        # Residual add in spike domain
        out = out + identity
        return out


class SpikingResNet(nn.Module):
    """
    Spiking ResNet backbone using Lava-DL SLAYER blocks.

    - Uses CUBA LIF blocks (cuba.Input, cuba.Conv, cuba.Pool, cuba.Dense)
    - Expects inputs as [N, C_in, H, W, T] (NCHWT)
    - Global average pooling over H, W -> [N, C, T] before Dense classifier

    You can:
      * swap SpikingBasicBlock with your MixedLIF block,
      * change `layers` for ResNet-18/34/50,
      * plug this into your SSL pipeline.
    """

    def __init__(self,
                 block,
                 layers,            # e.g., [3, 4, 6, 3] for ResNet-34
                 num_classes=1000,
                 in_channels=3,
                 neuron_params=None):
        super().__init__()
        if neuron_params is None:
            neuron_params = default_neuron_params()
        self.neuron_params = neuron_params

        self.inplanes = 64

        # Input block (optionally does analog→current scaling)
        self.input_block = slayer.block.cuba.Input(
            neuron_params,
            weight=None,   # set a scale if you feed analog frames
            bias=None,
            delay_shift=True,
        )

        # ResNet stem: 7×7 conv, stride=2, then 3×3 pool, stride=2
        self.conv1 = slayer.block.cuba.Conv(
            neuron_params,
            in_features=in_channels,
            out_features=64,
            kernel_size=7,
            stride=2,
            padding=3,
            weight_norm=True,
            delay=True,
        )

        self.pool1 = slayer.block.cuba.Pool(
            neuron_params,
            kernel_size=3,
            stride=2,
            padding=1,
        )

        # Residual stages
        self.layer1 = self._make_layer(block, 64,  layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # Dense classifier head (operates on [N, C, T])
        self.fc = slayer.block.cuba.Dense(
            neuron_params,
            in_neurons=512 * block.expansion,
            out_neurons=num_classes,
            weight_norm=True,
            delay=True,
        )

    def _make_layer(self, block, planes, blocks, stride):
        """
        Create one ResNet stage (e.g., conv2_x) with `blocks` BasicBlocks.
        """
        layers = []
        # First block may downsample
        layers.append(
            block(
                in_channels=self.inplanes,
                out_channels=planes,
                stride=stride,
                neuron_params=self.neuron_params,
            )
        )
        self.inplanes = planes * block.expansion

        # Remaining blocks keep same spatial resolution
        for _ in range(1, blocks):
            layers.append(
                block(
                    in_channels=self.inplanes,
                    out_channels=planes,
                    stride=1,
                    neuron_params=self.neuron_params,
                )
            )

        return nn.Sequential(*layers)

    def forward(self, x):
        """
        Forward pass.

        Args:
            x: input tensor of shape [N, C_in, H, W, T].
               - If you already have spikes, you can keep the Input block
                 as identity by setting weight=None / bias=None.
               - If you have analog frames, you can make the Input block
                 act as an encoder by giving a weight scale.

        Returns:
            logits_spike: [N, num_classes, T] spike tensor (per-timestep logits).
                          You can sum/mean over T for classification.
        """

        # Encode / normalize input
        x = self.input_block(x)          # [N, C_in, H, W, T]

        # Stem
        x = self.conv1(x)                # [N, 64, H/2, W/2, T]
        x = self.pool1(x)                # [N, 64, H/4, W/4, T]

        # Residual stages
        x = self.layer1(x)               # [N, 64,  H/4,  W/4,  T]
        x = self.layer2(x)               # [N, 128, H/8,  W/8,  T]
        x = self.layer3(x)               # [N, 256, H/16, W/16, T]
        x = self.layer4(x)               # [N, 512, H/32, W/32, T]

        # Global average pooling over spatial dims
        # x: [N, 512, H', W', T] -> [N, 512, T]
        x = x.mean(dim=(2, 3))

        # Dense spiking classifier
        logits_spike = self.fc(x)        # [N, num_classes, T]

        return logits_spike

    def export_hdf5(self, filename: str):
        """
        Optional: export to HDF5 for Loihi via lava.lib.dl.netx.
        You can also mimic the NMNIST example in Lava docs.
        """
        import h5py

        with h5py.File(filename, "w") as h:
            layer_group = h.create_group("layer")

            # Flatten network into a list of blocks so we can call export_hdf5
            blocks = []

            blocks.append(self.input_block)
            blocks.append(self.conv1)
            blocks.append(self.pool1)

            for m in self.layer1:
                blocks.append(m.conv1)
                blocks.append(m.conv2)
                if m.downsample is not None:
                    blocks.append(m.downsample)

            for m in self.layer2:
                blocks.append(m.conv1)
                blocks.append(m.conv2)
                if m.downsample is not None:
                    blocks.append(m.downsample)

            for m in self.layer3:
                blocks.append(m.conv1)
                blocks.append(m.conv2)
                if m.downsample is not None:
                    blocks.append(m.downsample)

            for m in self.layer4:
                blocks.append(m.conv1)
                blocks.append(m.conv2)
                if m.downsample is not None:
                    blocks.append(m.downsample)

            blocks.append(self.fc)

            for i, b in enumerate(blocks):
                if hasattr(b, "export_hdf5"):
                    b.export_hdf5(layer_group.create_group(f"{i}"))


# Convenience constructors
def spiking_resnet18(num_classes=1000, in_channels=3, neuron_params=None):
    # ResNet-18: [2, 2, 2, 2]
    return SpikingResNet(
        block=SpikingBasicBlock,
        layers=[2, 2, 2, 2],
        num_classes=num_classes,
        in_channels=in_channels,
        neuron_params=neuron_params,
    )


def spiking_resnet34(num_classes=1000, in_channels=3, neuron_params=None):
    # ResNet-34: [3, 4, 6, 3]
    return SpikingResNet(
        block=SpikingBasicBlock,
        layers=[3, 4, 6, 3],
        num_classes=num_classes,
        in_channels=in_channels,
        neuron_params=neuron_params,
    )
