import torch
import math
import torch.nn.functional as F
from torch import nn
from torch.jit import Final

from timm.layers.mlp import Mlp, SwiGLU
from timm.models.vision_transformer import Attention
from typing import Callable, Optional, List, Union
import numpy as np

#################################################################################
#                           Helper Functions from FiT                           #
#################################################################################


def modulate(x, shift, scale):
    """Apply modulation to input tensor."""
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

def create_norm(norm_type: Optional[str], dim: int):
    """Create normalization layer based on type."""
    if norm_type is None or norm_type == 'none':
        return nn.Identity()
    elif norm_type == 'layernorm':
        return nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
    elif norm_type == 'w_layernorm':
        return nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
    else:
        raise ValueError(f"Unknown norm type: {norm_type}")





#################################################################################
#           Embedding Layers for architecture            #
#################################################################################

class SemanticEmbedding(nn.Module):
    def __init__(self, embed_dim, n_vocab):
        super(SemanticEmbedding, self).__init__()
        self.embed_dim = embed_dim
        self.fc = nn.Linear(n_vocab, embed_dim)

    def forward(self, x):
        return self.fc(x)  * math.sqrt(self.embed_dim)


#################################################################################
#           Embedding Layers for Patches, Timesteps and Class Labels            #
#################################################################################

class PatchEmbedder(nn.Module):
    """
    Embeds latent features into vector representations
    """
    def __init__(self,
        input_dim,
        embed_dim,
        bias: bool = True,
        norm_layer: Optional[Callable] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.proj = nn.Linear(input_dim, embed_dim, bias=bias)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        x = self.proj(x)    # (B, L, patch_size ** 2 * C) -> (B, L, D)
        x = self.norm(x)
        return x  

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None] * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1]).to(device=t.device)], dim=-1)
        return embedding.to(dtype=t.dtype)

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb





#################################################################################
#                             Components                              #
#################################################################################

def get_2d_sincos_pos_embed(embed_dim, patches_h, patches_w, cls_token=False, extra_tokens=0):
    """
    Create 2D positional embeddings for a grid of patches.
    
    Args:
        embed_dim: Embedding dimension
        patches_h: Number of patches in height
        patches_w: Number of patches in width
        cls_token: Whether to include position for a class token
        extra_tokens: Number of extra tokens
    """
    grid_h = np.arange(patches_h, dtype=np.float32)
    grid_w = np.arange(patches_w, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)
    
    grid = grid.reshape([2, 1, patches_h, patches_w])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token and extra_tokens > 0:
        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb



def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def compute_3d_pos_embed(image_max_size_height, image_max_size_width, d_model, max_arch_len, patch_size):
    """
    Compute 3D positional embeddings for a rectangular image.
    
    Args:
        image_max_size_height: Height of the image
        image_max_size_width: Width of the image
        d_model: Embedding dimension
        max_arch_len: Maximum architecture length
        patch_size: Size of each patch (assumed square patches)
    """
    # Calculate actual number of patches in each dimension
    patches_h = image_max_size_height // patch_size
    patches_w = image_max_size_width // patch_size
    total_patches = patches_h * patches_w
    
    # Compute 2D spatial positional embedding
    pos_embed_2d = get_2d_sincos_pos_embed(d_model, patches_h, patches_w)  # Shape: (patches_h*patches_w, d_model)
    # print(pos_embed_2d.shape)
    
    # Compute 1D temporal positional embedding
    pos_embed_1d = get_1d_sincos_pos_embed_from_grid(d_model, np.arange(max_arch_len))  # Shape: (max_arch_len, d_model)
    # print(pos_embed_1d.shape)
    # Combine the temporal and spatial embeddings
    # Each architecture cell gets a unique positional embedding for every patch
    pos_3d = pos_embed_1d[:, None, :] + pos_embed_2d[None, :, :]  # Shape: (max_arch_len, total_patches, d_model)
    
    # Flatten the (max_arch_len, total_patches) grid into a single sequence dimension
    pos_3d = pos_3d.reshape(max_arch_len * total_patches, d_model)
    
    return pos_3d




#################################################################################
#                            FiT Components                              #
#################################################################################





class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self,
        hidden_size,
        num_heads,
        mlp_ratio=4.0,
        swiglu=True,
        swiglu_large=False,
        norm_layer: str = 'layernorm',
        q_norm: Optional[str] = None,
        k_norm: Optional[str] = None,
        qk_norm_weight: bool = False,
        qkv_bias=True,
        ffn_bias=True,
        adaln_bias=True,
        adaln_type='normal',
        adaln_lora_dim: int = None,
        **block_kwargs
    ):
        super().__init__()
        self.norm1 = create_norm(norm_layer, hidden_size)
        self.norm2 = create_norm(norm_layer, hidden_size)

        self.attn = Attention(
            hidden_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            **block_kwargs
        )
        
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        if swiglu:
            if swiglu_large:
                self.mlp = SwiGLU(in_features=hidden_size, hidden_features=mlp_hidden_dim, bias=ffn_bias)
            else:
                self.mlp = SwiGLU(in_features=hidden_size, hidden_features=(mlp_hidden_dim*2)//3, bias=ffn_bias)
        else:
            self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=ffn_bias)
        if adaln_type == 'normal':
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(),
                nn.Linear(hidden_size, 6 * hidden_size, bias=adaln_bias)
            )
        elif adaln_type == 'lora':
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(),
                nn.Linear(hidden_size, adaln_lora_dim, bias=adaln_bias),
                nn.Linear(adaln_lora_dim, 6 * hidden_size, bias=adaln_bias)
            )
        elif adaln_type == 'swiglu':
            self.adaLN_modulation = SwiGLU(
                in_features=hidden_size, hidden_features=(hidden_size//4)*3, out_features=6*hidden_size, bias=adaln_bias
            )

    def forward(self, x, c, mask=None, global_adaln=0.0):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.adaLN_modulation(c) + global_adaln).chunk(6, dim=1)
        # timm's Attention doesn't use mask parameter - process all tokens equally
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

class AWFinalLayer(nn.Module):
    """
    The final layer for architecture and weight processing.
    """
    def __init__(self, hidden_size, norm_layer: str = 'layernorm', adaln_bias=True, adaln_type='normal'):
        super().__init__()
        self.norm_final = create_norm(norm_type=norm_layer, dim=hidden_size)
        self.linear = nn.Linear(hidden_size, hidden_size, bias=True)
        if adaln_type == 'swiglu':
            self.adaLN_modulation = SwiGLU(in_features=hidden_size, hidden_features=hidden_size//2, out_features=2*hidden_size, bias=adaln_bias)
        else:   # adaln_type in ['normal', 'lora']
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(),
                nn.Linear(hidden_size, 2 * hidden_size, bias=adaln_bias)
            )
    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x
