from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F


class MaskedFeedForward(nn.Module):
    def __init__(self, original_ffn: nn.Module):
        super().__init__()
        self.ffn = original_ffn
        self.m_w1: Optional[torch.Tensor] = None
        self.m_b1: Optional[torch.Tensor] = None
        self.m_w2: Optional[torch.Tensor] = None
        self.m_b2: Optional[torch.Tensor] = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert (
            self.m_w1 is not None and self.m_w2 is not None
        ), "Mask tensors must be set before forward()"
        residual = x
        
        w1 = self.ffn.dense_1.weight * self.m_w1
        b1 = self.ffn.dense_1.bias
        if b1 is not None and self.m_b1 is not None:
            b1 = b1 * self.m_b1
        x = F.linear(x, w1, b1)
        act = getattr(self.ffn, "intermediate_act_fn", None)
        x = act(x) if act is not None else F.gelu(x)

        w2 = self.ffn.dense_2.weight * self.m_w2
        b2 = self.ffn.dense_2.bias
        if b2 is not None and self.m_b2 is not None:
            b2 = b2 * self.m_b2
        x = F.linear(x, w2, b2)
        drop = getattr(self.ffn, "dropout", None)
        if drop is not None:
            x = drop(x)
        x = x + residual
        ln = getattr(self.ffn, "LayerNorm", None)
        if ln is not None:
            x = ln(x)
        return x


class MaskedLinear(nn.Module):
    def __init__(self, base: nn.Linear):
        super().__init__()
        self.base = base
        self.m_w: Optional[torch.Tensor] = None
        self.m_b: Optional[torch.Tensor] = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.base.weight if self.m_w is None else self.base.weight * self.m_w
        b = self.base.bias
        if b is not None and self.m_b is not None:
            b = b * self.m_b
        return F.linear(x, w, b)
