"""
Gate Modules for AnIsoNet

Gating mechanisms from the Gated DeltaNet architecture.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNormGated(nn.Module):
    """RMSNorm with optional gating"""

    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states, gate=None):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        hidden_states = self.weight * hidden_states.to(input_dtype)

        if gate is not None:
            hidden_states = hidden_states * F.silu(gate.to(torch.float32)).to(input_dtype)

        return hidden_states
