from typing import Optional
from functools import partial

import torch

from transformers.cache_utils import Cache, DynamicCache
from torch.nn.attention.flex_attention import (
    _mask_mod_signature,
    BlockMask,
    create_block_mask,
)


def get_am(mask, b, idx):
    return torch.where(idx < mask.shape[-1], mask[b, idx], True)


def causal_mask_fn(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


def causal_attention_mask_fn(b, h, q_idx, kv_idx, mask):
    return ((q_idx >= kv_idx) & get_am(mask, b, q_idx) & get_am(mask, b, kv_idx))


def get_mask_mod_w_offset(mask_mod: _mask_mod_signature, _offset: torch.Tensor):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + _offset, kv)
    return _mask_mod


class AttentionMask:
    def __init__(
        self, 
        attention_mask: Optional[torch.Tensor],
        target_len: Optional[int] = None,
        max_seq_len: int = 12288,
        offset: Optional[torch.Tensor] = None,
        device: str = "cuda",
    ):
        self.target_len = target_len or max_seq_len
        self.max_seq_len = max_seq_len
        self.offset = torch.tensor(offset, dtype=torch.int64) if isinstance(offset, int) else offset

        if attention_mask is not None:
            self.attention_mask = attention_mask.bool()

        causal_mask_fn = causal_mask if attention_mask is None else partial(causal_attention_mask, mask=self.attention_mask)

        if self.offset is not None:
            causal_mask_fn = get_mask_mod_w_offset(causal_mask_fn, self.offset)
            self.block_mask = create_block_mask(causal_mask_fn, None, None, self.target_len, self.max_seq_len, device=device, _compile=True)
        else:
            # this is prefilling, so we don't need to worry about the offset
            self.block_mask = create_block_mask(causal_mask_fn, None, None, self.target_len, self.target_len, device=device, _compile=True)

    def get_block_mask(self, input_pos=None):
        if input_pos is None:
            return self.block_mask
        else:
            # can only handle 1 input position at a time
            assert input_pos.shape[-1] == 1
            block_index = input_pos // self.block_mask.BLOCK_SIZE[0]
            mask = self.block_mask[:, :, block_index]
            mask.mask_mod = self.block_mask.mask_mod
            mask.seq_lengths = (1, self.L)
            return mask
