import torch
import torch.nn as nn
import numpy as np
import math
from dpm_model.DiT import TimestepEmbedder, DiTBlock, get_2d_sincos_pos_embed
from config.base_config import Config


class DiT_txt_trunc(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    1. No patchify and unpatchify
    2. tailor the classifier-condition
    """
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        hidden_size = 512  # make consistent with CLIP
        depth = config.DiT_blocks   # number of DiT blocks
        num_heads = 16
        mlp_ratio = 4.0

        self.num_heads = num_heads

        self.feature_aligner_txt = nn.Linear(self.config.embed_dim, hidden_size)
        self.feature_aligner_vid = nn.Linear(self.config.embed_dim, hidden_size)


        self.feature_recover = nn.Linear(hidden_size, self.config.embed_dim)

        self.t_embedder = TimestepEmbedder(hidden_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.config.num_frames, hidden_size), requires_grad=False)

        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
        ])
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)


        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

        if self.config.feature_align_init == 'normal':
            nn.init.normal_(self.feature_aligner_txt.weight, std=0.02)
            nn.init.constant_(self.feature_aligner_txt.bias, 0)
            nn.init.normal_(self.feature_aligner_vid.weight, std=0.02)
            nn.init.constant_(self.feature_aligner_vid.bias, 0)
            nn.init.normal_(self.feature_recover.weight, std=0.02)
            nn.init.constant_(self.feature_recover.bias, 0)
        elif self.config.feature_align_init == 'eye':
            nn.init.eye_(self.feature_aligner_txt.weight)
            nn.init.constant_(self.feature_aligner_txt.bias, 0)
            nn.init.eye_(self.feature_aligner_vid.weight)
            nn.init.constant_(self.feature_aligner_vid.bias, 0)
            nn.init.eye_(self.feature_recover.weight)
            nn.init.constant_(self.feature_recover.bias, 0)
        elif self.config.feature_align_init == 'xavier_uniform':
            nn.init.xavier_uniform(self.feature_aligner_txt.weight)
            nn.init.constant_(self.feature_aligner_txt.bias, 0)
            nn.init.xavier_uniform(self.feature_aligner_vid.weight)
            nn.init.constant_(self.feature_aligner_vid.bias, 0)
            nn.init.xavier_uniform(self.feature_recover.weight)
            nn.init.constant_(self.feature_recover.bias, 0)
        else:
            raise NotImplementedError

        # Zero-out adaLN modulation layers in DiT blocks:
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)




    def forward(self, x, t):


        t = self.t_embedder(t)
        c = t
        for block in self.blocks:
            x = block(x, c)
        return x

