import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional, Any, List


class DropPath(nn.Module):

    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = drop_prob
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        output = x.div(keep_prob) * random_tensor
        return output
    
    def extra_repr(self) -> str:
        return f'drop_prob={self.drop_prob}'


class SelfAttention(nn.Module):

    def __init__(
        self,
        block_idx: int,
        embed_dim: int = 768,
        num_heads: int = 12,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        attn_l2_norm: bool = True,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        
        self.block_idx = block_idx
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.attn_l2_norm = attn_l2_norm
        
        if self.attn_l2_norm:
            self.scale = 1.0
            self.scale_mul_1H11 = nn.Parameter(
                torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), 
                requires_grad=True
            )
            self.max_scale_mul = torch.log(torch.tensor(100.0)).item()
        else:
            self.scale = 0.25 / math.sqrt(self.head_dim)
        
        self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.q_bias = nn.Parameter(torch.zeros(embed_dim))
        self.v_bias = nn.Parameter(torch.zeros(embed_dim))
        self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
        
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity()
        self.attn_drop = attn_drop
        
        self.caching = False
        self.cached_k = None
        self.cached_v = None
    
    def kv_caching(self, enable: bool):
        self.caching = enable
        if not enable:
            self.cached_k = None
            self.cached_v = None
    
    def forward(self, x: torch.Tensor, attn_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, L, C = x.shape
        
        qkv = F.linear(
            input=x, 
            weight=self.mat_qkv.weight, 
            bias=torch.cat([self.q_bias, self.zero_k_bias, self.v_bias])
        ).view(B, L, 3, self.num_heads, self.head_dim)
        
        q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)  # [B, H, L, C//H]
        
        if self.attn_l2_norm:
            scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
            q = F.normalize(q, dim=-1).mul(scale_mul)
            k = F.normalize(k, dim=-1)
        
        if self.caching:
            if self.cached_k is None:
                self.cached_k, self.cached_v = k, v
            else:
                k = self.cached_k = torch.cat([self.cached_k, k], dim=2)
                v = self.cached_v = torch.cat([self.cached_v, v], dim=2)
        
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        if attn_bias is not None:
            attn_scores = attn_scores + attn_bias
        
        attn_probs = F.softmax(attn_scores, dim=-1)
        if self.training and self.attn_drop > 0:
            attn_probs = F.dropout(attn_probs, p=self.attn_drop)
        
        out = torch.matmul(attn_probs, v)  # [B, H, L, C//H]
        out = out.transpose(1, 2).reshape(B, L, C)
        
        return self.proj_drop(self.proj(out))
    
    def extra_repr(self) -> str:
        return f'attn_l2_norm={self.attn_l2_norm}, caching={self.caching}'


class FFN(nn.Module):

    def __init__(
        self, 
        in_features: int, 
        hidden_features: Optional[int] = None, 
        out_features: Optional[int] = None, 
        drop: float = 0.0
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class GeneAdaLNSelfAttn(nn.Module):

    def __init__(
        self,
        block_idx: int,
        embed_dim: int = 768,
        condition_dim: int = 768,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_eps: float = 1e-6,
        shared_aln: bool = False,
        attn_l2_norm: bool = True,
    ):
        super().__init__()
        
        self.block_idx = block_idx
        self.embed_dim = embed_dim
        self.condition_dim = condition_dim
        self.shared_aln = shared_aln
        
        self.attn = SelfAttention(
            block_idx=block_idx,
            embed_dim=embed_dim,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            attn_l2_norm=attn_l2_norm,
        )
        
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.ffn = FFN(
            in_features=embed_dim,
            hidden_features=mlp_hidden_dim,
            drop=drop_rate
        )
        
        self.ln_wo_grad = nn.LayerNorm(embed_dim, eps=norm_eps, elementwise_affine=False)
        
        if shared_aln:
            self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
        else:
            self.ada_lin = nn.Sequential(
                nn.SiLU(inplace=False),
                nn.Linear(condition_dim, 6 * embed_dim)
            )
        
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
    
    def enable_kv_cache(self, enable: bool = True):
        self.attn.kv_caching(enable)
    
    def forward(
        self, 
        x: torch.Tensor,                    # [B, L, C]
        condition_embed: torch.Tensor,      # [B, C] or [B, 1, 6, C]
        attn_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        B, L, C = x.shape
        
        if self.shared_aln:
            if condition_embed.dim() == 2:
                condition_embed = condition_embed.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, C]
            gamma1, gamma2, scale1, scale2, shift1, shift2 = (
                self.ada_gss + condition_embed
            ).unbind(2)  # 6 tensors of [B, 1, C]
        else:
            ada_params = self.ada_lin(condition_embed)  # [B, 6*C]
            gamma1, gamma2, scale1, scale2, shift1, shift2 = ada_params.view(
                B, 1, 6, C
            ).unbind(2)  # 6 tensors of [B, 1, C]
        
        x_norm1 = self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1)
        attn_output = self.attn(x_norm1, attn_mask)
        x = x + self.drop_path(attn_output.mul_(gamma1))
        
        x_norm2 = self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2)
        ffn_output = self.ffn(x_norm2)
        x = x + self.drop_path(ffn_output.mul(gamma2))
        
        return x
    
    def extra_repr(self) -> str:
        return f'shared_aln={self.shared_aln}, block_idx={self.block_idx}'


class GeneAdaLNBeforeHead(nn.Module):

    def __init__(self, embed_dim: int, condition_dim: int, norm_eps: float = 1e-6):
        super().__init__()
        self.embed_dim = embed_dim
        self.condition_dim = condition_dim
        
        self.ln_wo_grad = nn.LayerNorm(embed_dim, eps=norm_eps, elementwise_affine=False)
        self.ada_lin = nn.Sequential(
            nn.SiLU(inplace=False),
            nn.Linear(condition_dim, 2 * embed_dim)
        )
    
    def forward(self, x: torch.Tensor, condition_embed: torch.Tensor) -> torch.Tensor:

        scale, shift = self.ada_lin(condition_embed).view(-1, 1, 2, self.embed_dim).unbind(2)
        return self.ln_wo_grad(x).mul(scale.add(1)).add_(shift)


class ConditionProcessor(nn.Module):

    def __init__(
        self,
        histology_dim: int = 1024,
        spatial_dim: int = 2,
        condition_embed_dim: int = 768,
        histology_hidden_dim: int = 512,
        spatial_hidden_dim: int = 256,
        dropout: float = 0.1,
    ):
        super().__init__()
        
        self.histology_dim = histology_dim
        self.spatial_dim = spatial_dim
        self.condition_embed_dim = condition_embed_dim
        
        self.histology_processor = nn.Sequential(
            nn.LayerNorm(histology_dim),
            nn.Linear(histology_dim, histology_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(histology_hidden_dim, histology_hidden_dim),
            nn.LayerNorm(histology_hidden_dim)
        )
        
        self.spatial_processor = nn.Sequential(
            nn.Linear(spatial_dim, spatial_hidden_dim // 2),
            nn.GELU(),
            nn.Linear(spatial_hidden_dim // 2, spatial_hidden_dim),
            nn.LayerNorm(spatial_hidden_dim)
        )
        
        self.pos_encoding_dim = spatial_hidden_dim // 2
        div_term = torch.exp(torch.arange(0, self.pos_encoding_dim, 2).float() * 
                           (-math.log(10000.0) / self.pos_encoding_dim))
        self.register_buffer('div_term', div_term)
        
        total_dim = histology_hidden_dim + spatial_hidden_dim
        self.final_projection = nn.Sequential(
            nn.Linear(total_dim, condition_embed_dim),
            nn.LayerNorm(condition_embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(
        self, 
        histology_features: torch.Tensor,  # [B, histology_dim]
        spatial_coords: torch.Tensor       # [B, spatial_dim]
    ) -> torch.Tensor:                     # [B, condition_embed_dim]

        histology_embed = self.histology_processor(histology_features)  # [B, histology_hidden_dim]
        
        spatial_embed = self.spatial_processor(spatial_coords)  # [B, spatial_hidden_dim]
        
        B = spatial_coords.shape[0]
        x_coords = spatial_coords[:, 0:1]  # [B, 1]
        y_coords = spatial_coords[:, 1:2]  # [B, 1]
        
        x_pe = torch.zeros(B, self.pos_encoding_dim, device=spatial_coords.device)
        y_pe = torch.zeros(B, self.pos_encoding_dim, device=spatial_coords.device)
        
        x_pe[:, 0::2] = torch.sin(x_coords * self.div_term[None, :]) 
        x_pe[:, 1::2] = torch.cos(x_coords * self.div_term[None, :]) 
        y_pe[:, 0::2] = torch.sin(y_coords * self.div_term[None, :])
        y_pe[:, 1::2] = torch.cos(y_coords * self.div_term[None, :])
        
        pos_encoding = torch.cat([x_pe, y_pe], dim=1)  
        
        spatial_embed = spatial_embed + pos_encoding
        
        condition_features = torch.cat([histology_embed, spatial_embed], dim=1)  
        
        condition_embed = self.final_projection(condition_features)  
        
        return condition_embed
    
    def extra_repr(self) -> str:
        return (f'histology_dim={self.histology_dim}, spatial_dim={self.spatial_dim}, '
                f'condition_embed_dim={self.condition_embed_dim}')


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 2000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)  
        
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:x.size(0), :]


class GeneVARTransformer(nn.Module):

    def __init__(self, **kwargs):
        super().__init__()
        raise NotImplementedError(
            "deprecated"
        )