from typing import List, Literal, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, repeat
import xformers.ops as xops

MEM_EFF_OP = {
    torch.bfloat16: (xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
    torch.float32: (xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp),
}
# xformers 0.0.29.post1 errors out if op is left as None
# (if using variable seqlen)

ATTN_TYPE = Literal["self", "cross"]
SEPARABLE_TENSOR = Tuple[Tensor, Tensor]


class SDPAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        q_dim: Optional[int] = None,
        kv_dim: Optional[int] = None,
        heads: int = 8,
        dropout: float = 0.0,
        bias: bool = True,
        attn_type: ATTN_TYPE = "self",
    ):
        super().__init__()

        q_dim, kv_dim = q_dim or dim, kv_dim or dim
        assert dim % heads == 0

        if attn_type == "self":
            assert q_dim == kv_dim
            self.to_qkv = nn.Linear(q_dim, 3 * dim, bias=bias)
        else:
            self.to_q = nn.Linear(q_dim, dim, bias=bias)
            self.to_kv = nn.Linear(kv_dim, 2 * dim, bias=bias)

        self.to_out = nn.Linear(dim, dim, bias=bias)
        self.heads = heads
        self.dropout = dropout
        self.attn_type = attn_type

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(
        self,
        x: Tensor,
        context: Optional[Tensor] = None,
        seqlen_x: Optional[List[int]] = None,
        seqlen_context: Optional[List[int]] = None,
    ) -> Tensor:
        r"""
        For `mask`: True means context token is involved in attention.
        """
        if self.attn_type == "self":
            assert context is None
            q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        else:
            q = self.to_q(x)
            k, v = self.to_kv(context).chunk(2, dim=-1)

        mask = None
        input_type = "padded" if seqlen_x is None else "chained"
        if input_type == "chained":
            q, k, v = q[None, ...], k[None, ...], v[None, ...]
            mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(seqlen_x, seqlen_context)

        assert q.dim() == k.dim() == v.dim() == 3
        q = rearrange(q, "b n (h d) -> b n h d", h=self.heads)
        k = rearrange(k, "b n (h d) -> b n h d", h=self.heads)
        v = rearrange(v, "b n (h d) -> b n h d", h=self.heads)
        scale = 1.0 / q.size(-1) ** 0.5

        if (remainder := q.size(-1) % 8) != 0:
            # xformers likes having last dimensions that are multiples of 8
            padding = 8 - remainder
            q = F.pad(q, (0, padding))
            k = F.pad(k, (0, padding))
            v = F.pad(v, (0, padding))
        out = xops.memory_efficient_attention(
            query=q,
            key=k,
            value=v,
            attn_bias=mask,
            p=self.dropout if self.training else 0.0,
            scale=scale,
            op=MEM_EFF_OP[q.dtype],
        )

        if remainder != 0:
            out = out[:, :, :, : q.size(-1) - padding]

        out = rearrange(out, "b n h d -> b n (h d)", h=self.heads)

        if input_type == "chained":
            out = out[0]

        out = self.to_out(out)
        return out
