from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import flax.linen as nn
import jax.numpy as jnp

# -------------------------
# Utility layers / helpers
# -------------------------

def _maybe_flatten(x: jnp.ndarray) -> jnp.ndarray:
    """Flatten (N, H, W, C) -> (N, H*W*C); pass through if already (N, D)."""
    if x.ndim > 2:
        return x.reshape((x.shape[0], -1))
    return x

# at top
from typing import Literal


class ConvBNAct(nn.Module):
    features: int
    kernel_size: tuple[int, int]
    strides: tuple[int, int] = (1, 1)
    padding: str = "SAME"
    use_bias: bool = False
    bn_decay: float = 0.99     # good for batch=256
    bn_eps: float = 1e-5

    @nn.compact
    def __call__(self, x, *, train: bool):
        x = nn.Conv(self.features, self.kernel_size, self.strides,
                    padding=self.padding, use_bias=self.use_bias)(x)
        x = nn.BatchNorm(use_running_average=not train,
                         momentum=self.bn_decay,
                         epsilon=self.bn_eps)(x)
        x = nn.relu(x)
        return x


class ResidualBlock(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x, *, train: bool = True):
        h = ConvBNAct(self.features)(x, train=train)
        h = nn.Conv(self.features, (3, 3), padding="SAME", use_bias=False)(h)
        h = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5)(h)
        return nn.relu(x + h)


# -----------
# MLP
# -----------

@dataclass
class MLPConfig:
    width: int = 512
    depth: int = 2
    num_classes: int = 10
    dropout_rate: float = 0.0


class MLP(nn.Module):
    """Simple MLP for (N, H, W, C) or (N, D) inputs.
    """

    width: int = 512
    depth: int = 2
    num_classes: int = 10
    dropout_rate: float = 0.0

    # For downstream code that checks this:
    is_mlp: bool = True

    @nn.compact
    def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
        x = _maybe_flatten(x)
        for _ in range(self.depth):
            x = nn.Dense(self.width)(x)
            x = nn.relu(x)
            if self.dropout_rate and self.dropout_rate > 0.0:
                x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(self.num_classes)(x)
        return x


# -----------
# ResNet9
# -----------

@dataclass
class ResNet9Config:
    num_classes: int = 10
    stem_channels: int = 64
    # If your inputs are 1-channel (e.g., MNIST), just pass C=1 tensors; convs are channel-agnostic.

def Norm(num_channels: int, kind: Literal["bn","gn"], *, train: bool):
    if kind == "bn":
        return nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5)
    # GroupNorm: 8 groups (fall back to 1 if channels < 8)
    groups = max(1, num_channels // 8)
    return nn.GroupNorm(num_groups=groups)

class ConvBNAct(nn.Module):
    features: int
    kernel_size: Tuple[int, int] = (3, 3)
    strides: Tuple[int, int] = (1, 1)
    padding: str = "SAME"
    use_bias: bool = False
    norm: Literal["bn","gn"] = "bn"  # <-- NEW

    @nn.compact
    def __call__(self, x, *, train: bool = True):
        x = nn.Conv(
            features=self.features,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding=self.padding,
            use_bias=self.use_bias,
        )(x)
        x = Norm(self.features, self.norm, train=train)(x)  # <-- swap in
        x = nn.relu(x)
        return x

class ResidualBlock(nn.Module):
    features: int
    norm: Literal["bn","gn"] = "bn"  # <-- NEW

    @nn.compact
    def __call__(self, x, *, train: bool = True):
        h = ConvBNAct(self.features, norm=self.norm)(x, train=train)
        h = nn.Conv(self.features, (3, 3), padding="SAME", use_bias=False)(h)
        h = Norm(self.features, self.norm, train=train)(h)  # <-- swap in
        return nn.relu(x + h)
class ResNet9(nn.Module):
    num_classes: int = 10
    stem_channels: int = 64
    is_mlp: bool = False
    norm: Literal["bn","gn"] = "bn"  # <-- NEW

    @nn.compact
    def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
        x = ConvBNAct(self.stem_channels, norm=self.norm)(x, train=train)
        x = ConvBNAct(self.stem_channels * 2, norm=self.norm)(x, train=train)
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2), padding="VALID")
        x = ResidualBlock(self.stem_channels * 2, norm=self.norm)(x, train=train)

        x = ConvBNAct(self.stem_channels * 4, norm=self.norm)(x, train=train)
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2), padding="VALID")
        x = ResidualBlock(self.stem_channels * 4, norm=self.norm)(x, train=train)

        x = ConvBNAct(self.stem_channels * 8, norm=self.norm)(x, train=train)
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2), padding="VALID")
        x = ResidualBlock(self.stem_channels * 8, norm=self.norm)(x, train=train)

        x = x.mean(axis=(1,2))
        x = nn.Dense(self.num_classes)(x)
        return x

# -------------------------
# ResNet50 (bottleneck v1)
# -------------------------
class BottleneckBlock(nn.Module):
    out_channels: int      # channels after the final 1x1 (the block output width)
    stride: int = 1        # spatial stride on the 3x3 conv
    expansion: int = 4     # bottleneck expansion
    use_projection: bool = False

    @nn.compact
    def __call__(self, x, *, train: bool = True):
        in_channels = x.shape[-1]
        mid_channels = self.out_channels // self.expansion

        residual = x
        if self.use_projection or (in_channels != self.out_channels):
            residual = nn.Conv(
                features=self.out_channels,
                kernel_size=(1, 1),
                strides=(self.stride, self.stride),
                use_bias=False,
                padding="SAME",
                name="proj_conv",
            )(x)
            residual = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5, name="proj_bn")(residual)

        # 1x1 reduce
        x = nn.Conv(
            features=mid_channels, kernel_size=(1, 1), strides=(1, 1), use_bias=False, padding="SAME", name="conv1",
        )(x)
        x = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5, name="bn1")(x)
        x = nn.relu(x)

        # 3x3
        x = nn.Conv(
            features=mid_channels, kernel_size=(3, 3), strides=(self.stride, self.stride),
            use_bias=False, padding="SAME", name="conv2",
        )(x)
        x = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5, name="bn2")(x)
        x = nn.relu(x)

        # 1x1 expand
        x = nn.Conv(
            features=self.out_channels, kernel_size=(1, 1), strides=(1, 1), use_bias=False, padding="SAME", name="conv3",
        )(x)
        x = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5, name="bn3")(x)

        x = nn.relu(x + residual)
        return x


class ResNetStage(nn.Module):
    out_channels: int
    num_blocks: int
    first_stride: int
    expansion: int = 4

    @nn.compact
    def __call__(self, x, *, train: bool = True):
        # First block may downsample
        x = BottleneckBlock(
            out_channels=self.out_channels,
            stride=self.first_stride,
            expansion=self.expansion,
            use_projection=True,
        )(x, train=train)
        # Remaining blocks keep stride=1
        for _ in range(self.num_blocks - 1):
            x = BottleneckBlock(
                out_channels=self.out_channels,
                stride=1,
                expansion=self.expansion,
                use_projection=False,
            )(x, train=train)
        return x


@dataclass
class ResNet50Config:
    num_classes: int = 1000
    base_channels: int = 64
    """Channels of the stem; stages scale as 1x, 2x, 4x, 8x of base_channels * expansion."""
    small_input: bool = False
    """
    If True, use a CIFAR-friendly stem (3x3 convs, no initial 7x7/2 + maxpool).
    If False, use ImageNet stem: 7x7/2 + 3x3 maxpool/2.
    """
    expansion: int = 4  # bottleneck expansion


class ResNet50(nn.Module):
    num_classes: int = 1000
    base_channels: int = 64
    small_input: bool = False
    expansion: int = 4
    is_mlp: bool = False  # compatibility

    @nn.compact
    def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
        """Expects NHWC (HWC per-example). Use small_input=True for CIFAR-10."""
        if self.small_input:
            # CIFAR-style stem: no aggressive downsampling, keep spatial size
            x = ConvBNAct(
                self.base_channels,
                kernel_size=(3, 3),
                strides=(1, 1),
                padding="SAME",
            )(x, train=train)
            # (Optional second 3x3; keep it SAME if you want it.)
            # x = ConvBNAct(self.base_channels, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x, train=train)
        else:
            # ImageNet-style stem
            x = ConvBNAct(
                self.base_channels,
                kernel_size=(7, 7),
                strides=(2, 2),
                padding="SAME",
            )(x, train=train)
            x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")

        # Stages: (3, 4, 6, 3) with widths 64, 128, 256, 512 (pre-expansion)
        widths  = [self.base_channels, self.base_channels * 2, self.base_channels * 4, self.base_channels * 8]
        blocks  = [3, 4, 6, 3]
        strides = [1, 2, 2, 2]  # first stage stride 1 for both CIFAR & ImageNet

        for w, n, s in zip(widths, blocks, strides):
            x = ResNetStage(
                out_channels=w * self.expansion,
                num_blocks=n,
                first_stride=s,
                expansion=self.expansion,
            )(x, train=train)

        x = x.mean(axis=(1, 2))           # GAP
        x = nn.Dense(self.num_classes)(x) # logits
        return x


# -------------------------
# WideResNet (WRN-16-4)
# -------------------------
class _WRNBasic(nn.Module):
    out_channels: int
    stride: int = 1
    dropout_rate: float = 0.0

    @nn.compact
    def __call__(self, x, *, train: bool = True):
        in_c = x.shape[-1]
        h = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5)(x)
        h = nn.relu(h)
        h = nn.Conv(self.out_channels, (3, 3), strides=(self.stride, self.stride),
                    padding="SAME", use_bias=False)(h)
        if self.dropout_rate and self.dropout_rate > 0.0:
            h = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(h)
        h = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5)(h)
        h = nn.relu(h)
        h = nn.Conv(self.out_channels, (3, 3), padding="SAME", use_bias=False)(h)

        # projection if shape changes
        if self.stride != 1 or in_c != self.out_channels:
            x = nn.Conv(self.out_channels, (1, 1), strides=(self.stride, self.stride),
                        padding="SAME", use_bias=False)(x)
        return x + h


class _WRNBlock(nn.Module):
    out_channels: int
    num_layers: int
    stride: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, *, train: bool = True):
        # First basic block may downsample
        x = _WRNBasic(self.out_channels, stride=self.stride, dropout_rate=self.dropout_rate)(x, train=train)
        for _ in range(self.num_layers - 1):
            x = _WRNBasic(self.out_channels, stride=1, dropout_rate=self.dropout_rate)(x, train=train)
        return x


@dataclass
class WideResNetConfig:
    depth: int = 16           # 6n+4 -> n=(depth-4)/6
    widen_factor: int = 4
    num_classes: int = 10
    dropout_rate: float = 0.0
    base_channels: int = 16   # standard WRN stem


class WideResNet(nn.Module):
    depth: int = 16
    widen_factor: int = 4
    num_classes: int = 10
    dropout_rate: float = 0.0
    base_channels: int = 16

    is_mlp: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
        # CIFAR-style stem (HWC input)
        x = nn.Conv(self.base_channels, (3, 3), padding="SAME", use_bias=False)(x)

        # depth = 6n + 4 => n residual blocks per group
        assert (self.depth - 4) % 6 == 0, "WRN depth must be 6n+4"
        n = (self.depth - 4) // 6

        widths = [self.base_channels * w for w in (1, 2 * self.widen_factor, 4 * self.widen_factor, 8 * self.widen_factor)]
        # Standard WRN groups: C, 2C, 4C with strides 1,2,2 (we already used stem C)
        x = _WRNBlock(out_channels=widths[1], num_layers=n, stride=1, dropout_rate=self.dropout_rate)(x, train=train)
        x = _WRNBlock(out_channels=widths[2], num_layers=n, stride=2, dropout_rate=self.dropout_rate)(x, train=train)
        x = _WRNBlock(out_channels=widths[3], num_layers=n, stride=2, dropout_rate=self.dropout_rate)(x, train=train)

        x = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5)(x)
        x = nn.relu(x)
        x = x.mean(axis=(1, 2))  # GAP
        x = nn.Dense(self.num_classes)(x)
        return x

# -------------------------
# ResNet18 (v1, CIFAR style)
# -------------------------
class _BasicBlockV1(nn.Module):
    out_channels: int
    stride: int = 1
    use_projection: bool = False
    bn_momentum: float = 0.9
    bn_eps: float = 1e-5

    @nn.compact
    def __call__(self, x, *, train: bool = True):
        residual = x
        x = nn.Conv(self.out_channels, (3, 3), strides=(self.stride, self.stride),
                    padding="SAME", use_bias=False)(x)
        x = nn.BatchNorm(use_running_average=not train, momentum=self.bn_momentum, epsilon=self.bn_eps)(x)
        x = nn.relu(x)

        x = nn.Conv(self.out_channels, (3, 3), strides=(1, 1),
                    padding="SAME", use_bias=False)(x)
        x = nn.BatchNorm(use_running_average=not train, momentum=self.bn_momentum, epsilon=self.bn_eps)(x)

        if self.use_projection:
            residual = nn.Conv(self.out_channels, (1, 1), strides=(self.stride, self.stride),
                               padding="SAME", use_bias=False)(residual)
            residual = nn.BatchNorm(use_running_average=not train, momentum=self.bn_momentum, epsilon=self.bn_eps)(residual)

        return nn.relu(x + residual)


class _ResNetStageV1(nn.Module):
    out_channels: int
    num_blocks: int
    first_stride: int
    bn_momentum: float = 0.9
    bn_eps: float = 1e-5

    @nn.compact
    def __call__(self, x, *, train: bool = True):
        # First block: possibly downsample with projection if stride>1 or channels change
        x = _BasicBlockV1(
            out_channels=self.out_channels,
            stride=self.first_stride,
            use_projection=True if self.first_stride != 1 else (x.shape[-1] != self.out_channels),
            bn_momentum=self.bn_momentum,
            bn_eps=self.bn_eps,
        )(x, train=train)

        # Remaining blocks keep the same shape
        for _ in range(self.num_blocks - 1):
            x = _BasicBlockV1(
                out_channels=self.out_channels,
                stride=1,
                use_projection=False,
                bn_momentum=self.bn_momentum,
                bn_eps=self.bn_eps,
            )(x, train=train)
        return x


@dataclass
class ResNet18Config:
    num_classes: int = 10
    base_channels: int = 64           # stem width
    small_input: bool = True          # CIFAR use-case
    bn_momentum: float = 0.9
    bn_eps: float = 1e-5


class ResNet18(nn.Module):
    num_classes: int = 10
    base_channels: int = 64
    small_input: bool = True
    bn_momentum: float = 0.9
    bn_eps: float = 1e-5
    is_mlp: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
        # CIFAR-style stem: 3x3 conv, stride 1, keep spatial size
        if self.small_input:
            x = nn.Conv(self.base_channels, (3, 3), strides=(1, 1), padding="SAME", use_bias=False)(x)
            x = nn.BatchNorm(use_running_average=not train, momentum=self.bn_momentum, epsilon=self.bn_eps)(x)
            x = nn.relu(x)
        else:
            # Optional ImageNet stem if you ever need it
            x = nn.Conv(self.base_channels, (7, 7), strides=(2, 2), padding="SAME", use_bias=False)(x)
            x = nn.BatchNorm(use_running_average=not train, momentum=self.bn_momentum, epsilon=self.bn_eps)(x)
            x = nn.relu(x)
            x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")

        # ResNet18 stages: [2, 2, 2, 2] with widths [64, 128, 256, 512]
        widths  = [self.base_channels, self.base_channels*2, self.base_channels*4, self.base_channels*8]
        blocks  = [2, 2, 2, 2]
        strides = [1, 2, 2, 2]  # first stage keeps stride=1 on CIFAR

        for w, n, s in zip(widths, blocks, strides):
            x = _ResNetStageV1(
                out_channels=w,
                num_blocks=n,
                first_stride=s,
                bn_momentum=self.bn_momentum,
                bn_eps=self.bn_eps,
            )(x, train=train)

        x = x.mean(axis=(1, 2))  # global average pool
        x = nn.Dense(self.num_classes)(x)
        return x



__all__ = [
    "MLP",
    "MLPConfig",
    "ResNet9",
    "ResNet9Config",
    "ResNet18",
    "ResNet18Config",
    "ResNet50",
    "ResNet50Config",
    "WideResNet",
    "WideResNetConfig",
]
