from typing import Any, Dict, List, Optional

import torch
import os
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import json
from omegaconf import ListConfig

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import is_torch_version

from .modeling_normalization import AdaLayerNormContinuous
from .modeling_embedding import CombinedTimestepTextProjEmbeddings
from .modeling_flux_block import FluxTransformerBlock, FluxSingleTransformerBlock

from safetensors.torch import load_file



def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
    assert dim % 2 == 0, "The dimension must be even."

    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
    omega = 1.0 / (theta**scale)

    batch_size, seq_length = pos.shape
    out = torch.einsum("...n,d->...nd", pos, omega)
    cos_out = torch.cos(out)
    sin_out = torch.sin(out)

    stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
    out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
    return out.float()


class EmbedND(nn.Module):
    def __init__(self, dim: int, theta: int, axes_dim: List[int]):
        super().__init__()
        self.dim = dim 
        self.theta = theta 
        self.axes_dim = axes_dim 

    def forward(self, ids: torch.Tensor) -> torch.Tensor:
        n_axes = ids.shape[-1] 
        emb = torch.cat(
            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 
            dim=-3,
        )
        return emb.unsqueeze(2)


class PyramidFluxTransformer(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        patch_size: int = 1,
        in_channels: int = 64,
        num_layers: int = 19,
        num_single_layers: int = 38,
        attention_head_dim: int = 64,
        num_attention_heads: int = 24,
        joint_attention_dim: int = 4096,
        pooled_projection_dim: int = 768,
        axes_dims_rope: List[int] = [16, 24, 24],
        use_flash_attn: bool = False,
        use_temporal_causal: bool = True,
        interp_condition_pos: bool = True,
        use_gradient_checkpointing: bool = False,
        gradient_checkpointing_ratio: float = 0.6,
        use_audio_cross_attn: bool = False,
        init_additional_args: Dict[str, Any] = {},
        audio_input_dim: int = 1024,
    ):
        super().__init__()
        
        self.out_channels = in_channels
        self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
        self.use_audio_cross_attn = use_audio_cross_attn


        self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
        self.time_text_embed = CombinedTimestepTextProjEmbeddings(
            embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
        )

        self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
        self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
        if self.use_audio_cross_attn:
            self.audio_embedder = nn.Sequential(
                nn.Linear(audio_input_dim, audio_input_dim),
                nn.GELU(approximate='tanh'),
                nn.Linear(audio_input_dim, audio_input_dim),
            )

        
        ### Additional arguments
        single_audio_layer_idx_list = []
        mmdit_audio_layer_idx_list = []
        if self.use_audio_cross_attn:
            audio_layer_idx = (init_additional_args or {}).get("audio_layer_idx")
            
            single_audio_layer_idx_list = list(audio_layer_idx.get("single", [])) if audio_layer_idx else list(range(self.config.num_single_layers))
            mmdit_audio_layer_idx_list = list(audio_layer_idx.get("mmdit", [])) if audio_layer_idx else []

            print(f"single_audio_layer_idx_list {single_audio_layer_idx_list}")
            print(f"mmdit_audio_layer_idx_list {mmdit_audio_layer_idx_list}")
            
            assert single_audio_layer_idx_list or mmdit_audio_layer_idx_list, "single_audio_layer_idx_list and mmdit_audio_layer_idx_list cannot be both empty"
            
        assert isinstance(single_audio_layer_idx_list, (list, ListConfig)), f"single_audio_layer_idx_list must be a list, {single_audio_layer_idx_list}"
        assert isinstance(mmdit_audio_layer_idx_list, (list, ListConfig)), f"mmdit_audio_layer_idx_list must be a list, {mmdit_audio_layer_idx_list}"
                
        

        self.transformer_blocks = nn.ModuleList(
            [
                FluxTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                    use_flash_attn=use_flash_attn,
                    use_audio_cross_attn=True if i in mmdit_audio_layer_idx_list else False,
                )
                for i in range(self.config.num_layers)
            ]
        )

        self.single_transformer_blocks = nn.ModuleList(
            [
                FluxSingleTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                    use_flash_attn=use_flash_attn,
                    use_audio_cross_attn=True if i in single_audio_layer_idx_list else False, 
                )
                for i in range(self.config.num_single_layers)
            ]
        )

        self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
        self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)

        self.gradient_checkpointing = use_gradient_checkpointing
        self.gradient_checkpointing_ratio = gradient_checkpointing_ratio

        self.use_temporal_causal = use_temporal_causal
        if self.use_temporal_causal:
            print("Using temporal causal attention")

        self.use_flash_attn = use_flash_attn
        if self.use_flash_attn:
            print("Using Flash attention")

        self.patch_size = 2 

        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

        # Initialize all the conditioning to normal init
        nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02)
        nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02)
        nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02)
        nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02)
        nn.init.normal_(self.context_embedder.weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for block in self.transformer_blocks:
            nn.init.constant_(block.norm1.linear.weight, 0)
            nn.init.constant_(block.norm1.linear.bias, 0)
            nn.init.constant_(block.norm1_context.linear.weight, 0)
            nn.init.constant_(block.norm1_context.linear.bias, 0)

        for block in self.single_transformer_blocks:
            nn.init.constant_(block.norm.linear.weight, 0)
            nn.init.constant_(block.norm.linear.bias, 0)

        # Zero-out output layers:
        nn.init.constant_(self.norm_out.linear.weight, 0)
        nn.init.constant_(self.norm_out.linear.bias, 0)
        nn.init.constant_(self.proj_out.weight, 0)
        nn.init.constant_(self.proj_out.bias, 0)


    @torch.no_grad()
    def _prepare_image_ids(self, batch_size, temp, height, width, train_height, train_width, device, start_time_stamp=0):
        latent_image_ids = torch.zeros(temp, height, width, 3)

        # Temporal Rope
        latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None]

        # height Rope
        if height != train_height:
            height_pos = F.interpolate(torch.arange(train_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1)
        else:
            height_pos = torch.arange(train_height).float()

        latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None]

        # width rope
        if width != train_width:
            width_pos = F.interpolate(torch.arange(train_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1)
        else:
            width_pos = torch.arange(train_width).float()

        latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :]

        latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
        latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')

        return latent_image_ids.to(device=device)

    @torch.no_grad()
    def _prepare_pyramid_image_ids(self, sample, batch_size, device):
        image_ids_list = []
        for _, sample_ in enumerate(sample):
            if not isinstance(sample_, list):
                sample_ = [sample_]

            cur_image_ids = []
            start_time_stamp = 0

            train_height = sample_[-1].shape[-2] // self.patch_size
            train_width = sample_[-1].shape[-1] // self.patch_size

            for clip_ in sample_:
                _, _, temp, height, width = clip_.shape
                height = height // self.patch_size
                width = width // self.patch_size
                cur_image_ids.append(self._prepare_image_ids(batch_size, temp, height, width, train_height, train_width, device, start_time_stamp=start_time_stamp))
                start_time_stamp += temp

            cur_image_ids = torch.cat(cur_image_ids, dim=1)
            image_ids_list.append(cur_image_ids)

        return image_ids_list
    
    @torch.no_grad()
    def _prepare_audio_ids(self, audio_sample, device): 
        audio_ids_list = []
        for i, sample_ in enumerate(audio_sample):
            temp = sample_.shape[1] 
            size_all = sample_.shape[2]
            batch_size = sample_.shape[0] 
            
            audio_ids = torch.zeros(temp, size_all, 3)
            
            size = size_all/3
            for i in range(temp):
                audio_ids[i, :, 0] = torch.arange(size_all) / (size) + ((i)-(size)//2/(size))
            
            audio_ids = audio_ids[None, :].repeat(batch_size, 1, 1, 1)
            audio_ids = rearrange(audio_ids, 'b t s c -> b (t s) c') 
            audio_ids = audio_ids.to(device=device)
            audio_ids_list.append(audio_ids)
        
        return audio_ids_list

    def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask):
        if isinstance(sample[0], list):
            device = sample[0][-1].device
            pad_batch_size = sample[0][-1].shape[0]
        else:
            device = sample[0].device
            pad_batch_size = sample[0].shape[0]

        num_stages = len(sample)
        height_list = [];width_list = [];temp_list = []
        trainable_token_list = []

        for i_b, sample_ in enumerate(sample):
            if isinstance(sample_, list):
                sample_ = sample_[-1] 
            _, _, temp, height, width = sample_.shape
            height = height // self.patch_size 
            width = width // self.patch_size
            temp_list.append(temp)
            height_list.append(height)
            width_list.append(width)
            trainable_token_list.append(height * width * temp)
        
        # prepare the RoPE IDs, 
        image_ids_list = self._prepare_pyramid_image_ids(sample, pad_batch_size, device)
        text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 3).to(device=device)
        input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
        image_rotary_emb = [self.pos_embed(input_ids) for input_ids in input_ids_list] 

        hidden_states, hidden_length = [], []
    
        for sample_ in sample:
            video_tokens = []

            for each_latent in sample_:
                each_latent = rearrange(each_latent, 'b c t h w -> b t h w c')
                each_latent = rearrange(each_latent, 'b t (h p1) (w p2) c -> b (t h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) 
                
                video_tokens.append(each_latent)
            video_tokens = torch.cat(video_tokens, dim=1) 
            video_tokens = self.x_embedder(video_tokens) 
            hidden_states.append(video_tokens)
            hidden_length.append(video_tokens.shape[1])
        assert encoder_attention_mask.shape[1] == encoder_hidden_length
        real_batch_size = encoder_attention_mask.shape[0]

        # prepare text ids
        text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length)
        text_ids = text_ids.to(device) 
        text_ids[encoder_attention_mask == 0] = 0

        # prepare image ids
        image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length))
        image_ids = image_ids.to(device)
        image_ids_list = []
        for i_p, length in enumerate(hidden_length):
            image_ids_list.append(image_ids[i_p::num_stages][:, :length])

        attention_mask = []
        for i_p in range(len(hidden_length)):
            image_ids = image_ids_list[i_p] 
            token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1) 
            stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') 
            if self.use_temporal_causal:
                input_order_ids = input_ids_list[i_p][:,:,0]
                temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j')
                stage_attention_mask = stage_attention_mask & temporal_causal_mask
            attention_mask.append(stage_attention_mask)

        return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb

    def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list):
        # To split the hidden states
        batch_size = batch_hidden_states.shape[0]
        output_hidden_list = []
        batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1)

        for i_p, length in enumerate(hidden_length):
            width, height, temp = widths[i_p], heights[i_p], temps[i_p]
            trainable_token_num = trainable_token_list[i_p]
            hidden_states = batch_hidden_states[i_p]

            # only the trainable token are taking part in loss computation
            hidden_states = hidden_states[:, -trainable_token_num:]

            # unpatchify
            hidden_states = hidden_states.reshape(
                shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels // 4)
            )
            hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c")
            hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w")
            output_hidden_list.append(hidden_states)

        return output_hidden_list

    def forward(
        self,
        sample: torch.FloatTensor, 
        encoder_hidden_states: torch.Tensor = None,
        encoder_attention_mask: torch.FloatTensor = None, 
        audio: Optional[torch.Tensor] = None,
        pooled_projections: torch.Tensor = None,
        timestep_ratio: torch.LongTensor = None,
        audio_temperature: float = 1.0,
        additional_args: Dict[str, Any] = {},
    ):
        ### Ablation study for layer selection
        ablation = additional_args.get("layer_ablation", {})
        skip_mmdit_idx_list = ablation.get("mmdit", []) 
        skip_single_idx_list = ablation.get("single", [])       
        
        skip_audio = additional_args.get("skip_audio", False)
        
        if isinstance(skip_mmdit_idx_list, int):
            skip_mmdit_idx_list = [skip_mmdit_idx_list]
        if isinstance(skip_single_idx_list, int):
            skip_single_idx_list = [skip_single_idx_list]
            
        temb = self.time_text_embed(timestep_ratio, pooled_projections)
        encoder_hidden_states = self.context_embedder(encoder_hidden_states)
        encoder_hidden_length = encoder_hidden_states.shape[1]
        
        hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, attention_mask, \
                image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask)
        
        hidden_states = torch.cat(hidden_states, dim=1)
            
        if (self.use_audio_cross_attn) and (audio is not None) and (not skip_audio):
            num_stages = len(hidden_length)
            if not isinstance(audio, list):
                audio = [audio]
            assert audio[0].dim() == 4, f"Expected audio[0] to have ndim=4, but got ndim={audio[0].dim()}." 

            n_frame_list = []
            segment_size = audio[0].shape[2]
            for i_s in range(len(audio)): 
                n_frames = audio[i_s].shape[1] 
                n_frame_list.append(n_frames)
            audio_time_length_list = [x * segment_size for x in n_frame_list]
            
            # audio RoPE IDs
            audio_ids_list = self._prepare_audio_ids(audio, device=hidden_states.device) 
            audio_rotary_emb = [self.pos_embed(audio_ids) for audio_ids in audio_ids_list] 
            
            # merge audio input
            audio_hidden_states = []
            temporal_ids_list = []
            audio_mask_list = []

            device = audio[0][-1].device
            for audio_ in audio:
                # Temporal IDs
                batch_size, temporal, segment_size, feature_dim = audio_.shape
                
                temporal_ids = torch.arange(temporal).repeat(batch_size, 1).unsqueeze(-1).expand(-1, -1, segment_size)+1 
                audio_ = rearrange(audio_, 'b f s c -> b (f s) c')              
                temporal_ids = rearrange(temporal_ids, 'b f s -> b (f s)').to(device=device) 
                
                audio_mask = (audio_ != 0).any(dim=-1)  
                
                audio_hidden_states.append(audio_)
                temporal_ids_list.append(temporal_ids)
                audio_mask_list.append(audio_mask) 
                
            audio_hidden_states = torch.cat(audio_hidden_states, dim=1) 
            audio_hidden_states = self.audio_embedder(audio_hidden_states)
            audio_mask = torch.cat(audio_mask_list, dim=1)

            # video image_ids
            device = sample[0][-1].device 
            pad_batch_size = sample[0][-1].shape[0]
            image_ids_list = self._prepare_pyramid_image_ids(sample, pad_batch_size, device) 
            image_ids_list = [image_ids[:,:,0] for image_ids in image_ids_list] 
            
            # Attention mask
            audio_attention_mask_list = []

            for temporal_ids, image_ids, this_audio_mask in zip(temporal_ids_list, image_ids_list, audio_mask_list):                
                indices = (image_ids == 1).nonzero(as_tuple=True)
                first_one_index = indices[1][0].item()
                
                image_ids_expanded = image_ids[:, first_one_index:].unsqueeze(-1)
                temporal_ids_expanded = temporal_ids.unsqueeze(1)
                
                audio_attn_mask = (temporal_ids_expanded == image_ids_expanded)

                this_audio_mask = this_audio_mask.unsqueeze(1)
                audio_attn_mask = audio_attn_mask & this_audio_mask 
                
                use_audio_mask = True
                if not use_audio_mask:
                    audio_attn_mask = torch.ones_like(audio_attn_mask)
                audio_attention_mask_list.append(audio_attn_mask.unsqueeze(1))
        else:
            if skip_audio:
                print(f"Warning: Skip Audio Cross Attention Layers")
            audio_hidden_states = None
            audio_attention_mask_list = None
            audio_time_length_list = None
            audio_rotary_emb = None
            
        for index_block, block in enumerate(self.transformer_blocks):
            if index_block in skip_mmdit_idx_list:
                continue 
            
            if self.training and self.gradient_checkpointing and (index_block <= int(len(self.transformer_blocks) * self.gradient_checkpointing_ratio)):

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    encoder_hidden_states, 
                    encoder_attention_mask, 
                    temb,
                    attention_mask,
                    hidden_length,
                    image_rotary_emb,
                    audio_hidden_states, 
                    audio_attention_mask_list, 
                    audio_time_length_list, 
                    audio_rotary_emb,
                    audio_temperature,
                    **ckpt_kwargs,
                )

            else:
                encoder_hidden_states, hidden_states = block(
                    hidden_states=hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    temb=temb,
                    attention_mask=attention_mask,
                    hidden_length=hidden_length,
                    image_rotary_emb=image_rotary_emb,
                    audio_hidden_states=audio_hidden_states,
                    audio_attention_mask_list=audio_attention_mask_list,
                    audio_time_length_list=audio_time_length_list,
                    audio_rotary_emb=audio_rotary_emb,
                    audio_temperature=audio_temperature
                )

        # remerge for single attention block
        num_stages = len(hidden_length)
        batch_hidden_states = list(torch.split(hidden_states, hidden_length, dim=1))
        concat_hidden_length = []

        for i_p in range(len(hidden_length)):

            batch_hidden_states[i_p] = torch.cat([encoder_hidden_states[i_p::num_stages], batch_hidden_states[i_p]], dim=1)

            concat_hidden_length.append(batch_hidden_states[i_p].shape[1])

        hidden_states = torch.cat(batch_hidden_states, dim=1)

        for index_block, block in enumerate(self.single_transformer_blocks):   
            if index_block in skip_single_idx_list:
                continue
            if self.training and self.gradient_checkpointing and (index_block <= int(len(self.single_transformer_blocks) * self.gradient_checkpointing_ratio)):

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    temb, 
                    encoder_attention_mask,
                    attention_mask,
                    concat_hidden_length,
                    image_rotary_emb,
                    audio_hidden_states,
                    audio_attention_mask_list,
                    audio_time_length_list,
                    audio_rotary_emb,
                    audio_temperature,
                    **ckpt_kwargs,
                )

            else:
                hidden_states = block(
                    hidden_states=hidden_states,
                    temb=temb,
                    encoder_attention_mask=encoder_attention_mask,  
                    attention_mask=attention_mask,
                    hidden_length=concat_hidden_length,
                    image_rotary_emb=image_rotary_emb,
                    audio_hidden_states=audio_hidden_states,
                    audio_attention_mask_list=audio_attention_mask_list,
                    audio_time_length_list=audio_time_length_list,
                    audio_rotary_emb=audio_rotary_emb,
                    audio_temperature=audio_temperature
                )

        batch_hidden_states = list(torch.split(hidden_states, concat_hidden_length, dim=1))

        for i_p in range(len(concat_hidden_length)):
            batch_hidden_states[i_p] = batch_hidden_states[i_p][:, encoder_hidden_length :, ...]
            
        hidden_states = torch.cat(batch_hidden_states, dim=1)
        hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length)
        hidden_states = self.proj_out(hidden_states)

        output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list)

        return output
    
    @classmethod
    def from_pretrained(cls, pretrained_model_path, **kwargs):
        pretrained_model_path = os.path.join(pretrained_model_path, "diffusion_pytorch_model.safetensors")
        print(f"loaded DiT pretrained weights from {pretrained_model_path} ...")
        config_file_name = pretrained_model_path.split("/")[-1].split("_")[-1].split(".")[0]
        config_file = os.path.join(os.path.dirname(pretrained_model_path), f'config.json')
        if not os.path.isfile(config_file):
            raise RuntimeError(f"{config_file} does not exist")
        with open(config_file, "r") as f:
            config = json.load(f)

        model = cls.from_config(config,  **kwargs)
        if not os.path.isfile(pretrained_model_path):
            raise RuntimeError(f"{pretrained_model_path} does not exist")
        state_dict = load_file(pretrained_model_path)
        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
        m, u = model.load_state_dict(state_dict, strict=False)
        print(f"### DiT missing keys: {m}")
        print(f"### DiT missing keys: {len(m)}; \n### DiT unexpected keys: {len(u)};")
        assert len(u) == 0
        
        return model