"""
Based on: https://github.com/kuleshov-group/mdlm

"""

import math
import typing

import flash_attn
import flash_attn.layers.rotary
import huggingface_hub
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)


# ==== Small utils ====

def bias_dropout_add_scale(x: torch.Tensor,
                           bias: typing.Optional[torch.Tensor],
                           scale: torch.Tensor,
                           residual: typing.Optional[torch.Tensor],
                           prob: float,
                           training: bool) -> torch.Tensor:
    if bias is not None:
        out = scale * F.dropout(x + bias, p=prob, training=training)
    else:
        out = scale * F.dropout(x, p=prob, training=training)
    if residual is not None:
        out = residual + out
    return out


@torch.jit.script
def bias_dropout_add_scale_fused_train(
    x: torch.Tensor,
    b: typing.Optional[torch.Tensor],
    s: torch.Tensor,
    r: typing.Optional[torch.Tensor],
    p: float
) -> torch.Tensor:
    return bias_dropout_add_scale(x, b, s, r, p, True)

@torch.jit.script
def bias_dropout_add_scale_fused_inference(
    x: torch.Tensor,
    b: typing.Optional[torch.Tensor],
    s: torch.Tensor,
    r: typing.Optional[torch.Tensor],
    p: float
) -> torch.Tensor:
    return bias_dropout_add_scale(x, b, s, r, p, False)

def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    return x * (1 + scale) + shift


@torch.jit.script
def modulate_fused(x, shift, scale):
    return modulate(x, shift, scale)


# ==== Rotary embed ====

class Rotary(torch.nn.Module):
    def __init__(self, dim, base=10_000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=1):
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
            self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
            self.cos_cached[:, :, 2, :, :].fill_(1.)
            self.sin_cached[:, :, 2, :, :].fill_(0.)
        return self.cos_cached, self.sin_cached


def apply_rotary_pos_emb(qkv, cos, sin):
    cos = cos[0, :, 0, 0, :cos.shape[-1] // 2]
    sin = sin[0, :, 0, 0, :sin.shape[-1] // 2]
    return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)


# ==== Layers ====

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.ones([dim]))
        self.dim = dim

    def forward(self, x):
        with torch.cuda.amp.autocast(enabled=False):
            x = F.layer_norm(x.float(), [self.dim])
        return x * self.weight[None, None, :]
        
        
# ==== Embedding Layers for Timesteps and Class Labels ====

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations
    
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings
          - t: a 1-D Tensor of N indices, one per batch element
          - dim: the dimension of the input
          - max_period: controls the minimum freq of the embeddings
          - return: an (N,D) Tensor of positional embeddings
        """
        half = dim // 2
        freqs = torch.exp(
            - math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        return self.mlp(t_freq)



# === Core Model ===

class DDiTBlock(nn.Module):
    def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        
        self.norm1 = LayerNorm(dim)
        self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.attn_out = nn.Linear(dim, dim, bias=False)

        self.norm2 = LayerNorm(dim)
        
        mlp_hidden = int(dim * mlp_ratio)
        
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden, bias=True),
            nn.GELU(approximate='tanh'),
            nn.Linear(mlp_hidden, dim, bias=True),
        )

        self.dropout = dropout
        
        self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def _bdas(self):
        return bias_dropout_add_scale_fused_train if self.training else bias_dropout_add_scale_fused_inference

    def forward(self, x, rotary_cos_sin, c, seqlens=None, image_features=None): 
        B, S, H = x.shape
        bdas = self._bdas()

        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(c)[:, None].chunk(6, dim=2)

        # Self-attention
        x_skip = x
        x_norm = self.norm1(x)
        x_mod = modulate_fused(x_norm, shift_msa, scale_msa)

        qkv = self.attn_qkv(x_mod)
        qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.n_heads)

        with torch.cuda.amp.autocast(enabled=False):
            cos, sin = rotary_cos_sin
            qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))

        qkv_packed = rearrange(qkv, 'b s three h d -> (b s) three h d')

        if seqlens is None:
            cu_seqlens = torch.arange(0, (B + 1) * S, step=S, dtype=torch.int32, device=qkv_packed.device)
            max_s = S
        else:
            cu_seqlens = torch.cat([torch.tensor([0], device=seqlens.device, dtype=seqlens.dtype),
                                    seqlens.cumsum(dim=-1)], dim=0)
            max_s = int(torch.max(seqlens).item())

        attn_out = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
            qkv_packed, cu_seqlens, max_s, 0., causal=False
        )
        attn_out = rearrange(attn_out, '(b s) h d -> b s (h d)', b=B)

        x = bdas(self.attn_out(attn_out), None, gate_msa, x_skip, self.dropout)

        # MLP
        x_skip2 = x
        x_norm2 = self.norm2(x)
        x_mod2 = modulate_fused(x_norm2, shift_mlp, scale_mlp)
        x_mlp = self.mlp(x_mod2)
        x = bdas(x_mlp, None, gate_mlp, x_skip2, self.dropout)
        return x


class EmbeddingLayer(nn.Module):
    def __init__(self, dim, vocab_dim):
        super().__init__()
        self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
        torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))

    def forward(self, x):
        return self.embedding[x]


class DDitFinalLayer(nn.Module):
    def __init__(self, hidden_size, out_channels, cond_dim):
        super().__init__()
        self.norm_final = LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.linear.weight.data.zero_()
        self.linear.bias.data.zero_()

        self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
        x = modulate_fused(self.norm_final(x), shift, scale)
        return self.linear(x)


# ==== DIT (OWT-compatible, optional image cond) ====

class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
    """
    OWT compatible DIT with optional image conditioning:
    
    """
    def __init__(self, config, vocab_size: int, image_feature_dim: int = None):
        super().__init__()
        if isinstance(config, dict):
            config = omegaconf.OmegaConf.create(config)
            
        self.config = config
        self.vocab_size = vocab_size

        self.vocab_embed = EmbeddingLayer(config.model.hidden_size, vocab_size)
        self.sigma_map  = TimestepEmbedder(config.model.cond_dim)
        self.rotary_emb = Rotary(config.model.hidden_size // config.model.n_heads)
        
        self.use_image_conditioning = image_feature_dim is not None
        if self.use_image_conditioning:
            in_dim = image_feature_dim  
            self.image_adapter = nn.Sequential(
                nn.Linear(in_dim, config.model.cond_dim, bias=True),
                nn.SiLU(),
                nn.Linear(config.model.cond_dim, config.model.cond_dim, bias=True),
            )
        else:
            self.image_adapter = None

        self.blocks = nn.ModuleList([
            DDiTBlock(config.model.hidden_size,
                      config.model.n_heads,
                      config.model.cond_dim,
                      mlp_ratio=getattr(config.model, "mlp_ratio", 4.0),
                      dropout=config.model.dropout)
            for _ in range(config.model.n_blocks)
        ])
        
        self.output_layer = DDitFinalLayer(config.model.hidden_size, vocab_size, config.model.cond_dim)
        self.scale_by_sigma = config.model.scale_by_sigma

        logger.info(f"DIT initialized – use_image_conditioning={self.use_image_conditioning}; "
                    f"cond_dim={config.model.cond_dim} (unchanged)")

    def forward(self, indices: torch.Tensor, sigma: torch.Tensor, image_features: typing.Optional[torch.Tensor] = None):
        x = self.vocab_embed(indices)                
        sigma_emb = F.silu(self.sigma_map(sigma))    

        if self.use_image_conditioning and (image_features is not None):
            # Accept either global features or patch tokens; pool if needed
            if image_features.dim() == 3:
                img_global = image_features.mean(dim=1)  
            else:
                img_global = image_features              
            img_cond = self.image_adapter(img_global)    
            c = sigma_emb + img_cond                    
        else:
            c = sigma_emb

        rotary_cos_sin = self.rotary_emb(x)
        
        use_bf16 = hasattr(torch, "bfloat16") and torch.cuda.is_available() and torch.cuda.is_bf16_supported()
        dtype = torch.bfloat16 if use_bf16 else torch.float32
        
        with torch.cuda.amp.autocast(dtype=dtype):
            for blk in self.blocks:
                x = blk(x, rotary_cos_sin, c, seqlens=None, image_features=image_features)
            x = self.output_layer(x, c)
        return x


class DDiTBlockWithCrossAttn(DDiTBlock):
    # Enhanced DDiT block that adds cross-attention while preserving original functionality
    
    def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4.0, dropout=0.1, 
                 use_cross_attn=False, image_feature_dim=None):
                 
        # Initialize parent class (original DDiTBlock)
        super().__init__(dim, n_heads, cond_dim, mlp_ratio, dropout)
        
        self.use_cross_attn = use_cross_attn
        
        # Add cross-attention components if enabled
        if use_cross_attn and image_feature_dim is not None:
            self.cross_attn_norm = LayerNorm(dim)
            self.cross_attn_q = nn.Linear(dim, dim, bias=False)
            self.cross_attn_kv = nn.Linear(image_feature_dim, 2 * dim, bias=False)
            self.cross_attn_out = nn.Linear(dim, dim, bias=False)
            
            #Initialize output to zero for stable training
            nn.init.zeros_(self.cross_attn_out.weight)
            
            logger.info(f"Added cross-attention to DDiT block (dim={dim}, img_dim={image_feature_dim})")
    
    def forward(self, x, rotary_cos_sin, c, seqlens=None, image_features=None):
        # Run original DDiTBlock forward pass first
        x = super().forward(x, rotary_cos_sin, c, seqlens, image_features)
        
        # Add cross-attention if enabled and image features provided
        if self.use_cross_attn and image_features is not None:
            B, S, H = x.shape
            
            # Handle different image feature shapes
            if image_features.dim() == 3:  # [B, N, D] - patch features
                img_feat = image_features
            else:  # [B, D] - global features, expand to sequence
                img_feat = image_features.unsqueeze(1)  # [B, 1, D]
            
            x_skip = x
            x_norm = self.cross_attn_norm(x)
            
            # Cross-attention computation
            q = self.cross_attn_q(x_norm)  # [B, S, H]
            kv = self.cross_attn_kv(img_feat)  # [B, N, 2*H]
            k, v = kv.chunk(2, dim=-1)  # [B, N, H] each
            
            # Reshape for multi-head attention
            q = q.view(B, S, self.n_heads, H // self.n_heads).transpose(1, 2)
            k = k.view(B, -1, self.n_heads, H // self.n_heads).transpose(1, 2)
            v = v.view(B, -1, self.n_heads, H // self.n_heads).transpose(1, 2)
            
            # Scaled dot-product attention
            scale = (H // self.n_heads) ** -0.5
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
            attn_weights = F.softmax(attn_scores, dim=-1)
            
            cross_out = torch.matmul(attn_weights, v)
            cross_out = cross_out.transpose(1, 2).reshape(B, S, H)
            cross_out = self.cross_attn_out(cross_out)
            
            # Residual connection
            x = x_skip + cross_out
        
        return x
