from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


def _broadcast_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    mask = mask.to(dtype=x.dtype, device=x.device)
    if mask.dim() == 1:
        mask = mask.view(1, 1, -1)
    elif mask.dim() == 2:
        mask = mask.unsqueeze(1)
    elif mask.dim() == 3:
        pass
    else:
        raise ValueError(f"Unsupported mask shape: {mask.shape}")
    return mask


class MaskableLlamaMLP(nn.Module):
    def __init__(self, mlp: nn.Module):
        super().__init__()
        self.up_proj = mlp.up_proj
        self.gate_proj = getattr(mlp, "gate_proj", None)
        self.down_proj = mlp.down_proj
        self.act_fn = getattr(mlp, "act_fn", F.silu)
        self.dropout = getattr(mlp, "dropout", None)
        self.intermediate_mask: Optional[torch.Tensor] = None

    def set_mask(self, mask: Optional[torch.Tensor]) -> None:
        self.intermediate_mask = mask

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_input = self.gate_proj(x) if self.gate_proj is not None else self.up_proj(x)
        gate = self.act_fn(gate_input)
        up = self.up_proj(x)

        mask = self.intermediate_mask
        if mask is not None:
            broadcasted = _broadcast_mask(up, mask)
            up = up * broadcasted
            if self.gate_proj is None:
                gate = gate * broadcasted

        hidden = gate * up if self.gate_proj is not None else gate
        down = self.down_proj(hidden)
        if self.dropout is not None:
            down = self.dropout(down)
        return down


def _resolve_llama_layers(model: nn.Module):
    """Return the decoder layers regardless of PEFT wrappers (LLaMA/Qwen-style)."""
    current = model
    while current is not None:
        inner = getattr(current, "model", None)
        if inner is not None:
            if hasattr(inner, "layers"):
                return inner.layers
            decoder = getattr(inner, "decoder", None)
            if decoder is not None and hasattr(decoder, "layers"):
                return decoder.layers

        if hasattr(current, "layers"):
            return current.layers
        decoder = getattr(current, "decoder", None)
        if decoder is not None and hasattr(decoder, "layers"):
            return decoder.layers

        transformer = getattr(current, "transformer", None)
        if transformer is not None and hasattr(transformer, "h"):
            return transformer.h

        current = getattr(current, "base_model", None)
    raise AttributeError("Unable to locate transformer layers on the provided model.")


def wrap_model_with_maskable_ffn(model: nn.Module) -> None:
    layers = _resolve_llama_layers(model)
    for layer in layers:
        if not isinstance(layer.mlp, MaskableLlamaMLP):
            layer.mlp = MaskableLlamaMLP(layer.mlp)


def set_layer_mask(model: nn.Module, layer_idx: int, mask: Optional[torch.Tensor]) -> None:
    layers = _resolve_llama_layers(model)
    layers[layer_idx].mlp.set_mask(mask)


def clear_all_masks(model: nn.Module) -> None:
    layers = _resolve_llama_layers(model)
    for layer in layers:
        layer.mlp.set_mask(None)


def apply_inference_mask(mlp: MaskableLlamaMLP, mask: torch.Tensor) -> None:
    mask = mask.to(dtype=mlp.up_proj.weight.dtype, device=mlp.up_proj.weight.device)
    mask_flat = mask.view(-1)

    up_weight = mlp.up_proj.weight.data
    mlp.up_proj.weight.data = up_weight * mask_flat.unsqueeze(-1)
    if mlp.up_proj.bias is not None:
        mlp.up_proj.bias.data *= mask_flat
