# forward_forward/models/layers/normalizer.py

import torch
import torch.nn as nn
from typing import Optional, Tuple, Union, List

class Normalizer(nn.Module):
    """
    Normalize input activities using specified mode and dimensions.
    Supports channel-wise normalization followed by global normalization.

    Args:
        mode (str): Normalization mode. One of ["l1", "l2", "std"].
        zero_mean (bool): Whether to subtract mean before normalization.
        eps (float): Small epsilon value to avoid division by zero.
        norm_dims (Optional[Union[int, Tuple[int, ...]]]): Dimensions to normalize over.
            If None, default is to flatten all but batch (original behavior).
        normalize_channels_first (bool): If True, normalizes each channel independently before global normalization.
        channel_dims (Optional[Union[int, List[int]]]): Dimensions to treat as channels for channel-wise normalization.
            Defaults to dim=1 (standard channel dim in PyTorch).
    """
    def __init__(
        self,
        mode: str = "std",
        zero_mean: bool = True,
        eps: float = 1e-6,
        norm_dims: Optional[Union[int, Tuple[int, ...]]] = None,
        channel_dims: Optional[Union[int, List[int]]] = None,
    ):
        super().__init__()
        self.eps = eps
        self.norm_dims = norm_dims
        self.channel_dims = (1,) if channel_dims is None else (
            (channel_dims,) if isinstance(channel_dims, int) else tuple(channel_dims)
        )

    def _normalize(self, x: torch.Tensor, dims: Tuple[int, ...]) -> torch.Tensor:
        """Helper function to normalize over given dimensions."""
        x = x - x.mean(dim=dims, keepdim=True)

        norm = x.std(dim=dims, keepdim=True)

        norm = norm.clamp(min=self.eps)
        return x / norm

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        original_shape = x.shape

        # Step 2: Global normalization
        if self.norm_dims is not None:
            dims = self.norm_dims if isinstance(self.norm_dims, tuple) else (self.norm_dims,)
        else:
            # Default: Flatten all but batch dim
            if x.dim() > 2:
                x = x.view(x.size(0), -1)
                dims = (1,)
            else:
                dims = (1,)

        x = self._normalize(x, dims=dims)

        # Reshape back if we flattened
        if self.norm_dims is None and original_shape != x.shape:
            x = x.view(*original_shape)

        return x
