import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final

from timm.layers import use_fused_attn
from timm.layers import Mlp, DropPath

from typing import Optional

import numpy as np


class GenreCrossAttention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            ge_path: str,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
            wkv="linear",
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.wkv = wkv
        
        if wkv == "linear":
            self.kv = nn.Linear(512, dim * 2, bias=qkv_bias) ### dimension(fixed) = 512 for CLIP's text_features
        
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.register_buffer("genre_embed", torch.from_numpy(np.load(ge_path)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # print(x.shape)
        # raise ValueError()
        # B, N, C = x.shape
        B, C = x.shape
        
        # q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        q = self.q(x).reshape(B, 1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        q = self.q_norm(q)

        M, _ = self.genre_embed.shape
        if self.wkv == "linear":
            kv = self.kv(self.genre_embed.to(x))[None, :, :].reshape(1, M, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            k, v = kv.unbind(0)
        elif self.wkv == "direct":
            kv = torch.cat((self.genre_embed, torch.zeros(M, self.num_heads*self.head_dim-512)), -1).to(x)[None, :, :].reshape(1, M, 1, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            # kv = torch.cat((self.genre_embed, torch.zeros(M, self.num_heads*self.head_dim-512)), -1).to(x)[None, :, :].reshape(1, M, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
            k, v = kv, kv
        k = self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v
        
        # x = x.transpose(1, 2).reshape(B, N, C)
        x = x.transpose(1, 2).reshape(B, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class LayerScale(nn.Module):
    def __init__(
            self,
            dim: int,
            init_values: float = 1e-5,
            inplace: bool = False,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

class Block_CA(nn.Module):
    def __init__(
            self,
            dim: int,
            ge_path: str,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            init_values: Optional[float] = None,
            drop_path: float = 0.,
            act_layer: nn.Module = nn.GELU,
            norm_layer: nn.Module = nn.LayerNorm,
            mlp_layer: nn.Module = Mlp,
            wkv="linear",
    ) -> None:
        super().__init__()

        self.cross_attn = GenreCrossAttention(
            dim,
            ge_path=ge_path,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
            wkv=wkv,
        )

        self.cross_norm = norm_layer(dim)
        self.cross_ls = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.cross_drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        x = x + self.cross_drop_path(self.cross_ls(self.cross_attn(self.cross_norm(x))))
        return x
