import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import xformers.ops as xops


class SDPAttention(nn.Module):
    """Wrapper around torch.nn.MultiheadAttention"""

    def __init__(self, dim, heads=8, dropout=0.0, use_memory_efficient_attn=True):
        super().__init__()

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_k = nn.Linear(dim, dim, bias=False)
        self.to_v = nn.Linear(dim, dim, bias=False)
        self.to_out = nn.Linear(dim, dim, bias=False)

        self.heads = heads
        self.dropout = dropout
        self.use_memory_efficient_attn = use_memory_efficient_attn

    def forward(self, x, context=None, context_mask=None):
        q = self.norm_q(x)
        k = self.norm_k(context if context is not None else x)
        v = self.norm_v(context if context is not None else x)

        q = self.to_q(q)
        k = self.to_k(k)
        v = self.to_v(v)

        if self.use_memory_efficient_attn:
            # Cast tensors to float16 before attention computation
            q = q.half()
            k = k.half()
            v = v.half()

            query = rearrange(q, "b n (h d) -> b n h d", h=self.heads)
            key = rearrange(k, "b n (h d) -> b n h d", h=self.heads)
            value = rearrange(v, "b n (h d) -> b n h d", h=self.heads)

            attn_mask = context_mask

            attn_mask = (
                repeat(attn_mask, "b m -> b h n m", h=self.heads, n=query.size(1))
                if attn_mask is not None
                else None
            )
            attn_bias = (
                attn_mask.float().masked_fill(attn_mask, float("-inf"))
                if attn_mask is not None
                else None
            )

            out = xops.memory_efficient_attention(
                query,
                key,
                value,
                attn_bias=attn_bias,
                p=self.dropout if self.training else 0.0,
            )

            out = rearrange(out, "b n h d -> b n (h d)")

        else:

            q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
            k = rearrange(k, "b n (h d) -> b h n d", h=self.heads)
            v = rearrange(v, "b n (h d) -> b h n d", h=self.heads)

            attn_mask = None
            if context_mask is not None:
                attn_mask = rearrange(context_mask, "b n -> b () () n")

            out = F.scaled_dot_product_attention(
                query=q,
                key=k,
                value=v,
                attn_mask=attn_mask,
                dropout_p=self.dropout if self.training else 0.0,
            )

            out = rearrange(out, "b h n d -> b n (h d)", h=self.heads)
            out = self.to_out(out)

        return out


class Memoeryeff_Attention(nn.Module):
    """Wrapper around torch.nn.MultiheadAttention"""

    def __init__(self, dim, q_dim=None, v_dim=None, heads=8, dropout=0.0):
        super().__init__()

        if q_dim is not None:
            self.norm_q = nn.LayerNorm(q_dim)
            self.norm_k = nn.LayerNorm(dim)
            self.norm_v = nn.LayerNorm(dim)
            self.to_q = nn.Linear(q_dim, dim, bias=False)
            self.to_k = nn.Linear(dim, dim, bias=False)
            self.to_v = nn.Linear(dim, dim, bias=False)
            self.to_out = nn.Linear(dim, dim, bias=False)
        else:

            if v_dim is not None:
                self.norm_q = nn.LayerNorm(v_dim)
                self.norm_k = nn.LayerNorm(dim)
                self.norm_v = nn.LayerNorm(dim)
                self.to_q = nn.Linear(v_dim, dim, bias=False)
                self.to_k = nn.Linear(dim, dim, bias=False)
                self.to_v = nn.Linear(dim, dim, bias=False)
                self.to_out = nn.Linear(dim, dim, bias=False)
            else:
                self.norm_q = nn.LayerNorm(dim)
                self.norm_k = nn.LayerNorm(dim)
                self.norm_v = nn.LayerNorm(dim)

                self.to_q = nn.Linear(dim, dim, bias=False)
                self.to_k = nn.Linear(dim, dim, bias=False)
                self.to_v = nn.Linear(dim, dim, bias=False)
                self.to_out = nn.Linear(dim, dim, bias=False)

        self.heads = heads
        self.dropout = dropout

    def forward(self, x, context=None, context_mask=None):
        q = self.norm_q(x)
        k = self.norm_k(context if context is not None else x)
        v = self.norm_v(context if context is not None else x)
        q = self.to_q(q)
        k = self.to_k(k)
        v = self.to_v(v)

        q_seqlen = context_mask[0]
        if len(context_mask) == 1:
            kv_seqlen = context_mask[0]
        else:
            kv_seqlen = context_mask[1]
        # Cast tensors to float16 before attention computation
        q = q.half()
        k = k.half()
        v = v.half()
        out = mem_efficient_attn_seq(
            query=q,
            key=k,
            value=v,
            q_seqlen=q_seqlen,
            kv_seqlen=kv_seqlen,
            num_heads=self.heads,
            dropout_p=self.dropout if self.training else 0.0,
        )
        out = out.float()
        out = self.to_out(out)
        return out


def mem_efficient_attn_seq(
    query,
    key,
    value,
    q_seqlen,
    kv_seqlen,
    num_heads: int,
    dropout_p: float,
):

    # xformers attention expects shape (1, n, h, d)
    query = rearrange(query, "n (h d) -> () n h d", h=num_heads)
    key = rearrange(key, "n (h d) -> () n h d", h=num_heads)
    value = rearrange(value, "n (h d) -> () n h d", h=num_heads)

    if isinstance(q_seqlen, torch.Tensor):
        q_seqlen = q_seqlen.tolist()
    if isinstance(kv_seqlen, torch.Tensor):
        kv_seqlen = kv_seqlen.tolist()

    # fill attention_bias with BlockDiagonalMask
    with torch.no_grad():
        attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(
            q_seqlen=q_seqlen,
            kv_seqlen=kv_seqlen,
        )

    out = xops.memory_efficient_attention(
        query,
        key,
        value,
        attn_bias=attn_bias,
        p=dropout_p,
    )

    out = rearrange(out, "() n h d -> n (h d)")
    out = out.float()
    return out
