from __future__ import annotations
import torch
from .spec import GateSpec
from .masked import MaskedFeedForward, MaskedLinear


def _outer_mask(row_gate: torch.Tensor, col_gate: torch.Tensor) -> torch.Tensor:
    return torch.outer(row_gate, col_gate)


def apply_ffn_rowcol_gates(
    gates_1d: torch.Tensor, spec: GateSpec, wrappers: list[MaskedFeedForward]
):
    for gm in spec.mats:
        r = gates_1d[gm.r_start : gm.r_start + gm.r_len]
        c = gates_1d[gm.c_start : gm.c_start + gm.c_len]
        M = _outer_mask(r, c)
        w = wrappers[gm.layer_idx]
        if gm.kind == "dense1":
            w.m_w1 = M
            if w.ffn.dense_1.bias is not None:
                w.m_b1 = r
        elif gm.kind == "dense2":
            w.m_w2 = M
            if w.ffn.dense_2.bias is not None:
                w.m_b2 = r
        else:
            raise ValueError(f"Unknown FFN gate kind {gm.kind}")


def apply_attn_rowcol_gates(
    gates_1d: torch.Tensor, spec: GateSpec, wrappers: list[MaskedLinear]
):
    for gm in spec.mats:
        r = gates_1d[gm.r_start : gm.r_start + gm.r_len]
        c = gates_1d[gm.c_start : gm.c_start + gm.c_len]
        M = _outer_mask(r, c)
        w = wrappers[gm.layer_idx]
        w.m_w = M
        if w.base.bias is not None:
            w.m_b = r
