import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from typing import Optional
from torch.jit import Final
from timm.layers import Mlp, DropPath, use_fused_attn
import math
from flash_attn.flash_attn_interface import flash_attn_func
from models.modules.rotary import RotaryEmbedding, RoPEAttention
    
class FlashAttention(nn.Module):
    """
    FlashAttention module leveraging flash_attn for efficient attention.
    """
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'Embedding dimension must be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # You can add a switch to enable/disable flash attention
        self.use_flash_attn = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, N, C]
        Returns:
            Tensor of shape [B, N, C]
        """
        B, N, C = x.shape
        # 1) Project to Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        # -> [B, N, 3, num_heads, head_dim]
        qkv = qkv.permute(2, 0, 3, 1, 4)  # -> [3, B, num_heads, N, head_dim]
        q, k, v = qkv.unbind(dim=0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.use_flash_attn:
            # ---- Use FlashAttention ----
            # 1) reshape to [B*num_heads, N, head_dim]
            q_ = q.reshape(B * self.num_heads, N, self.head_dim)
            k_ = k.reshape(B * self.num_heads, N, self.head_dim)
            v_ = v.reshape(B * self.num_heads, N, self.head_dim)

            out = flash_attn_func(
                q_, k_, v_, 
                dropout_p=self.attn_drop.p if self.training else 0.,
                softmax_scale=self.scale
            )  # -> [B*num_heads, N, head_dim]
            x = out.reshape(B, self.num_heads, N, self.head_dim)
        else:
            # ---- Standard attention ----
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)  # -> [B, num_heads, N, N]
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v  # -> [B, num_heads, N, head_dim]

        # 2) Post-processing
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
class AlternatingAttention(nn.Module):
    fused_attn: Final[bool]
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
            in_chans = 23,
            block_idx   = 0
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'Embedding dimension must be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        # Toggle fused attention if supported
        self.fused_attn = False

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
        # Alternate between channel-wise and patch-wise attention per block index
        self.do_channel_attn = (block_idx % 2 == 0)
        self.in_chans = in_chans

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, D = x.shape
        T = N // self.in_chans

        if self.do_channel_attn:
            # Perform attention over channel dimension C
            x = rearrange(x, 'B (C T) D -> (B T) C D', C=self.in_chans)
            x = self.forward_attn(x)
            x = rearrange(x, '(B T) C D -> B (C T) D', T=T)
        else:
            # Perform attention over patch dimension T
            x = rearrange(x, 'B (C T) D -> (B C) T D', C=self.in_chans)
            x = self.forward_attn(x)
            x = rearrange(x, '(B C) T D -> B (C T) D', C=self.in_chans)
        return x
    
    def forward_attn(self, x):
        """Core attention mechanism (fused or standard)."""
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return self.proj_drop(x)

class Attention(nn.Module):
    fused_attn: Final[bool]
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        """Vanilla multi-head self-attention."""
        assert dim % num_heads == 0, 'Embedding dimension must be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = False

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            x = q @ k.transpose(-2, -1)
            x = x.softmax(dim=-1)
            x = self.attn_drop(x)
            x = x @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return self.proj_drop(x)

class LayerScale(nn.Module):
    def __init__(
            self,
            dim: int,
            init_values: float = 1e-5,
            inplace: bool = False,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

class CustomAttentionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        proj_drop: float = 0.0,
        attn_drop: float = 0.0,
        init_values: Optional[float] = None,
        drop_path: float = 0.0,
        act_layer: nn.Module = nn.GELU,
        norm_layer: nn.Module = nn.LayerNorm,
        mlp_layer: nn.Module = Mlp,
        num_channels: int = 23,
        attention_type: str = 'default',  # support 'rope'
        block_idx = 0,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)

        # 1) Select attention module based on attention_type
        if attention_type == 'default':
            self.attn = Attention(
                dim, num_heads=num_heads, qkv_bias=qkv_bias,
                qk_norm=qk_norm, attn_drop=attn_drop, proj_drop=proj_drop,
                norm_layer=norm_layer
            )
        elif attention_type == 'flash':
            self.attn = FlashAttention(
                dim, num_heads=num_heads, qkv_bias=qkv_bias,
                qk_norm=qk_norm, attn_drop=attn_drop, proj_drop=proj_drop,
                norm_layer=norm_layer
            )
        elif attention_type == 'alternating':
            self.attn = AlternatingAttention(
                dim, num_heads=num_heads, qkv_bias=qkv_bias,
                qk_norm=qk_norm, attn_drop=attn_drop, proj_drop=proj_drop,
                norm_layer=norm_layer, block_idx=block_idx,
                in_chans=num_channels
            )
        elif attention_type == 'rope':
            self.attn = RoPEAttention(
                dim=dim,
                num_heads=num_heads,
                rope_dim=dim // num_heads,
                attn_drop=attn_drop,
                proj_drop=proj_drop
            )
        else:
            raise ValueError(f"Unknown attention_type {attention_type}")

        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_attn = self.attn(self.norm1(x))
        x = x + self.drop_path1(self.ls1(x_attn))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x
