# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import os
import requests
from einops import rearrange
from tqdm import tqdm 
# import scipy.stats as stats
from huggingface_hub import hf_hub_download

import torch
import torch.nn as nn
import torch.nn.functional as F


from transformers import AutoModelForCausalLM, PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaRMSNorm

from bifrost.models.diffusion import Diffusion
from bifrost.models.mar.utils import sample_orders, random_masking, patchify
from bifrost.models.mask_schedule import get_mask_chedule, mask_or_random_replace_tokens

from bifrost.models.flowar.flowmodel import SimpleTransformerAdaLN
from bifrost.models.flowar.flowloss import SILoss

from diffusers import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory


def model_name_to_cls(cls_name):

    if "ShallowUViTEncoder" in cls_name:
        from bifrost.models.uvit import ShallowUViTEncoder
        cls = ShallowUViTEncoder
    elif "ShallowUViTDecoder" in cls_name:
        from bifrost.models.uvit import ShallowUViTDecoder
        cls = ShallowUViTDecoder
    elif 'AutoencoderKL' in cls_name:
        from bifrost.models.autoencoder.autoencoder_kl import AutoencoderKL
        cls = AutoencoderKL    
    elif 'Qwen2_5_VLForConditionalGeneration' in cls_name:
        from bifrost.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration 
        cls = Qwen2_5_VLForConditionalGeneration
    elif 'Qwen2ForCausalLM' in cls_name:
        from bifrost.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM 
        cls = Qwen2ForCausalLM
    elif 'VQ-16' in cls_name:
        from bifrost.models.llamagen.vq_model import VQ_16
        cls = VQ_16
    elif 'VQ-8' in cls_name:
        from bifrost.models.llamagen.vq_model import VQ_8
        cls = VQ_8
    elif 'magvitv2' in cls_name:
        from bifrost.models.magvitv2.modeling_magvitv2 import MAGVITv2
        cls = MAGVITv2
    elif "MlpProjector" in cls_name:
        from bifrost.models.projector import MlpProjector
        cls = MlpProjector
    elif "VisionHead" in cls_name:
        from bifrost.models.projector import VisionHead
        cls = VisionHead
    elif 'black-forest-labs/FLUX.1-dev' in cls_name:
        from bifrost.models.diffusion_decoder import FLUXControlNetDiffusionDecoder
        cls = FLUXControlNetDiffusionDecoder
    elif cls_name == None:
        return None
    else:
        raise ValueError(f"class_name {cls_name} is invalid.")

    return cls



class MultiModalityCausalLM(PreTrainedModel):

    def __init__(self, config, **kwargs):
        super().__init__(config)

        self.use_clip_visual_encoder = config.use_clip_visual_encoder
        self.vision_denoising_type = config.vision_denoising_type
        self.vision_head_type = config.vision_head_type
        self.vision_loss_type = config.vision_loss_type
        self.batch_size_t2i = config.batch_size_t2i 
        self.t2i_resolution = config.t2i_resolution
        self.num_visual_gen_tokens = config.num_visual_gen_tokens
        self.vision_pos_emb_type = config.vision_pos_emb_type
        self.lambda_gpu = config.lambda_gpu
        self.use_2d_query_tokens = config.use_2d_query_tokens
        self.e2e_training = config.e2e_training
        self.ctrlnet_training = config.ctrlnet_training
        self.remove_vae = config.remove_vae
        self.proportion_empty_prompts = config.proportion_empty_prompts
        self.lambda_clip = config.lambda_clip
        self.vae_w_ctrlnet_training = config.vae_w_ctrlnet_training
        self.vae_wo_ctrlnet_training = config.vae_wo_ctrlnet_training
        self.inner_dim = config.inner_dim
        self.config = config

        # vision-language model
        vision_language_model_config = config.vision_language_model_config
        vision_language_model_cls = model_name_to_cls(vision_language_model_config.cls)
        self.vision_language_model = vision_language_model_cls.from_pretrained(
            vision_language_model_config.params.model_name_or_path,
            torch_dtype= torch.bfloat16 # torch.bfloat16
            )

        if vision_language_model_config.params.remove_vision_und_encoder:
            del self.vision_language_model.visual
        elif 'vision_language_model.visual' in vision_language_model_config.params.frozen_modules_in_vlm:
            self.vision_language_model.visual.eval() 
            for p in self.vision_language_model.visual.parameters(): 
                p.requires_grad = False 

        if config.add_vision_branch:
            num_hidden_layers = self.vision_language_model.model.config.num_hidden_layers
            for i in range(num_hidden_layers):
                self.vision_language_model.model.layers[i].self_attn.add_vision_attention()
                self.vision_language_model.model.layers[i].add_vision_mlp()
                if not config.add_vision_branch_reuse_layernorm:
                    self.vision_language_model.model.layers[i].add_vision_layernorms()
            if not config.add_vision_branch_reuse_layernorm:
                self.vision_language_model.model.add_vision_norm()


        # diffusion module
        self.diffusion = Diffusion(config)


        if config.use_discrete_visual_tokenizer:
            ##### tokenizer #####
            vision_gen_tokenizer_config = config.vision_gen_tokenizer_config
            vision_gen_tokenizer_cls = model_name_to_cls(vision_gen_tokenizer_config.cls)
            if vision_gen_tokenizer_config.cls in ['VQ-16', 'VQ-8']:
                # download checkopint to local path 
                vq_model_map = {'VQ-16': "vq_ds16_c2i.pt", 'VQ-8': "vq_ds8_c2i.pt"}
                url = vision_gen_tokenizer_config.params.model_name_or_path
                local_dir = "vq_llamagen"
                local_path = os.path.join(local_dir, vision_gen_tokenizer_config.params.model_name_or_path.split("/")[-1])
                os.makedirs(local_dir, exist_ok=True)
                response = requests.get(url, stream=True)
                if response.status_code == 200:
                    with open(local_path, "wb") as f:
                        for chunk in response.iter_content(chunk_size=8192):
                            f.write(chunk)
                    print(f"Downloaded model to {local_path}")
                else:
                    raise RuntimeError(f"Failed to download model. HTTP Status: {response.status_code}")
                # load the checkpoint
                self.vq_model = vision_gen_tokenizer_cls(
                    codebook_size=vision_gen_tokenizer_config.params.image_token_size,
                    codebook_embed_dim=vision_gen_tokenizer_config.params.n_embed)
                checkpoint = torch.load(local_path, map_location="cpu")
                self.vq_model.load_state_dict(checkpoint["model"])
            elif vision_gen_tokenizer_config.cls == 'magvitv2':
                self.vq_model = vision_gen_tokenizer_cls.from_pretrained(vision_gen_tokenizer_config.params.model_name_or_path)

            self.vq_model.eval()
            for p in self.vq_model.parameters(): 
                p.requires_grad = False 
                

            ##### aligner #####
            vision_gen_aligner_config = config.vision_gen_aligner_config
            gen_vision_aligner_cls = model_name_to_cls(vision_gen_aligner_config.cls)
            self.vision_gen_aligner = gen_vision_aligner_cls(vision_gen_aligner_config.params)
            for p in self.vision_gen_aligner.parameters(): 
                p.requires_grad = True 

            ##### head #####
            vision_gen_head_config = config.vision_gen_head_config
            get_vision_head_cls = model_name_to_cls(vision_gen_head_config.cls)
            self.vision_gen_head = get_vision_head_cls(vision_gen_head_config.params)
            self.image_token_size = vision_gen_tokenizer_config.params.image_token_size
            for p in self.vision_gen_head.parameters(): 
                p.requires_grad = True 

            ##### embed tokens #####
            self.visual_embed_tokens = torch.nn.Embedding(vision_gen_tokenizer_config.params.image_token_size, vision_gen_tokenizer_config.params.n_embed)

        else:
            # # vision generation vae
            vision_gen_vae_config = config.vision_gen_vae_config
            vision_gen_vae_model_cls = model_name_to_cls(vision_gen_vae_config.cls)
            if vision_gen_vae_config.params.model_name_or_path == 'black-forest-labs/FLUX.1-dev':
                self.vision_gen_vae_model = vision_gen_vae_model_cls.from_pretrained(
                    vision_gen_vae_config.params.model_name_or_path, 
                    subfolder='vae', 
                    token=vision_gen_vae_config.params.huggingface_token
                )
                self.in_channels = 64
            elif vision_gen_vae_config.params.model_name_or_path == 'stabilityai/sdxl-vae':
                self.vision_gen_vae_model = vision_gen_vae_model_cls.from_pretrained(
                    vision_gen_vae_config.params.model_name_or_path, 
                    token=vision_gen_vae_config.params.huggingface_token
                )
                self.in_channels = 16
            elif vision_gen_vae_config.params.model_name_or_path == 'pretrained_models/vae/kl16.ckpt':
                # url = "https://www.dropbox.com/scl/fi/hhmuvaiacrarfg28qxhwz/kl16.ckpt?rlkey=l44xipsezc8atcffdp4q7mwmh&dl=0"
                local_dir = "vae"
                local_path = os.path.join(local_dir, vision_gen_vae_config.params.model_name_or_path.split("/")[-1])
                os.makedirs(local_dir, exist_ok=True)
                headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
                r = requests.get("https://www.dropbox.com/scl/fi/hhmuvaiacrarfg28qxhwz/kl16.ckpt?rlkey=l44xipsezc8atcffdp4q7mwmh&dl=0", stream=True, headers=headers)
                print("Downloading KL-16 VAE...")
                with open(local_path, 'wb') as f:
                    for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=254):
                        if chunk:
                            f.write(chunk)                            

                from bifrost.models.autoencoder.vae_kl import AutoencoderKL
                self.vision_gen_vae_model = AutoencoderKL(
                    embed_dim=16, 
                    ch_mult=(1, 1, 2, 2, 4), 
                    ckpt_path=local_path
                )
                self.in_channels = 16

            if vision_gen_vae_config.params.frozen_vision_gen_vae:
                self.vision_gen_vae_model.eval() 
                for p in self.vision_gen_vae_model.parameters(): 
                    p.requires_grad = False 

            if self.vision_denoising_type == 'mar': # self.mar_style_vision_branch:
                self.in_channels = self.inner_dim

                if not self.use_2d_query_tokens:
                    self.mask_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim))
                    if self.vision_pos_emb_type == 'learnable_pos_emb':
                        if (self.vae_wo_ctrlnet_training or self.vae_w_ctrlnet_training) and not self.config.vae_scale_by_4:
                            self.learnable_pos_emb = nn.Parameter(torch.zeros(1, self.num_visual_gen_tokens * 4, self.inner_dim))
                        else:
                            self.learnable_pos_emb = nn.Parameter(torch.zeros(1, self.num_visual_gen_tokens, self.inner_dim))    
                    else:
                        self.learnable_pos_emb = None
                else:
                    self.learnable_2d_query_tokens = nn.Parameter(torch.zeros(1, self.num_visual_gen_tokens, self.inner_dim))
                    

                # vision Loss
                if self.vision_head_type == 'linear': 
                    if self.vae_wo_ctrlnet_training or self.vae_w_ctrlnet_training:
                        # self.vision_head = nn.Linear(self.inner_dim, 16 * 4, bias=True)
                        self.vision_head = nn.Linear(self.inner_dim, self.vision_gen_vae_model.config.latent_channels, bias=True)
                        from bifrost.flux_cnet_diffusers.controlnet_flux import NHWCDownsampleBlock, NHWCUpsampleBlock, NHWCDownsampleBlock2, NHWCUpsampleBlock2
                        if self.config.vae_scale_by_4:
                            self.down_sampling_block = NHWCDownsampleBlock2(2048)
                            self.up_sampling_block = NHWCUpsampleBlock2(2048)
                        else:
                            self.down_sampling_block = NHWCDownsampleBlock(2048)
                            self.up_sampling_block = NHWCUpsampleBlock(2048)        
                    else:
                        self.vision_head = nn.Linear(self.inner_dim, self.in_channels, bias=True)
             
            else:
                # # vision generation encoder
                vision_gen_enc_config = config.vision_gen_enc_config
                vision_gen_enc_cls = model_name_to_cls(vision_gen_enc_config.cls)
                self.vision_gen_enc_model = vision_gen_enc_cls(**vision_gen_enc_config.params)
                self.vision_gen_enc_aligner = nn.Linear(768, 2048, bias=True)
                if vision_gen_enc_config.params.frozen_vision_gen_encdec:
                    self.vision_gen_enc_model.eval() 
                    for p in self.vision_gen_enc_model.parameters(): 
                        p.requires_grad = False 
                    self.vision_gen_enc_aligner.eval() 
                    for p in self.vision_gen_enc_aligner.parameters(): 
                        p.requires_grad = False 

                # # vision generation decoder
                vision_gen_dec_config = config.vision_gen_dec_config
                vision_gen_dec_cls = model_name_to_cls(vision_gen_dec_config.cls)
                self.vision_gen_dec_model = vision_gen_dec_cls(**vision_gen_dec_config.params)
                self.vision_gen_dec_aligner_norm = LlamaRMSNorm(2048, eps=1e-06)
                self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)


        if self.e2e_training:
            diffusion_decoder_cls = model_name_to_cls(self.config.diffusion_decoder_config.cls)
            self.diffusion_decoder = diffusion_decoder_cls(self.config.diffusion_decoder_config.params, input_cond_dim=self.inner_dim, remove_vae=self.remove_vae)

        if self.vae_wo_ctrlnet_training or self.vae_w_ctrlnet_training:
            diffusion_decoder_cls = model_name_to_cls(self.config.diffusion_decoder_config.cls)
            self.diffusion_decoder = diffusion_decoder_cls(self.config.diffusion_decoder_config.params, input_cond_dim=64, remove_vae=self.remove_vae)
            self.vision_embed_proj_in = nn.Linear(self.vision_gen_vae_model.config.latent_channels, self.inner_dim, bias=True)

            
            if self.vae_wo_ctrlnet_training:
                del self.diffusion_decoder.text_encoder_one 
                del self.diffusion_decoder.text_encoder_two 
                del self.diffusion_decoder.flux_transformer 
                del self.diffusion_decoder.flux_controlnet 


    # this forward function is only used for training
    def forward(
            self,
            lm_flow=None,
            t2i_flow=None,
            mmu_flow=None,
            **kwargs,
    ):


        max_seq_length = kwargs.get('max_seq_length')
        num_visual_gen_tokens = kwargs.get('num_visual_gen_tokens') * 4 if ((self.vae_wo_ctrlnet_training or self.vae_w_ctrlnet_training) and not self.config.vae_scale_by_4) else kwargs.get('num_visual_gen_tokens')
        label_smoothing = kwargs.get('label_smoothing')
        min_masking_rate = kwargs.get('min_masking_rate')
        log_task_specific_loss = kwargs.get('log_task_specific_loss')

        t2i_coeff = kwargs.get('t2i_coeff')
        lm_coeff = kwargs.get('lm_coeff')
        mmu_coeff = kwargs.get('mmu_coeff')
        
        batch_size_t2i = kwargs.get('batch_size_t2i', None)
        batch_size_lm = kwargs.get('batch_size_lm', None)
        batch_size_mmu = kwargs.get('batch_size_mmu', None)

        precise_prompt_mask = kwargs.get('precise_prompt_mask')
        add_vision_branch = kwargs.get("add_vision_branch")
        add_vision_branch_reuse_layernorm = kwargs.get("add_vision_branch_reuse_layernorm")
        add_timestep_token = kwargs.get("add_timestep_token")
        use_discrete_visual_tokenizer = kwargs.get("use_discrete_visual_tokenizer")
        skip_text_part2 = kwargs.get("skip_text_part2")
        add_vision_gen_mask_token = kwargs.get("add_vision_gen_mask_token")
        suffix_length = 1 if precise_prompt_mask else 3
        add_vision_soi_eoi_tokens = kwargs.get("add_vision_soi_eoi_tokens")
        add_vision_soi_token = kwargs.get("add_vision_soi_token")
        vision_soi_eoi_tokens_length = 1 if add_vision_soi_eoi_tokens else 0
        vision_soi_token_length = 1 if add_vision_soi_token else 0

        suffix_length = 0

        # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*
        # Build formatted sequences for class-conditional/text-to-image generation
        # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*

        pixel_values_or_image_ids = t2i_flow["pixel_values"]
        image_clip_embs = t2i_flow['image_clip_embs'] if 'image_clip_embs' in t2i_flow else None
        input_ids_t2i = t2i_flow["input_ids"]
        labels = t2i_flow["labels"]
        attention_mask = t2i_flow["attention_mask"]
        image_position_mask = t2i_flow['image_position_mask']
        position_ids = t2i_flow['position_ids']
        image_grid_thw = t2i_flow['t2i_image_grid_thw']
        ar_mask = t2i_flow['ar_mask']
        t_emb = None
        device = pixel_values_or_image_ids.device



        position_ids = rearrange(position_ids, "bsz k c -> k bsz c")
       
        if True:
            with torch.no_grad():
                if not hasattr(self, 'vision_gen_vae_model'): # lora FT clip encoder
                    # if self.use_cached_dataset or self.lambda_gpu:
                    if self.lambda_gpu:
                        if image_clip_embs is None:
                            input_embs_img = self.vision_language_model.visual(pixel_values_or_image_ids.to(self.dtype), grid_thw=image_grid_thw, same_grid_images=False)
                            input_embs_img = rearrange(input_embs_img, "(b hw) c -> b hw c", b=pixel_values_or_image_ids.shape[0]) # torch.Size([2, 64, 2048])
                        else:
                            input_embs_img = image_clip_embs.to(self.dtype)
                        if self.e2e_training or self.vae_w_ctrlnet_training or self.vae_wo_ctrlnet_training:
                            pixel_latents = self.diffusion_decoder.vae.encode(pixel_values_or_image_ids.to(self.dtype)).latent_dist.sample() # encode image with vae 
                    else:
                        input_embs_img = self.vision_language_model.visual(pixel_values_or_image_ids.to(self.dtype), grid_thw=image_grid_thw, same_grid_images=False)
                        input_embs_img = rearrange(input_embs_img, "(b hw) c -> b hw c", b=pixel_values_or_image_ids.shape[0]) # torch.Size([2, 64, 2048])

                elif self.vision_gen_vae_model.config._name_or_path == 'stabilityai/sdxl-vae':
                    input_embs_img = self.vision_gen_vae_model.encode(pixel_values_or_image_ids.to(self.dtype)).latent_dist.sample() # b c h w torch.Size([10, 16, 32, 32])
                    input_embs_img = input_embs_img * self.vision_gen_vae_model.config.scaling_factor
                    vae_scale_factor = 2 ** (len(self.vision_gen_vae_model.config.block_out_channels) - 1)

                elif self.vision_gen_vae_model.config._name_or_path == 'black-forest-labs/FLUX.1-dev':
                    input_embs_img = self.vision_gen_vae_model.encode(pixel_values_or_image_ids.to(self.dtype)).latent_dist.sample() # b c h w torch.Size([10, 16, 32, 32])
                    input_embs_img = (input_embs_img - self.vision_gen_vae_model.config.shift_factor) * self.vision_gen_vae_model.config.scaling_factor
                    vae_scale_factor = 2 ** (len(self.vision_gen_vae_model.config.block_out_channels) - 1)

            if self.vae_wo_ctrlnet_training or (self.vae_w_ctrlnet_training and not self.ctrlnet_training):
                if hasattr(self, 'vision_gen_vae_model'): # lora FT clip encoder
                    input_embs_img = rearrange(input_embs_img, "b c h w -> b h w c")
                    z_emb = self.vision_embed_proj_in(input_embs_img)
                    z_emb = self.down_sampling_block(z_emb) 
                else:
                    input_embs_img = FluxControlNetPipeline._pack_latents(
                        input_embs_img,
                        pixel_values_or_image_ids.shape[0],
                        input_embs_img.shape[1],
                        input_embs_img.shape[2],
                        input_embs_img.shape[3],
                    ) ### target
                    z_emb = self.vision_embed_proj_in(input_embs_img)
                               


            if not self.ctrlnet_training:
                if self.vision_denoising_type in ['ar', 'mar', 'xar', 'flowar']: # self.mar_style_vision_branch:
                    if self.use_2d_query_tokens:
                        z_emb = self.learnable_2d_query_tokens.repeat(input_embs_img.shape[0], 1, 1) #
                    else:
                        if self.vae_wo_ctrlnet_training or self.vae_w_ctrlnet_training:
                            z_emb = self.mask_token * ar_mask.unsqueeze(-1) + z_emb * (1-ar_mask.unsqueeze(-1)) # mask=1 for masked positions
                        else:
                            z_emb = self.mask_token * ar_mask.unsqueeze(-1) + input_embs_img * (1-ar_mask.unsqueeze(-1)) # mask=1 for masked positions
                        if self.vision_pos_emb_type == 'learnable_pos_emb':
                            z_emb = z_emb + self.learnable_pos_emb

        if not self.ctrlnet_training:
            ## step 4: combine text & visual tokens & time_emb
            time_token_length = 1 if add_timestep_token else 0
            if add_vision_soi_eoi_tokens:
                text_embs_part1 = self.vision_language_model.model.embed_tokens(input_ids_t2i[:, :-(num_visual_gen_tokens + suffix_length + time_token_length + 2 )]) # 3 is for <|vision_end|>, <|im_end|> and \n, 1 is to skip <|im_end|> placeholder for time_emb
            else:
                text_embs_part1 = self.vision_language_model.model.embed_tokens(input_ids_t2i[:, :-(num_visual_gen_tokens + suffix_length + time_token_length )]) # 3 is for <|vision_end|>, <|im_end|> and \n, 1 is to skip <|im_end|> placeholder for time_emb
            
            if skip_text_part2:
                if add_timestep_token:
                    input_embeddings = torch.cat((text_embs_part1, t_emb.unsqueeze(1), z_emb), dim=1) 
                else:
                    input_embeddings = torch.cat((text_embs_part1, z_emb), dim=1) 
                if suffix_length > 0:
                    attention_mask = attention_mask[:, :, :-suffix_length, :-suffix_length]
                    position_ids = position_ids[:, :, :-suffix_length]
                    image_position_mask = image_position_mask[:, :-suffix_length] if add_vision_branch is not None else None
      
        if not self.ctrlnet_training:
            if self.vision_language_model.config._name_or_path in ["Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-3B-Instruct"]:
                if self.vision_pos_emb_type == '1drope': #if use_1d_rope_for_vision:
                    position_ids = None

            outputs = self.vision_language_model(inputs_embeds=input_embeddings.to(self.dtype), 
                                        use_cache=None, 
                                        attention_mask=attention_mask.to(self.dtype),
                                        past_key_values=None,
                                        return_dict=True,
                                        position_ids = position_ids,
                                        image_position_mask=image_position_mask if add_vision_branch is not None else None,
                                        t_emb = t_emb
                                        )

            hidden_states = outputs['hidden_states'] 
            


        ## step 6: loss calculation
        # 6.1: t2i diffusion denoising velocity prediction
        if batch_size_t2i > 0 and not self.ctrlnet_training:

            if self.diffusion.vision_denoising_type == 'mar':
                if skip_text_part2:
                    denoised_hidden_states = hidden_states[:batch_size_t2i, -(num_visual_gen_tokens):] # torch.Size([10, 256, 2048])
                else:
                    denoised_hidden_states = hidden_states[:batch_size_t2i, -(num_visual_gen_tokens + suffix_length):-suffix_length] # torch.Size([10, 256, 2048])

                if self.vae_wo_ctrlnet_training or self.vae_w_ctrlnet_training:
                    denoised_hidden_states = self.up_sampling_block(denoised_hidden_states) #
                    denoised_hidden_states = self.vision_head(denoised_hidden_states)
                    denoised_hidden_states = rearrange(denoised_hidden_states, "b (h w ) c -> b h w c", h=int(denoised_hidden_states.shape[1]**0.5))
                    model_pred = denoised_hidden_states
                    target = input_embs_img
                else:
                    bsz, seq_len, _ = input_embs_img.shape

                    if self.vision_head_type == 'linear':
                        model_pred = self.vision_head(denoised_hidden_states)
                        target = input_embs_img

                if self.vision_loss_type == 'mse':
                    loss_t2i_clip = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
               

        ###########################################
        ############   clip controlnet ############
        ###########################################

        if self.e2e_training:
            pixel_latents_tmp = (pixel_latents - self.diffusion_decoder.vae.config.shift_factor) * self.diffusion_decoder.vae.config.scaling_factor # torch.Size([2, 16, 32, 32])
            pixel_latents = FluxControlNetPipeline._pack_latents(
                pixel_latents_tmp,
                pixel_values_or_image_ids.shape[0],
                pixel_latents_tmp.shape[1],
                pixel_latents_tmp.shape[2],
                pixel_latents_tmp.shape[3],
            ) 

            if self.ctrlnet_training and self.vae_w_ctrlnet_training:
                control_image = pixel_latents 
            elif self.ctrlnet_training and not self.vae_w_ctrlnet_training:
                control_image = input_embs_img
            else:
                control_image = model_pred

            latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
                batch_size=pixel_latents_tmp.shape[0],
                height=pixel_latents_tmp.shape[2] // 2,
                width=pixel_latents_tmp.shape[3] // 2,
                device=pixel_values_or_image_ids.device,
                dtype=pixel_values_or_image_ids.dtype,
            ) 

            
            bsz = pixel_latents.shape[0]
            noise = torch.randn_like(pixel_latents).to(self.device).to(dtype=self.dtype) 
            # Sample a random timestep for each image
            # for weighting schemes where we sample timesteps non-uniformly
            u = compute_density_for_timestep_sampling(
                weighting_scheme="logit_normal",
                batch_size=bsz,
                logit_mean=0.0,
                logit_std=1.0,
                mode_scale=1.29,
            )
            indices = (u * self.diffusion_decoder.noise_scheduler_copy.config.num_train_timesteps).long()
            timesteps = self.diffusion_decoder.noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)

            # Add noise according to flow matching.
            sigmas = self.diffusion_decoder.get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype, device=self.device)
            noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise 

            # handle guidance
            guidance_scale = 3.5
            if self.diffusion_decoder.flux_transformer.config.guidance_embeds:
                guidance_vec = torch.full(
                    (noisy_model_input.shape[0],),
                    guidance_scale,
                    device=self.device,
                    dtype=self.dtype,
                )
            else:
                guidance_vec = None

            # calc prompts
            texts = t2i_flow['image_labels']
            prompt_embeds, pooled_prompt_embeds, text_ids = self.diffusion_decoder.compute_embeddings(texts, self.proportion_empty_prompts, self.dtype, is_train=True, device=self.device)


            controlnet_block_samples, controlnet_single_block_samples = self.diffusion_decoder.flux_controlnet(
                hidden_states=noisy_model_input, 
                controlnet_cond=control_image, 
                timestep=timesteps / 1000, 
                guidance=guidance_vec, 
                pooled_projections=pooled_prompt_embeds.to(dtype=self.dtype),
                encoder_hidden_states=prompt_embeds.to(dtype=self.dtype), 
                txt_ids=text_ids[0].to(dtype=self.dtype), 
                img_ids=latent_image_ids,
                return_dict=False,
            )
            
            noise_pred = self.diffusion_decoder.flux_transformer(
                hidden_states=noisy_model_input,
                timestep=timesteps / 1000,
                guidance=guidance_vec,
                pooled_projections=pooled_prompt_embeds.to(dtype=self.dtype),
                encoder_hidden_states=prompt_embeds.to(dtype=self.dtype),
                controlnet_block_samples=[sample.to(dtype=self.dtype) for sample in controlnet_block_samples] if controlnet_block_samples is not None else None,
                controlnet_single_block_samples=[sample.to(dtype=self.dtype) for sample in controlnet_single_block_samples] if controlnet_single_block_samples is not None else None,
                txt_ids=text_ids[0].to(dtype=self.dtype),
                img_ids=latent_image_ids,
                return_dict=False,
            )[0] 

            loss_t2i_vae = F.mse_loss(noise_pred.float(), (noise - pixel_latents).float(), reduction="mean")


        if self.e2e_training and not self.ctrlnet_training: # e2e training with 1 or 2 losses
            loss_t2i = loss_t2i_clip * self.lambda_clip + loss_t2i_vae
        elif self.ctrlnet_training: # controlnet-only training 
            loss_t2i = loss_t2i_vae
        elif not self.e2e_training and not self.ctrlnet_training: # MLLM-only training
            loss_t2i = loss_t2i_clip

        outputs = {'loss': loss_t2i}

        return outputs 
