from __future__ import annotations

from typing import Sequence

import torch


NEG_INF = -1e9


def _as_int_list(values: Sequence[int] | torch.Tensor) -> list[int]:
    if isinstance(values, torch.Tensor):
        return [int(x) for x in values.detach().cpu().tolist()]
    return [int(x) for x in values]


def build_bec_mask(
    batch_B: Sequence[int] | torch.Tensor,
    batch_E: Sequence[int] | torch.Tensor,
    batch_C: Sequence[int] | torch.Tensor,
    batch_windows: Sequence[Sequence[tuple[int, int, int, int]]],
    device: torch.device | str | None = None,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """
    batch_B, batch_E, batch_C:
        1D tensors or lists of ints of length batch_size, giving B, E, C for each example.

    batch_windows:
        List of length batch_size.
        batch_windows[i] is a list of tuples (q_start, q_end, e_start, e_end)
        with global indices for that example.

    Returns:
        bec_mask: (batch_size, L_max, L_max) tensor of additive masks.
                  For positions q>=L_i or k>=L_i in example i, leave mask at 0.0
                  (they will be masked by padding separately).
    """
    batch_B_list = _as_int_list(batch_B)
    batch_E_list = _as_int_list(batch_E)
    batch_C_list = _as_int_list(batch_C)

    batch_size = len(batch_B_list)
    if len(batch_E_list) != batch_size or len(batch_C_list) != batch_size:
        raise ValueError("batch_B, batch_E, and batch_C must have the same length")
    if len(batch_windows) != batch_size:
        raise ValueError("batch_windows must have one entry per batch example")

    lengths = [b + e + c for b, e, c in zip(batch_B_list, batch_E_list, batch_C_list)]
    l_max = max(lengths, default=0)
    bec_mask = torch.zeros((batch_size, l_max, l_max), device=device, dtype=dtype)
    blocked_value = torch.tensor(NEG_INF, device=device, dtype=dtype)

    for i, (b, e, c, length) in enumerate(zip(batch_B_list, batch_E_list, batch_C_list, lengths)):
        if min(b, e, c) < 0:
            raise ValueError("B, E, and C must be non-negative")
        if length == 0:
            continue

        explanation_start = b
        explanation_end = b + e
        claim_start = explanation_end
        claim_end = length

        bec_mask[i, :length, :length] = blocked_value

        if b:
            board_tri = torch.tril(torch.ones((b, b), device=device, dtype=torch.bool))
            bec_mask[i, :b, :b].masked_fill_(board_tri, 0.0)

        if e:
            bec_mask[i, explanation_start:explanation_end, :b] = 0.0
            explanation_tri = torch.tril(torch.ones((e, e), device=device, dtype=torch.bool))
            bec_mask[i, explanation_start:explanation_end, explanation_start:explanation_end].masked_fill_(
                explanation_tri,
                0.0,
            )

        for q_start, q_end, e_start, e_end in batch_windows[i]:
            q_start = int(q_start)
            q_end = int(q_end)
            e_start = int(e_start)
            e_end = int(e_end)

            if not (claim_start <= q_start <= q_end < claim_end):
                raise ValueError(
                    f"Window query range {(q_start, q_end)} must lie in claim block "
                    f"[{claim_start}, {claim_end}) for batch item {i}"
                )
            if not (explanation_start <= e_start <= e_end < explanation_end):
                raise ValueError(
                    f"Window key range {(e_start, e_end)} must lie in explanation block "
                    f"[{explanation_start}, {explanation_end}) for batch item {i}"
                )

            bec_mask[i, q_start : q_end + 1, e_start : e_end + 1] = 0.0

    return bec_mask


def combine_with_padding_mask(bec_mask: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    """
    Convert a tokenizer-style attention_mask to an additive key padding mask and combine it
    with a BEC mask.

    attention_mask: (batch, L_max), with 1 for real tokens and 0 for padding.
    returns: (batch, 1, L_max, L_max), additive attention mask.
    """
    pad_mask = (1 - attention_mask).unsqueeze(1).unsqueeze(2).to(dtype=bec_mask.dtype) * NEG_INF
    return bec_mask.unsqueeze(1) + pad_mask


def _assert_allowed(mask: torch.Tensor, q: int, allowed_keys: set[int], length: int) -> None:
    for k in range(length):
        value = float(mask[q, k].item())
        if k in allowed_keys:
            assert value == 0.0, f"Expected q={q}, k={k} to be allowed, got {value}"
        else:
            assert value == NEG_INF, f"Expected q={q}, k={k} to be blocked, got {value}"


def _assert_raises_value_error(fn) -> None:
    try:
        fn()
    except ValueError:
        return
    raise AssertionError("Expected ValueError")


if __name__ == "__main__":
    b, e, c = 2, 3, 2
    length = b + e + c
    mask = build_bec_mask(
        batch_B=[b],
        batch_E=[e],
        batch_C=[c],
        batch_windows=[[(5, 6, 2, 3)]],
    )
    assert mask.shape == (1, length, length)
    item = mask[0]

    _assert_allowed(item, q=0, allowed_keys={0}, length=length)
    _assert_allowed(item, q=1, allowed_keys={0, 1}, length=length)

    _assert_allowed(item, q=2, allowed_keys={0, 1, 2}, length=length)
    _assert_allowed(item, q=3, allowed_keys={0, 1, 2, 3}, length=length)
    _assert_allowed(item, q=4, allowed_keys={0, 1, 2, 3, 4}, length=length)

    _assert_allowed(item, q=5, allowed_keys={2, 3}, length=length)
    _assert_allowed(item, q=6, allowed_keys={2, 3}, length=length)

    expected = torch.tensor(
        [
            [0, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
            [0, 0, NEG_INF, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
            [0, 0, 0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
            [0, 0, 0, 0, NEG_INF, NEG_INF, NEG_INF],
            [0, 0, 0, 0, 0, NEG_INF, NEG_INF],
            [NEG_INF, NEG_INF, 0, 0, NEG_INF, NEG_INF, NEG_INF],
            [NEG_INF, NEG_INF, 0, 0, NEG_INF, NEG_INF, NEG_INF],
        ],
        dtype=mask.dtype,
    )
    assert torch.equal(item, expected), item

    attention_mask = torch.ones((1, length), dtype=torch.long)
    final_mask = combine_with_padding_mask(mask, attention_mask)
    assert final_mask.shape == (1, 1, length, length)
    assert torch.equal(final_mask[0, 0], item)

    padded_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0]], dtype=torch.long)
    padded_final = combine_with_padding_mask(mask, padded_attention_mask)
    assert float(padded_final[0, 0, 0, 5].item()) <= NEG_INF
    assert float(padded_final[0, 0, 3, 6].item()) <= NEG_INF
    assert float(padded_final[0, 0, 5, 2].item()) == 0.0

    batch_mask = build_bec_mask(
        batch_B=torch.tensor([1, 2]),
        batch_E=torch.tensor([1, 2]),
        batch_C=torch.tensor([1, 2]),
        batch_windows=[
            [(2, 2, 1, 1)],
            [(4, 5, 2, 3)],
        ],
    )
    assert batch_mask.shape == (2, 6, 6)
    _assert_allowed(batch_mask[0], q=0, allowed_keys={0}, length=3)
    _assert_allowed(batch_mask[0], q=1, allowed_keys={0, 1}, length=3)
    _assert_allowed(batch_mask[0], q=2, allowed_keys={1}, length=3)
    assert torch.equal(batch_mask[0, 3:, :], torch.zeros_like(batch_mask[0, 3:, :]))
    assert torch.equal(batch_mask[0, :, 3:], torch.zeros_like(batch_mask[0, :, 3:]))
    _assert_allowed(batch_mask[1], q=0, allowed_keys={0}, length=6)
    _assert_allowed(batch_mask[1], q=1, allowed_keys={0, 1}, length=6)
    _assert_allowed(batch_mask[1], q=2, allowed_keys={0, 1, 2}, length=6)
    _assert_allowed(batch_mask[1], q=3, allowed_keys={0, 1, 2, 3}, length=6)
    _assert_allowed(batch_mask[1], q=4, allowed_keys={2, 3}, length=6)
    _assert_allowed(batch_mask[1], q=5, allowed_keys={2, 3}, length=6)

    _assert_raises_value_error(lambda: build_bec_mask([2], [3], [2], [[(4, 5, 2, 3)]]))
    _assert_raises_value_error(lambda: build_bec_mask([2], [3], [2], [[(5, 6, 1, 3)]]))
    _assert_raises_value_error(lambda: build_bec_mask([2], [3], [2], [[(5, 6, 2, 5)]]))

    print("BEC mask sanity check passed.")
