import os

import torch
from torch import nn
import torch.nn.functional as F
from pydoc import locate
from einops import rearrange
from jaxtyping import Float

from ..layers import (
    AdaRMSNorm,
    FeedForwardBlock,
    RMSNorm,
    scale_for_cosine_sim,
    scale_for_cosine_sim_qkv,
    scale_for_cosine_sim_kv,
    apply_wd,
    checkpoint,
    zero_init,
)

# ===============================================


try:
    import flash_attn
except ImportError:
    flash_attn = None


def get_use_flash_attention():
    return os.environ.get("USE_FLASH", "1") == "1"


def use_flash(x):
    if not get_use_flash_attention():
        return False
    if flash_attn is None:
        return False
    if x.device.type != "cuda":
        return False
    if x.dtype not in (torch.float16, torch.bfloat16):
        return False
    return True


class GenericTransformerLayer(nn.Module):
    def __init__(
        self,
        d_model,
        pos_enc_cls,
        pos_enc_params={},
        self_attn_params={},
        ffn_params={},
        d_head=64,
        d_cond_norm=None,
        dropout=0.0,
        ff_expand=3,
    ):
        super().__init__()
        d_ff = d_model * ff_expand

        self.self_attn = GenericAttentionBlock(
            d_model,
            pos_enc_cls,
            pos_enc_params,
            d_head,
            d_cond_norm,
            dropout,
            **self_attn_params,
        )
        self.ff = FeedForwardBlock(d_model, d_ff, d_cond_norm, dropout, **ffn_params)

    def forward(self, x, pos, **kwargs):
        x = checkpoint(self.self_attn, x, pos, **kwargs)
        x = checkpoint(self.ff, x, **kwargs)
        return x


class GenericCrossTransformerLayer(nn.Module):
    def __init__(
        self,
        d_model,
        d_cross,
        pos_enc_cls,
        cross_pos_enc_cls,
        pos_enc_params={},
        cross_pos_enc_params={},
        self_attn_params={},
        ffn_params={},
        cross_attn_params={},
        d_head=64,
        d_cond_norm=None,
        dropout=0.0,
        ff_expand=3,
    ):
        super().__init__()
        d_ff = d_model * ff_expand

        self.self_attn = GenericAttentionBlock(
            d_model,
            pos_enc_cls,
            pos_enc_params,
            d_head,
            d_cond_norm,
            dropout,
            **self_attn_params,
        )
        self.cross_attn = GenericCrossAttentionBlock(
            d_model,
            d_cross,
            cross_pos_enc_cls,
            cross_pos_enc_params,
            d_head,
            d_cond_norm,
            dropout,
            **cross_attn_params,
        )
        self.ff = FeedForwardBlock(d_model, d_ff, d_cond_norm, dropout, **ffn_params)

    def forward(self, x: Float[torch.Tensor, "B L C"], pos, cross_mask=None, **kwargs):
        x = checkpoint(self.self_attn, x, pos, **kwargs)
        x_skip = x
        x = checkpoint(self.cross_attn, x, pos, **kwargs)

        if cross_mask is not None:
            kwargs["check_dict"]["cross_mask"] = True
            x = x_skip * (1 - cross_mask[:, None, None]) + x * cross_mask[:, None, None]

        x = checkpoint(self.ff, x, **kwargs)
        return x


class GenericCrossAttentionBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_cross: int,
        pos_enc_cls,
        pos_enc_params,
        d_head: int = 64,
        d_cond_norm: int | None = None,
        dropout: float = 0.0,
        use_flash: bool = False,
    ):
        super().__init__()
        self.d_head = d_head
        self.n_heads = d_model // d_head
        if d_cond_norm is not None:
            self.norm = AdaRMSNorm(d_model, d_cond_norm)
        else:
            self.norm = RMSNorm(d_model)
        self.norm_cross = RMSNorm(d_cross)

        self.q_proj = apply_wd(nn.Linear(d_model, d_model, bias=False))
        self.kv_proj = apply_wd(nn.Linear(d_cross, d_model * 2, bias=False))
        self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
        self.pos_emb = locate(pos_enc_cls)(d_head, self.n_heads, **pos_enc_params)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False)))

        self.use_flash = use_flash
        if not (self.use_flash and get_use_flash_attention()):
            self.forward = torch.compile(self.forward)

    def extra_repr(self):
        return f"d_head={self.d_head},"

    def forward(
        self,
        x: Float[torch.Tensor, "b l d"],
        pos: Float[torch.Tensor, "b l 2"],
        x_cross: Float[torch.Tensor, "b l' d'"],
        pos_cross: Float[torch.Tensor, "b l' n"],
        check_dict: dict[str, bool],
        cond_norm: Float[torch.Tensor, "b d"] | None = None,
        **kwargs,
    ) -> Float[torch.Tensor, "b l d"]:
        check_dict["x_cross"] = True
        check_dict["pos_cross"] = True

        skip = x
        if cond_norm is not None:
            x = self.norm(x, cond_norm)
            check_dict["cond_norm"] = True
        else:
            x = self.norm(x)
        x_cross = self.norm_cross(x_cross)
        q = self.q_proj(x)
        kv = self.kv_proj(x_cross)

        pos = pos.to(q.dtype)
        pos_cross = pos_cross.to(q.dtype)
        theta = self.pos_emb(pos)
        theta_cross = self.pos_emb(pos_cross)

        if self.use_flash and use_flash(q):
            q = rearrange(q, "n l (nh e) -> n l nh e", e=self.d_head)
            kv = rearrange(kv, "n l (t nh e) -> n l t nh e", t=2, e=self.d_head)
            q, kv = scale_for_cosine_sim_kv(q, kv, self.scale, 1e-6)
            theta_cross = torch.stack((theta_cross, torch.zeros_like(theta_cross)), dim=-3)
            q = self.pos_emb.apply_emb(q, theta)
            kv = self.pos_emb.apply_emb(kv, theta_cross)
            x = flash_attn.flash_attn_kvpacked_func(q, kv, softmax_scale=1.0)
            x = rearrange(x, "n l nh e -> n l (nh e)")
        else:
            q = rearrange(q, "n l (nh e) -> n nh l e", e=self.d_head)
            k, v = rearrange(kv, "n l (t nh e) -> t n nh l e", t=2, e=self.d_head)
            q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6)
            theta = theta.movedim(-2, -3)
            q = self.pos_emb.apply_emb(q, theta)
            theta_cross = theta_cross.movedim(-2, -3)
            k = self.pos_emb.apply_emb(k, theta_cross)
            x = F.scaled_dot_product_attention(q, k, v, scale=1.0)
            x = rearrange(x, "n nh l e -> n l (nh e)")

        x = self.dropout(x)
        x = self.out_proj(x)
        return x + skip


class GenericAttentionBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        pos_enc_cls,
        pos_enc_params,
        d_head: int = 64,
        d_cond_norm: int | None = None,
        dropout: float = 0.0,
        use_flash: bool = False,
        compile: bool = False,
    ):
        super().__init__()
        self.d_head = d_head
        self.n_heads = d_model // d_head
        if d_cond_norm is not None:
            self.norm = AdaRMSNorm(d_model, d_cond_norm)
        else:
            self.norm = RMSNorm(d_model)
        self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False))
        self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
        self.pos_emb = locate(pos_enc_cls)(d_head, self.n_heads, **pos_enc_params)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False)))

        self.use_flash = use_flash
        if compile and not (self.use_flash and get_use_flash_attention()):
            self.forward = torch.compile(self.forward)

    def extra_repr(self):
        return f"d_head={self.d_head},"

    def forward(self, x, pos, check_dict, cond_norm=None, keep_indices=None, **kwargs):
        skip = x
        if cond_norm is not None:
            x = self.norm(x, cond_norm)
            check_dict["cond_norm"] = True
        else:
            x = self.norm(x)
        qkv = self.qkv_proj(x)
        pos = pos.to(qkv.dtype)
        theta = self.pos_emb(pos)

        if self.use_flash and use_flash(qkv):
            qkv = rearrange(qkv, "n l (t nh e) -> n l t nh e", t=3, e=self.d_head)
            qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6)
            theta = torch.stack((theta, theta, torch.zeros_like(theta)), dim=-3)
            qkv = self.pos_emb.apply_emb(qkv, theta)
            x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0)
            x = rearrange(x, "n l nh e -> n l (nh e)")
        else:
            q, k, v = rearrange(qkv, "n l (t nh e) -> t n nh l e", t=3, e=self.d_head)
            q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6)
            theta = theta.movedim(-2, -3)
            q = self.pos_emb.apply_emb(q, theta)
            k = self.pos_emb.apply_emb(k, theta)

            if keep_indices is not None:
                check_dict["keep_indices"] = True
                B, N, D = x.shape
                num_heads = self.n_heads

                attention_mask = torch.ones((B, N), dtype=torch.bool, device=q.device)
                for b in range(B):
                    attention_mask[b, keep_indices[b]] = False

                mask = attention_mask
                attn_mask = mask.unsqueeze(1).unsqueeze(2)  # B x 1 x 1 x N
                attn_mask = attn_mask.expand(B, num_heads, N, N)  # B x num_heads x N x N
                attn_mask = attn_mask | attn_mask.transpose(-1, -2)
                attn_mask = attn_mask.float().masked_fill(attn_mask, float("-inf"))

                attn_mask = attn_mask.to(dtype=q.dtype)
                x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=1.0)
            else:
                x = F.scaled_dot_product_attention(q, k, v, scale=1.0)

            x = rearrange(x, "n nh l e -> n l (nh e)")

        x = self.dropout(x)
        x = self.out_proj(x)
        return x + skip
