# SPDX-License-Identifier: Apache-2.0
"""Minimal distiller MLP utilities for overlap demo."""

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


def _layer_norm_batch(batch: Tensor) -> Tensor:
    if batch.dim() <= 1:
        return batch
    normalized_shape = batch.shape[1:]
    if not normalized_shape:
        return batch
    return F.layer_norm(batch, normalized_shape)


class SwiGLUBlock(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.w_gate = nn.Linear(dim, hidden_dim, bias=False)
        self.w_value = nn.Linear(dim, hidden_dim, bias=False)
        self.w_out = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x_norm = self.norm(x)
        gate = F.silu(self.w_gate(x_norm))
        value = self.w_value(x_norm)
        x_hidden = gate * value
        out = self.w_out(x_hidden)
        return out + residual


class TransitionMLP(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        num_layers: int = 4,
        mlp_hidden_dim: Optional[int] = None,
    ):
        super().__init__()

        if input_dim != output_dim:
            self.project_in = nn.Linear(input_dim, output_dim, bias=False)
        else:
            self.project_in = nn.Identity()

        current_dim = output_dim
        if mlp_hidden_dim is None or mlp_hidden_dim <= 0:
            hidden_dim = min(current_dim // 4, 384)
        else:
            hidden_dim = int(mlp_hidden_dim)

        self.layers = nn.ModuleList(
            [SwiGLUBlock(dim=current_dim, hidden_dim=hidden_dim) for _ in range(num_layers)]
        )
        # Optional output norm (set by overlap runtime).
        self.register_buffer("output_norm_weight", None)
        self.register_buffer("output_norm_bias", None)
        self.output_norm_eps = 1e-6
        self.output_norm_type = "rms"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.project_in(x)
        for layer in self.layers:
            x = layer(x)
        return self._apply_output_norm(x)

    def set_output_norm(
        self,
        *,
        weight: Optional[torch.Tensor],
        bias: Optional[torch.Tensor],
        eps: float,
        norm_type: str,
    ) -> None:
        if weight is None:
            self.output_norm_weight = None
            self.output_norm_bias = None
            self.output_norm_eps = float(eps)
            self.output_norm_type = str(norm_type)
            return
        self.output_norm_weight = weight.detach().clone().to(torch.float32)
        if bias is not None:
            self.output_norm_bias = bias.detach().clone().to(torch.float32)
        else:
            self.output_norm_bias = None
        self.output_norm_eps = float(eps)
        self.output_norm_type = str(norm_type)

    def _apply_output_norm(self, hidden: torch.Tensor) -> torch.Tensor:
        weight = self.output_norm_weight
        if weight is None:
            return hidden
        eps = self.output_norm_eps
        weight = weight.to(device=hidden.device, dtype=hidden.dtype)
        bias = self.output_norm_bias
        if bias is not None:
            bias = bias.to(device=hidden.device, dtype=hidden.dtype)
        if self.output_norm_type == "layer":
            return F.layer_norm(
                hidden,
                (hidden.shape[-1],),
                weight=weight,
                bias=bias,
                eps=eps,
            )
        if self.output_norm_type == "rms" and hasattr(F, "rms_norm"):
            out = F.rms_norm(hidden, (hidden.shape[-1],), weight=weight, eps=eps)
            if bias is not None:
                out = out + bias
            return out
        if self.output_norm_type == "layer":
            mean = hidden.mean(dim=-1, keepdim=True)
            var = hidden.var(dim=-1, unbiased=False, keepdim=True)
            out = (hidden - mean) * torch.rsqrt(var + eps)
        else:
            var = hidden.pow(2).mean(dim=-1, keepdim=True)
            out = hidden * torch.rsqrt(var + eps)
        out = out * weight
        if bias is not None:
            out = out + bias
        return out


class BatchedTransitionMLP(nn.Module):
    def __init__(
        self,
        num_groups: int,
        input_dim: int,
        output_dim: int,
        num_layers: int = 4,
        mlp_hidden_dim: Optional[int] = None,
    ):
        super().__init__()
        if num_groups <= 0:
            raise ValueError("num_groups must be > 0")
        self.num_groups = int(num_groups)
        self.input_dim = int(input_dim)
        self.output_dim = int(output_dim)
        self.num_layers = int(num_layers)
        if mlp_hidden_dim is None or mlp_hidden_dim <= 0:
            hidden_dim = min(self.output_dim // 4, 384)
        else:
            hidden_dim = int(mlp_hidden_dim)
        self.hidden_dim = hidden_dim

        if self.input_dim != self.output_dim:
            self.project_in_weight = nn.Parameter(
                torch.empty(self.num_groups, self.output_dim, self.input_dim)
            )
        else:
            self.project_in_weight = None

        self.norm_weight = nn.ParameterList(
            [nn.Parameter(torch.empty(self.num_groups, self.output_dim)) for _ in range(self.num_layers)]
        )
        self.norm_bias = nn.ParameterList(
            [nn.Parameter(torch.empty(self.num_groups, self.output_dim)) for _ in range(self.num_layers)]
        )
        self.w_gate = nn.ParameterList(
            [
                nn.Parameter(torch.empty(self.num_groups, self.hidden_dim, self.output_dim))
                for _ in range(self.num_layers)
            ]
        )
        self.w_value = nn.ParameterList(
            [
                nn.Parameter(torch.empty(self.num_groups, self.hidden_dim, self.output_dim))
                for _ in range(self.num_layers)
            ]
        )
        self.w_out = nn.ParameterList(
            [
                nn.Parameter(torch.empty(self.num_groups, self.output_dim, self.hidden_dim))
                for _ in range(self.num_layers)
            ]
        )
        self.norm_eps = 1e-6
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.project_in_weight is not None:
            nn.init.kaiming_uniform_(self.project_in_weight, a=5**0.5)
        for idx in range(self.num_layers):
            nn.init.ones_(self.norm_weight[idx])
            nn.init.zeros_(self.norm_bias[idx])
            nn.init.kaiming_uniform_(self.w_gate[idx], a=5**0.5)
            nn.init.kaiming_uniform_(self.w_value[idx], a=5**0.5)
            nn.init.kaiming_uniform_(self.w_out[idx], a=5**0.5)

    def forward(self, x: Tensor, group_indices: Optional[Tensor] = None) -> Tensor:
        if x.dim() != 3:
            raise ValueError(f"BatchedTransitionMLP expects 3D input, got {x.shape}")
        if group_indices is None:
            group_indices = torch.arange(self.num_groups, device=x.device, dtype=torch.long)
        if group_indices.dim() != 1:
            raise ValueError("group_indices must be 1D")
        if x.size(0) != group_indices.numel():
            raise ValueError(
                f"Input batch size {x.size(0)} must match group_indices {group_indices.numel()}"
            )
        if self.project_in_weight is not None:
            w_proj = self.project_in_weight.index_select(0, group_indices)
            x = torch.bmm(x, w_proj.transpose(1, 2))
        out = x
        for idx in range(self.num_layers):
            norm_w = self.norm_weight[idx].index_select(0, group_indices)
            norm_b = self.norm_bias[idx].index_select(0, group_indices)
            w_gate = self.w_gate[idx].index_select(0, group_indices)
            w_value = self.w_value[idx].index_select(0, group_indices)
            w_out = self.w_out[idx].index_select(0, group_indices)
            residual = out
            x_norm = _batched_layer_norm(out, weight=norm_w, bias=norm_b, eps=self.norm_eps)
            gate = F.silu(torch.bmm(x_norm, w_gate.transpose(1, 2)))
            value = torch.bmm(x_norm, w_value.transpose(1, 2))
            hidden = gate * value
            out = torch.bmm(hidden, w_out.transpose(1, 2))
            out = out + residual
        return out


def _batched_layer_norm(
    x: Tensor,
    *,
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    eps: float,
) -> Tensor:
    # Per-group layer norm over the last dimension.
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, unbiased=False, keepdim=True)
    out = (x - mean) * torch.rsqrt(var + eps)
    if weight is not None:
        out = out * weight.unsqueeze(1)
    if bias is not None:
        out = out + bias.unsqueeze(1)
    return out


def _batched_rms_norm(
    x: Tensor,
    *,
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    eps: float,
) -> Tensor:
    var = x.pow(2).mean(dim=-1, keepdim=True)
    out = x * torch.rsqrt(var + eps)
    if weight is not None:
        out = out * weight.unsqueeze(1)
    if bias is not None:
        out = out + bias.unsqueeze(1)
    return out


def stack_transition_weights(predictors: list[TransitionMLP]) -> dict:
    if not predictors:
        return {"proj": None, "layers": [], "out_norm": None}
    proj0 = predictors[0].project_in
    proj_weight = None
    if not isinstance(proj0, nn.Identity):
        proj_weight = torch.stack([p.project_in.weight for p in predictors], dim=0)
    num_layers = len(predictors[0].layers)
    layers = []
    for layer_idx in range(num_layers):
        layer_list = [p.layers[layer_idx] for p in predictors]
        layers.append(
            (
                torch.stack([layer.norm.weight for layer in layer_list], dim=0),
                torch.stack([layer.norm.bias for layer in layer_list], dim=0),
                torch.stack([layer.w_gate.weight for layer in layer_list], dim=0),
                torch.stack([layer.w_value.weight for layer in layer_list], dim=0),
                torch.stack([layer.w_out.weight for layer in layer_list], dim=0),
                float(layer_list[0].norm.eps),
            )
        )
    out_norm = None
    if predictors[0].output_norm_weight is not None:
        out_weight = torch.stack([p.output_norm_weight for p in predictors], dim=0)
        out_bias = None
        if predictors[0].output_norm_bias is not None:
            out_bias = torch.stack([p.output_norm_bias for p in predictors], dim=0)
        out_norm = (
            out_weight,
            out_bias,
            float(predictors[0].output_norm_eps),
            str(predictors[0].output_norm_type),
        )
    return {"proj": proj_weight, "layers": layers, "out_norm": out_norm}


def batched_transition_forward_stacked(x: Tensor, stacked: dict) -> Tensor:
    if x.dim() != 3:
        raise ValueError(f"batched_transition_forward expects 3D input, got {x.shape}")
    proj_weight = stacked.get("proj")
    if isinstance(proj_weight, torch.Tensor):
        x = torch.bmm(x, proj_weight.transpose(1, 2))
    out = x
    for norm_w, norm_b, w_gate, w_value, w_out, eps in stacked.get("layers", []):
        residual = out
        x_norm = _batched_layer_norm(out, weight=norm_w, bias=norm_b, eps=eps)
        gate = F.silu(torch.bmm(x_norm, w_gate.transpose(1, 2)))
        value = torch.bmm(x_norm, w_value.transpose(1, 2))
        hidden = gate * value
        out = torch.bmm(hidden, w_out.transpose(1, 2))
        out = out + residual
    out_norm = stacked.get("out_norm")
    if out_norm is not None:
        weight, bias, eps, norm_type = out_norm
        if norm_type == "layer":
            out = _batched_layer_norm(out, weight=weight, bias=bias, eps=eps)
        else:
            out = _batched_rms_norm(out, weight=weight, bias=bias, eps=eps)
    return out


def batched_transition_forward(x: Tensor, predictors: list[TransitionMLP]) -> Tensor:
    """Batched forward for multiple TransitionMLP instances.

    Args:
        x: [G, N, D] input tensor.
        predictors: list of TransitionMLP, length G.
    """
    if not predictors:
        return x
    num_groups = len(predictors)
    if x.size(0) != num_groups:
        raise ValueError(
            f"batched_transition_forward expects G={num_groups} in input, got {x.size(0)}"
        )
    stacked = stack_transition_weights(predictors)
    return batched_transition_forward_stacked(x, stacked)
