"""
torchrun --nnodes=1 --nproc_per_node=4 train.py

"""
import math
import numpy as np
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


from src.dit_modules import (
    TimestepEmbedder, PatchEmbedder, compute_3d_pos_embed, get_1d_sincos_pos_embed_from_grid,
    DiTBlock, AWFinalLayer,
    SemanticEmbedding
)


class NNiT(nn.Module):
    """
    Neural Network Diffusion model with a Transformer backbone and multimodal RoPE support.
    """
    def __init__(
        self,
        architecture_max_layer=6,  # input + 4 hidden + output
        architecture_n_vocab=5,  # input, output, 16, 32, 64 
        weight_max_size=64,  
        patch_size=8,
        hidden_size=256,
        depth = 12,
        num_heads = 16,
        mlp_ratio=4.0,
        learn_sigma=True,
        use_swiglu=True,
        use_swiglu_large=False,
        norm_type='layernorm',
        q_norm: Optional[str] = None,
        k_norm: Optional[str] = None,
        qk_norm_weight: bool = False,
        qkv_bias: bool = True,
        ffn_bias: bool = True,
        adaln_bias: bool = True,
        adaln_type: str = "normal",
        token_drop=0.0,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.architecture_max_layer = architecture_max_layer
        self.architecture_n_vocab = architecture_n_vocab
        
        if isinstance(weight_max_size, (tuple, list)):
            self.weight_max_height, self.weight_max_width = weight_max_size
        else:
            self.weight_max_height = weight_max_size
            self.weight_max_width = weight_max_size
        
        #mm-diffusion does not use it but better to use to stablilize
        if learn_sigma:
            self.weight_out_channels = 1 * 2
            self.architecture_out_n_vocab = architecture_n_vocab * 2
        else:
            self.weight_out_channels = 1
            self.architecture_out_n_vocab = architecture_n_vocab
        
        self.patch_size = patch_size
        self.learn_sigma = learn_sigma
        self.depth = depth
        self.num_heads = num_heads
        self.adaln_type = adaln_type
        
        # Token dropout - applied after embeddings
        self.token_dropout = nn.Dropout(p=token_drop)
        
        #---Timestep Embedding---
        self.w_t_embedder = TimestepEmbedder(self.hidden_size)
        self.a_t_embedder = TimestepEmbedder(self.hidden_size)

        self.merge_mlp = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size*2),
            nn.SiLU(),
            nn.Linear(self.hidden_size*2, self.hidden_size),
        )
        
        #---Architecture Embedding---
        self.arch_opEmb = SemanticEmbedding(self.hidden_size, self.architecture_n_vocab)
        self.arch_final_layer = nn.Linear(self.hidden_size, self.architecture_out_n_vocab, bias=True)

        #---Weight Embedding---
        # Linear patch embedding (like FiT)
        img_height = self.weight_max_height  # weight_max_size (output neurons)
        img_width = self.weight_max_width + self.patch_size  # weight_max_size + patch_size (input neurons + bias)

        # Linear projection: (patch_size * patch_size * 1) -> hidden_size
        self.weight_embedder = PatchEmbedder(
            input_dim=self.patch_size * self.patch_size * 1,  # Flattened patch
            embed_dim=self.hidden_size,
            bias=True
        )
        self.weight_final_layer = nn.Linear(self.hidden_size, self.weight_out_channels * self.patch_size * self.patch_size, bias=True)
        
        # Calculate number of patches for rectangular input
        self.patches_h = img_height // self.patch_size
        self.patches_w = img_width // self.patch_size
        self.num_patches = self.patches_h * self.patches_w
        
        # Calculate sequence lengths
        # Architecture: N tokens, Weights: N-1 matrices (each with num_patches patches)
        self.weight_seq_len = (self.architecture_max_layer - 1) * self.num_patches
        self.arch_seq_len = self.architecture_max_layer
        
        # DiT transformer blocks
        self.blocks = nn.ModuleList([DiTBlock(
            hidden_size, num_heads, mlp_ratio=mlp_ratio, swiglu=use_swiglu, swiglu_large=use_swiglu_large,
            norm_layer=norm_type, q_norm=q_norm, k_norm=k_norm, qk_norm_weight=qk_norm_weight,
            qkv_bias=qkv_bias, ffn_bias=ffn_bias, adaln_bias=adaln_bias, adaln_type=adaln_type
        ) for _ in range(depth)])

        # Final layer processes concatenated embeddings
        self.final_layer = AWFinalLayer(hidden_size, norm_layer=norm_type, adaln_bias=adaln_bias, adaln_type=adaln_type)
        
        # Initialize positional embeddings
        self.initialize_positional_embeddings()
        
        self.initialize_weights()
        
    def initialize_weights(self, pretrain_ckpt=None, ignore=None):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

        # Initialize weight patch embedder (linear projection)
        nn.init.xavier_uniform_(self.weight_embedder.proj.weight)
        if self.weight_embedder.proj.bias is not None:
            nn.init.constant_(self.weight_embedder.proj.bias, 0)

        # Initialize timestep embedding MLPs:
        nn.init.normal_(self.w_t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.w_t_embedder.mlp[2].weight, std=0.02)
        nn.init.normal_(self.a_t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.a_t_embedder.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for block in self.blocks:
            if self.adaln_type in ['normal']:
                nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
                nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
            elif self.adaln_type == 'swiglu':
                nn.init.constant_(block.adaLN_modulation.fc2.weight, 0)
                nn.init.constant_(block.adaLN_modulation.fc2.bias, 0)
        if self.adaln_type == 'lora':
            nn.init.constant_(self.global_adaLN_modulation[-1].weight, 0)
            nn.init.constant_(self.global_adaLN_modulation[-1].bias, 0)
        # Zero-out output layers:
        if self.adaln_type == 'swiglu':
            nn.init.constant_(self.final_layer.adaLN_modulation.fc2.weight, 0)
            nn.init.constant_(self.final_layer.adaLN_modulation.fc2.bias, 0)
        else:   # adaln_type in ['normal', 'lora']
            nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)
        
        keys = list(self.state_dict().keys())
        ignore_keys = []
        if ignore != None:
            for ign in ignore:
                for key in keys:
                    if ign in key:
                        ignore_keys.append(key)
        ignore_keys = list(set(ignore_keys))
        # if pretrain_ckpt != None:
        #     init_from_ckpt(self, pretrain_ckpt, ignore_keys, verbose=True)
        

    def initialize_positional_embeddings(self):
        """
        Initialize sinusoidal positional embeddings for weight patches and architecture tokens.
        """
        # Initialize 3D positional embeddings for weight patches (layer, height, width)
        # For N architecture layers, we have N-1 weight matrices
        weight_pos_embed = compute_3d_pos_embed(
            image_max_size_height=self.weight_max_height,
            image_max_size_width=self.weight_max_width + self.patch_size,
            d_model=self.hidden_size,
            max_arch_len=self.architecture_max_layer - 1,  # 6 weight matrices for 7 architecture tokens
            patch_size=self.patch_size
        )
        # compute_3d_pos_embed now returns shape ((max_arch_len-1) * total_patches, d_model)
        self.register_buffer('weight_pos_embed', torch.from_numpy(weight_pos_embed).float())
        
        # Initialize 1D positional embeddings for architecture tokens
        # Scale by sqrt(2) to match 3D position norm (3D = 1D + 2D, so ~sqrt(2) larger)
        arch_pos_embed = get_1d_sincos_pos_embed_from_grid(
            embed_dim=self.hidden_size,
            pos=np.arange(self.architecture_max_layer, dtype=np.float32)
        )
        self.register_buffer('arch_pos_embed', torch.from_numpy(arch_pos_embed).float())

    def ckpt_wrapper(self, module):
        def ckpt_forward(*inputs):
            outputs = module(*inputs)
            return outputs
        return ckpt_forward
            
    def unpatchify(self, x):
        B, total_patches, _ = x.shape
        # Number of weight matrices = architecture_max_layer - 1
        # (N architecture layers need N-1 weight matrices to connect them)
        S = self.architecture_max_layer - 1
        num_patches_per_sequence = total_patches // S
        c = self.weight_out_channels
        p = self.patch_size  # Use patch_size directly (no longer have patch_size attribute on embedder)

        # Use the actual patch dimensions calculated during initialization
        h = self.patches_h
        w = self.patches_w
        assert h * w == num_patches_per_sequence, f"Patch dimensions mismatch: {h}*{w}={h*w} != {num_patches_per_sequence}"

        x = x.reshape(B, S, h, w, p, p, c)
        x = torch.einsum('bshwpqc->bschpwq', x)
        x = x.reshape(B, S, c, h*p, w*p)
        return x
            
    def forward(self, architecture, weight, t_arch, t_weight, **kwargs):
        """
        Forward pass of NiT with 3D video-style positional embeddings.

        Args:
            weight: (N, S, 1, H, W) tensor of weight inputs
            architecture: (N, S, n_vocab) tensor of architecture inputs
            t_weight: (N,) tensor of diffusion timesteps for weights
            t_arch: (N,) tensor of diffusion timesteps for architecture
            **kwargs: Must contain 'arch_mask' and 'weight_mask'
                arch_mask: (N, S) tensor where True=valid layer, False=padded layer
                weight_mask: (N, S) tensor where True=valid weight layer, False=padded
        """
        # Extract masks from kwargs
        # arch_mask = kwargs['arch_mask']
        # weight_mask = kwargs['weight_mask']

        B = weight.shape[0]

        #---Timestep Embedding---   
        t_w = self.w_t_embedder(t_weight.float())
        t_a = self.a_t_embedder(t_arch.float())
        
        emb_t = torch.cat([t_w, t_a], dim=-1)
        emb_t = self.merge_mlp(emb_t)

        condition_emb = emb_t
        
        
        #---Architecture Embedding---
        emb_arch = self.arch_opEmb(architecture) 
        emb_arch = emb_arch + self.arch_pos_embed.unsqueeze(0)
        emb_arch = self.token_dropout(emb_arch)  
        
        #---Weight Embedding---
        B, S, C, H, W = weight.shape
        p = self.patch_size

        # Manually patchify: (B, S, C, H, W) -> (B, S, num_patches, patch_size^2 * C)
        # Reshape to patches
        h_patches = H // p
        w_patches = W // p
        weight_patches = weight.reshape(B, S, C, h_patches, p, w_patches, p)
        weight_patches = weight_patches.permute(0, 1, 3, 5, 4, 6, 2)  # (B, S, h_p, w_p, p, p, C)
        weight_patches = weight_patches.reshape(B, S, h_patches * w_patches, p * p * C)  # (B, S, num_patches, p^2*C)

        # Flatten for linear embedding: (B*S, num_patches, p^2*C)
        emb_weight = weight_patches.reshape(B * S, h_patches * w_patches, p * p * C)
        emb_weight = self.weight_embedder(emb_weight) 

        _, num_spatial, dim = emb_weight.shape
        emb_weight = emb_weight.reshape(B, S, num_spatial, dim)
        emb_weight = emb_weight.reshape(B, S * num_spatial, dim)
        emb_weight = emb_weight + self.weight_pos_embed.unsqueeze(0)
        emb_weight = self.token_dropout(emb_weight)  # Apply dropout to weight tokens        

        #---Concatenate Embeddings---
        concatenated_emb = torch.cat([emb_arch, emb_weight], dim=1)

        #---DiT Transformer Blocks (no masking needed)---
        for block in self.blocks:
            concatenated_emb = block(
                concatenated_emb,
                condition_emb,
                mask=None
            )
        #---Final Layer---
        aw_out = self.final_layer(concatenated_emb, condition_emb)
        # Before line 235, add:
        assert emb_weight.size(1) + emb_arch.size(1) == aw_out.size(1), f"Split mismatch: {emb_weight.size(1)} + {emb_arch.size(1)} != {aw_out.size(1)}"
        # weight_out, arch_out = torch.split(aw_out, [emb_weight.size(1), emb_arch.size(1)], dim=1)
        arch_out, weight_out = torch.split(aw_out, [emb_arch.size(1), emb_weight.size(1)], dim=1)

        #---Architecture Final Layer---
        arch_out = self.arch_final_layer(arch_out)
 
        #---Weight Final Layer---
        weight_out = self.weight_final_layer(weight_out)
        weight_out = self.unpatchify(weight_out)      
        

        
        return arch_out,weight_out
    

    