from __future__ import annotations

import torch
import torch.nn as nn
from torch import Tensor


class ACETransformerEncoderLayer(nn.TransformerEncoderLayer):
    """
    Drop-in replacement for TransformerEncoderLayer where the self-attention
    block performs cross-attention from the full sequence to the context
    subsequence. It infers the number of context tokens from the provided
    attention mask (zeros in the first row denote context columns).

    Expected mask semantics (matching TNP non-AR mask):
    - mask shape [L, L] float, zeros in [:, :Nc], -inf elsewhere.
    """

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int, dropout: float, *, batch_first: bool = True):
        super().__init__(d_model, nhead, dim_feedforward, dropout, batch_first=batch_first)

    def _sa_block(
        self,
        x: Tensor,
        attn_mask: Tensor | None,
        key_padding_mask: Tensor | None,
        is_causal: bool = False,
    ) -> Tensor:
        assert attn_mask is not None, "ACE layer requires an attention mask to infer Nc"
        # Infer Nc from the mask (zeros indicate unmasked columns)
        slice_ = attn_mask[0, :]
        zero_mask = slice_ == 0
        num_ctx = int(torch.sum(zero_mask).item())

        # Cross-attention: queries are the whole sequence, keys/values are context only
        q = x
        k = x[:, :num_ctx, :]
        v = x[:, :num_ctx, :]
        y = self.self_attn(
            q,
            k,
            v,
            attn_mask=None,
            key_padding_mask=key_padding_mask,
            need_weights=False,
            is_causal=is_causal,
        )[0]
        return self.dropout1(y)


def _layer_hparams(layer: nn.TransformerEncoderLayer) -> tuple[int, int, int, float]:
    d_model = layer.self_attn.embed_dim
    n_heads = layer.self_attn.num_heads
    d_ff = layer.linear1.out_features
    p = getattr(layer.dropout1, "p", 0.0)
    return d_model, n_heads, d_ff, p


def _copy_layer_weights(src: nn.TransformerEncoderLayer, dst: nn.TransformerEncoderLayer) -> None:
    with torch.no_grad():
        # MHA
        dst.self_attn.in_proj_weight.copy_(src.self_attn.in_proj_weight)
        if src.self_attn.in_proj_bias is not None and dst.self_attn.in_proj_bias is not None:
            dst.self_attn.in_proj_bias.copy_(src.self_attn.in_proj_bias)
        dst.self_attn.out_proj.weight.copy_(src.self_attn.out_proj.weight)
        if src.self_attn.out_proj.bias is not None and dst.self_attn.out_proj.bias is not None:
            dst.self_attn.out_proj.bias.copy_(src.self_attn.out_proj.bias)
        # FFN
        dst.linear1.weight.copy_(src.linear1.weight)
        dst.linear1.bias.copy_(src.linear1.bias)
        dst.linear2.weight.copy_(src.linear2.weight)
        dst.linear2.bias.copy_(src.linear2.bias)
        # Norms
        dst.norm1.weight.copy_(src.norm1.weight)
        dst.norm1.bias.copy_(src.norm1.bias)
        dst.norm2.weight.copy_(src.norm2.weight)
        dst.norm2.bias.copy_(src.norm2.bias)


def patch_tnp_encoder_with_ace(model: nn.Module) -> nn.Module:
    """
    Replace model.encoder (TransformerEncoder) with an equivalent encoder built
    from ACETransformerEncoderLayer, copying all weights for parity.

    Intended for TNPD/TNPND which use non-AR masks (context-visible only).
    """
    if not hasattr(model, "encoder") or not isinstance(model.encoder, nn.TransformerEncoder):
        return model
    old_enc: nn.TransformerEncoder = model.encoder
    device = next(model.parameters()).device
    dtype = next(model.parameters()).dtype
    # Collect hparams from first layer
    first_layer = old_enc.layers[0]
    d_model, n_heads, d_ff, p = _layer_hparams(first_layer)

    new_layer = ACETransformerEncoderLayer(d_model, n_heads, d_ff, p, batch_first=True)
    new_enc = nn.TransformerEncoder(new_layer, num_layers=len(old_enc.layers))
    new_enc = new_enc.to(device=device, dtype=dtype)

    # Copy weights layer-by-layer
    for src, dst in zip(old_enc.layers, new_enc.layers):
        _copy_layer_weights(src, dst)

    model.encoder = new_enc
    return model

