import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat


class SDPAttention(nn.Module):
    """Wrapper around torch.nn.MultiheadAttention"""

    def __init__(self, dim, heads=8, dropout=0.0, use_memory_efficient_attn=True):
        super().__init__()

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_k = nn.Linear(dim, dim, bias=False)
        self.to_v = nn.Linear(dim, dim, bias=False)
        self.to_out = nn.Linear(dim, dim, bias=False)

        self.heads = heads
        self.dropout = dropout
        self.use_memory_efficient_attn = use_memory_efficient_attn

    def forward(self, x, context=None, context_mask=None):
        q = self.norm_q(x)
        k = self.norm_k(context if context is not None else x)
        v = self.norm_v(context if context is not None else x)

        q = self.to_q(q)
        k = self.to_k(k)
        v = self.to_v(v)

        if self.use_memory_efficient_attn:
            # Cast tensors to float16 before attention computation
            q = q.half()
            k = k.half()
            v = v.half()

            query = rearrange(q, "b n (h d) -> b n h d", h=self.heads)
            key = rearrange(k, "b n (h d) -> b n h d", h=self.heads)
            value = rearrange(v, "b n (h d) -> b n h d", h=self.heads)

            attn_mask = context_mask

            attn_mask = (
                repeat(attn_mask, "b m -> b h n m", h=self.heads, n=query.size(1))
                if attn_mask is not None
                else None
            )
            attn_bias = (
                attn_mask.float().masked_fill(attn_mask, float("-inf"))
                if attn_mask is not None
                else None
            )

            out = xops.memory_efficient_attention(
                query,
                key,
                value,
                attn_bias=attn_bias,
                p=self.dropout if self.training else 0.0,
            )

            out = rearrange(out, "b n h d -> b n (h d)")

        else:

            q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
            k = rearrange(k, "b n (h d) -> b h n d", h=self.heads)
            v = rearrange(v, "b n (h d) -> b h n d", h=self.heads)

            attn_mask = None
            if context_mask is not None:
                attn_mask = rearrange(context_mask, "b n -> b () () n")

            out = F.scaled_dot_product_attention(
                query=q,
                key=k,
                value=v,
                attn_mask=attn_mask,
                dropout_p=self.dropout if self.training else 0.0,
            )

            out = rearrange(out, "b h n d -> b n (h d)", h=self.heads)
            out = self.to_out(out)

        return out


class Memoeryeff_Attention(nn.Module):
    """Wrapper around torch.nn.MultiheadAttention"""

    def __init__(self, dim, q_dim=None, v_dim=None, heads=8, dropout=0.0):
        super().__init__()

        if q_dim is not None:
            self.norm_q = nn.LayerNorm(q_dim)
            self.norm_k = nn.LayerNorm(dim)
            self.norm_v = nn.LayerNorm(dim)
            self.to_q = nn.Linear(q_dim, dim, bias=False)
            self.to_k = nn.Linear(dim, dim, bias=False)
            self.to_v = nn.Linear(dim, dim, bias=False)
            self.to_out = nn.Linear(dim, dim, bias=False)
        else:

            if v_dim is not None:
                self.norm_q = nn.LayerNorm(v_dim)
                self.norm_k = nn.LayerNorm(dim)
                self.norm_v = nn.LayerNorm(dim)
                self.to_q = nn.Linear(v_dim, dim, bias=False)
                self.to_k = nn.Linear(dim, dim, bias=False)
                self.to_v = nn.Linear(dim, dim, bias=False)
                self.to_out = nn.Linear(dim, dim, bias=False)
            else:
                self.norm_q = nn.LayerNorm(dim)
                self.norm_k = nn.LayerNorm(dim)
                self.norm_v = nn.LayerNorm(dim)

                self.to_q = nn.Linear(dim, dim, bias=False)
                self.to_k = nn.Linear(dim, dim, bias=False)
                self.to_v = nn.Linear(dim, dim, bias=False)
                self.to_out = nn.Linear(dim, dim, bias=False)

        self.heads = heads
        self.dropout = dropout

    def forward(self, x, context=None, context_mask=None):
        q = self.norm_q(x)
        k = self.norm_k(context if context is not None else x)
        v = self.norm_v(context if context is not None else x)
        q = self.to_q(q)
        k = self.to_k(k)
        v = self.to_v(v)

        q_seqlen = context_mask[0]
        if len(context_mask) == 1:
            kv_seqlen = context_mask[0]
        else:
            kv_seqlen = context_mask[1]
        # Cast tensors to float16 before attention computation
        q = q.half()
        k = k.half()
        v = v.half()
        out = mem_efficient_rotary_attn_seq(
            query=q,
            key=k,
            value=v,
            q_seqlen=q_seqlen,
            kv_seqlen=kv_seqlen,
            num_heads=self.heads,
            dropout_p=self.dropout if self.training else 0.0,
        )
        out = out.float()
        # out = rearrange(out, "n (h d)-> b n (h d)", h=self.heads)
        out = self.to_out(out)
        return out


import xformers.ops as xops


def mem_efficient_rotary_attn_seq(
    query,
    key,
    value,
    q_seqlen,
    kv_seqlen,
    num_heads: int,
    dropout_p: float,
):

    # xformers attention expects shape (1, n, h, d)
    query = rearrange(query, "n (h d) -> () n h d", h=num_heads)
    key = rearrange(key, "n (h d) -> () n h d", h=num_heads)
    value = rearrange(value, "n (h d) -> () n h d", h=num_heads)

    if isinstance(q_seqlen, torch.Tensor):
        q_seqlen = q_seqlen.tolist()
    if isinstance(kv_seqlen, torch.Tensor):
        kv_seqlen = kv_seqlen.tolist()

    # fill attention_bias with BlockDiagonalMask
    with torch.no_grad():
        attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(
            q_seqlen=q_seqlen,
            kv_seqlen=kv_seqlen,
        )

    out = xops.memory_efficient_attention(
        query,
        key,
        value,
        attn_bias=attn_bias,
        p=dropout_p,
    )

    out = rearrange(out, "() n h d -> n (h d)")
    out = out.float()
    return out


class ElementwiseVectorMul(nn.Module):
    def __init__(
        self,
        feature_dim,
        heads,
        bias=True,
        weight_init_std=1.0,
        bias_init_std=0.02,
        device=None,
        dtype=None,
    ):
        super().__init__()
        factory_kwargs = {"device": device, "dtype": dtype}

        self.dim = feature_dim
        self.heads = heads
        self.weight_init_std = weight_init_std
        self.bias_init_std = bias_init_std

        self.weight = nn.Parameter(torch.empty((heads, feature_dim), **factory_kwargs))

        if bias:
            self.bias = nn.Parameter(
                torch.empty((heads, feature_dim), **factory_kwargs)
            )
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Normal distribution should work since its only elementwise
        # Goal is: input stdev = output stdev
        nn.init.normal_(self.weight, mean=0, std=self.weight_init_std)
        if self.bias is not None:
            nn.init.normal_(self.bias, mean=0, std=self.bias_init_std)

    def forward(self, x):
        """
        Args:
            x: input tensor with shape (..., feature_dim)

        Returns:
            output tensor with shape (..., heads, feature_dim)
        """
        x = x[..., None, :]

        if self.bias is not None:
            ans = x * self.weight + self.bias
        else:
            ans = x * self.weight

        return ans

    def extra_repr(self) -> str:
        return "feature_dim={}, heads={}, bias={}".format(
            self.feature_dim, self.heads, self.bias is not None
        )


class GraphInvariantAttention(nn.Module):
    """Attention block with biases with graph position invariances"""

    def __init__(self, dim_feat, dim_pos, heads=8, dropout=0.0):
        super().__init__()

        self.norm_q = nn.LayerNorm(dim_feat)
        self.norm_k = nn.LayerNorm(dim_feat)
        self.norm_v = nn.LayerNorm(dim_feat)

        self.norm_pq = nn.LayerNorm(dim_pos)
        self.norm_pk = nn.LayerNorm(dim_pos)

        # for features
        self.to_qx = nn.Linear(dim_feat, dim_feat, bias=False)
        self.to_kx = nn.Linear(dim_feat, dim_feat, bias=False)
        self.to_vx = nn.Linear(dim_feat, dim_feat, bias=False)

        # for position
        self.to_qp = ElementwiseVectorMul(dim_pos, heads, bias=False)
        self.to_kp = ElementwiseVectorMul(dim_pos, heads, bias=False)

        # output projections and combination weights
        self.head_weights = nn.Linear(heads, 1, bias=False)
        self.to_outx = nn.Linear(dim_feat, dim_feat, bias=False)

        self.heads = heads
        self.dropout = dropout

    def forward(self, x, p_x, kv=None, p_kv=None, kv_mask=None, ignore_pos=False):

        q_feat = self.norm_q(x)
        k_feat = self.norm_k(kv if kv is not None else x)
        v_feat = self.norm_v(kv if kv is not None else x)
        q_feat = self.to_qx(q_feat)
        k_feat = self.to_kx(k_feat)
        v_feat = self.to_vx(v_feat)

        q_feat = rearrange(q_feat, "b n (h d) -> b h n d", h=self.heads)
        k_feat = rearrange(k_feat, "b n (h d) -> b h n d", h=self.heads)
        v_feat = rearrange(v_feat, "b n (h d) -> b h n d", h=self.heads)

        q_pos = self.norm_pq(p_x)
        kv_pos = self.norm_pk(p_kv if p_kv is not None else p_x)
        q_pos = self.to_qp(q_pos)  # (b, n, dpos) -> (h, b, n, dpos)
        k_pos = self.to_kp(kv_pos)  # (b, n, dpos) -> (h, b, n, dpos)

        q_pos = rearrange(q_pos, "b n h dpos -> b h n dpos")
        k_pos = rearrange(k_pos, "b n h dpos -> b h n dpos")
        v_pos = repeat(kv_pos, "b n dpos -> b h n dpos", h=self.heads)

        assert (
            k_pos.shape == v_pos.shape
        ), f"k_pos shape: {k_pos.shape}, v_pos shape: {v_pos.shape}"
        assert q_pos.shape[0] == k_pos.shape[0]
        assert q_pos.shape[1] == k_pos.shape[1]
        assert q_pos.shape[3] == k_pos.shape[3]

        q = q_feat
        k = k_feat
        v = v_feat
        if not ignore_pos:
            q = torch.cat([q, q_pos], dim=-1)
            k = torch.cat([k, k_pos], dim=-1)
            v = torch.cat([v, v_pos], dim=-1)

        attn_mask = None
        if kv_mask is not None:
            attn_mask = rearrange(kv_mask, "b n -> b () () n")

        out = F.scaled_dot_product_attention(
            query=q,
            key=k,
            value=v,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
        )

        if ignore_pos:
            # Treat it like SDP attention
            out_feat = rearrange(out, "b h n d -> b n (h d)", h=self.heads)
            out = self.to_outx(out_feat)
            return out, None

        # Apply head combination weights (head dim is 1)
        combination_weight = torch.ones(self.heads, device=x.device)
        combination_weight += self.head_weights.weight.flatten()  # (h,)
        combination_weight = combination_weight[None, :, None, None]  # (1, h, 1, 1)
        out = out * combination_weight

        out_feat = out[..., : q_feat.shape[-1]]  # b h n dfeat
        out_feat = rearrange(out_feat, "b h n d -> b n (h d)", h=self.heads)
        out_feat = self.to_outx(out_feat)

        out_pos = out[..., q_feat.shape[-1] :]  # b h n dpos
        out_pos = out_pos.sum(dim=1)  # Reduce along head dimension

        return out_feat, out_pos
