import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import einops
from einops import rearrange, repeat
from inspect import isfunction
from .rotary import RotaryEmbedding
from .modules import RMSNorm

if hasattr(nn.functional, 'scaled_dot_product_attention'):
    ATTENTION_MODE = 'flash'
else:
    ATTENTION_MODE = 'math'
print(f'attention mode is {ATTENTION_MODE}')


def add_mask(sim, mask):
    b, ndim = sim.shape[0], mask.ndim
    if ndim == 3:
        mask = rearrange(mask, "b n m -> b 1 n m")
    if ndim == 2:
        mask = repeat(mask, "n m -> b 1 n m", b=b)
    max_neg_value = -torch.finfo(sim.dtype).max
    sim = sim.masked_fill(~mask, max_neg_value)
    return sim


def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
    def default(val, d):
        return val if val is not None else (d() if isfunction(d) else d)

    b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
    q_mask = default(
        q_mask, torch.ones((b, i), device=device, dtype=torch.bool)
    )
    k_mask = default(
        k_mask, torch.ones((b, j), device=device, dtype=torch.bool)
    )
    attn_mask = rearrange(q_mask, 'b i -> b 1 i 1'
                         ) * rearrange(k_mask, 'b j -> b 1 1 j')
    return attn_mask


class Attention(nn.Module):
    def __init__(
        self,
        dim,
        context_dim=None,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        qk_norm=None,
        attn_drop=0.,
        proj_drop=0.,
        rope_mode='none'
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        if context_dim is None:
            self.cross_attn = False
        else:
            self.cross_attn = True

        context_dim = dim if context_dim is None else context_dim

        self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
        self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
        self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)

        if qk_norm is None:
            self.norm_q = nn.Identity()
            self.norm_k = nn.Identity()
        elif qk_norm == 'layernorm':
            self.norm_q = nn.LayerNorm(head_dim)
            self.norm_k = nn.LayerNorm(head_dim)
        elif qk_norm == 'rmsnorm':
            self.norm_q = RMSNorm(head_dim)
            self.norm_k = RMSNorm(head_dim)
        else:
            raise NotImplementedError

        self.attn_drop_p = attn_drop
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        if self.cross_attn:
            assert rope_mode == 'none'
        self.rope_mode = rope_mode
        if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
            self.rotary = RotaryEmbedding(dim=head_dim)
        elif self.rope_mode == 'dual':
            self.rotary_x = RotaryEmbedding(dim=head_dim)
            self.rotary_c = RotaryEmbedding(dim=head_dim)

    def _rotary(self, q, k, extras):
        if self.rope_mode == 'shared':
            q, k = self.rotary(q=q, k=k)
        elif self.rope_mode == 'x_only':
            q_x, k_x = self.rotary(
                q=q[:, :, extras:, :], k=k[:, :, extras:, :]
            )
            q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
            q = torch.cat((q_c, q_x), dim=2)
            k = torch.cat((k_c, k_x), dim=2)
        elif self.rope_mode == 'dual':
            q_x, k_x = self.rotary_x(
                q=q[:, :, extras:, :], k=k[:, :, extras:, :]
            )
            q_c, k_c = self.rotary_c(
                q=q[:, :, :extras, :], k=k[:, :, :extras, :]
            )
            q = torch.cat((q_c, q_x), dim=2)
            k = torch.cat((k_c, k_x), dim=2)
        elif self.rope_mode == 'none':
            pass
        else:
            raise NotImplementedError
        return q, k

    def _attn(self, q, k, v, mask_binary):
        if ATTENTION_MODE == 'flash':
            x = F.scaled_dot_product_attention(
                q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
            )
            x = einops.rearrange(x, 'B H L D -> B L (H D)')
        elif ATTENTION_MODE == 'math':
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = add_mask(
                attn, mask_binary
            ) if mask_binary is not None else attn
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2)
            x = einops.rearrange(x, 'B H L D -> B L (H D)')
        else:
            raise NotImplementedError
        return x

    def forward(self, x, context=None, context_mask=None, extras=0):
        B, L, C = x.shape
        if context is None:
            context = x

        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)

        if context_mask is not None:
            mask_binary = create_mask(
                x.shape, context.shape, x.device, None, context_mask
            )
        else:
            mask_binary = None

        q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
        k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
        v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)

        q = self.norm_q(q)
        k = self.norm_k(k)

        q, k = self._rotary(q, k, extras)

        x = self._attn(q, k, v, mask_binary)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class JointAttention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        qk_norm=None,
        attn_drop=0.,
        proj_drop=0.,
        rope_mode='none'
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(
            dim, qkv_bias
        )
        self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(
            dim, qkv_bias
        )

        self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
        self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)

        self.attn_drop_p = attn_drop
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj_x = nn.Linear(dim, dim)
        self.proj_drop_x = nn.Dropout(proj_drop)

        self.proj_c = nn.Linear(dim, dim)
        self.proj_drop_c = nn.Dropout(proj_drop)

        self.rope_mode = rope_mode
        if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
            self.rotary = RotaryEmbedding(dim=head_dim)
        elif self.rope_mode == 'dual':
            self.rotary_x = RotaryEmbedding(dim=head_dim)
            self.rotary_c = RotaryEmbedding(dim=head_dim)

    def _make_qkv_layers(self, dim, qkv_bias):
        return (
            nn.Linear(dim, dim,
                      bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias),
            nn.Linear(dim, dim, bias=qkv_bias)
        )

    def _make_norm_layers(self, qk_norm, head_dim):
        if qk_norm is None:
            norm_q = nn.Identity()
            norm_k = nn.Identity()
        elif qk_norm == 'layernorm':
            norm_q = nn.LayerNorm(head_dim)
            norm_k = nn.LayerNorm(head_dim)
        elif qk_norm == 'rmsnorm':
            norm_q = RMSNorm(head_dim)
            norm_k = RMSNorm(head_dim)
        else:
            raise NotImplementedError
        return norm_q, norm_k

    def _rotary(self, q, k, extras):
        if self.rope_mode == 'shared':
            q, k = self.rotary(q=q, k=k)
        elif self.rope_mode == 'x_only':
            q_x, k_x = self.rotary(
                q=q[:, :, extras:, :], k=k[:, :, extras:, :]
            )
            q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
            q = torch.cat((q_c, q_x), dim=2)
            k = torch.cat((k_c, k_x), dim=2)
        elif self.rope_mode == 'dual':
            q_x, k_x = self.rotary_x(
                q=q[:, :, extras:, :], k=k[:, :, extras:, :]
            )
            q_c, k_c = self.rotary_c(
                q=q[:, :, :extras, :], k=k[:, :, :extras, :]
            )
            q = torch.cat((q_c, q_x), dim=2)
            k = torch.cat((k_c, k_x), dim=2)
        elif self.rope_mode == 'none':
            pass
        else:
            raise NotImplementedError
        return q, k

    def _attn(self, q, k, v, mask_binary):
        if ATTENTION_MODE == 'flash':
            x = F.scaled_dot_product_attention(
                q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
            )
            x = einops.rearrange(x, 'B H L D -> B L (H D)')
        elif ATTENTION_MODE == 'math':
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = add_mask(
                attn, mask_binary
            ) if mask_binary is not None else attn
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2)
            x = einops.rearrange(x, 'B H L D -> B L (H D)')
        else:
            raise NotImplementedError
        return x

    def _cat_mask(self, x, context, x_mask=None, context_mask=None):
        B = x.shape[0]
        if x_mask is None:
            x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
        if context_mask is None:
            context_mask = torch.ones(
                B, context.shape[-2], device=context.device
            ).bool()
        mask = torch.cat([context_mask, x_mask], dim=1)
        return mask

    def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
        B, Lx, C = x.shape
        _, Lc, _ = context.shape
        if x_mask is not None or context_mask is not None:
            mask = self._cat_mask(
                x, context, x_mask=x_mask, context_mask=context_mask
            )
            shape = [B, Lx + Lc, C]
            mask_binary = create_mask(
                q_shape=shape,
                k_shape=shape,
                device=x.device,
                q_mask=None,
                k_mask=mask
            )
        else:
            mask_binary = None

        qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
        qc, kc, vc = self.to_qc(context), self.to_kc(context
                                                    ), self.to_vc(context)

        qx, kx, vx = map(
            lambda t: einops.
            rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
            [qx, kx, vx]
        )
        qc, kc, vc = map(
            lambda t: einops.
            rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
            [qc, kc, vc]
        )

        qx, kx = self.norm_qx(qx), self.norm_kx(kx)
        qc, kc = self.norm_qc(qc), self.norm_kc(kc)

        q, k, v = (
            torch.cat([qc, qx],
                      dim=2), torch.cat([kc, kx],
                                        dim=2), torch.cat([vc, vx], dim=2)
        )

        q, k = self._rotary(q, k, extras)

        x = self._attn(q, k, v, mask_binary)

        context, x = x[:, :Lc, :], x[:, Lc:, :]

        x = self.proj_x(x)
        x = self.proj_drop_x(x)

        context = self.proj_c(context)
        context = self.proj_drop_c(context)

        return x, context
