# coding=utf-8
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import shlex
import ast
import json
import io 
import math
import copy
import wandb
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from einops import rearrange
from google.cloud import storage

import dataclasses
from dataclasses import dataclass, field, fields

import typing
from typing import Dict, Optional, Sequence, List, Union, get_args

import transformers
from transformers import AutoTokenizer, AutoProcessor

import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms

from bifrost.models.configuration_bifrost import MultiModalityConfig
from bifrost.models.modeling_bifrost import MultiModalityCausalLM
from bifrost.conversation import Conversation
from bifrost.models.diffusion import retrieve_timesteps
from bifrost.utils import _load_from_checkpoint, _load_from_checkpoint_mllm
from bifrost.train.utils import image_transform
from bifrost.models.mar.utils import sample_orders, random_masking

from diffusers import DDPMScheduler, FlowMatchEulerDiscreteScheduler
from diffusers.utils import load_image
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor


@dataclass
class TrainingArguments(transformers.TrainingArguments):

    ## model 
    vision_language_model: str = field(default='Qwen2_5_VLForConditionalGeneration')
    vision_language_model_name: str = field(default='Qwen/Qwen2.5-VL-3B-Instruct')
    vision_gen_vae: str = field(default=None)
    vision_gen_enc: str = field(default='ShallowUViTEncoder')
    vision_gen_dec: str = field(default='ShallowUViTDecoder')
    vision_gen_tokenizer: str = field(default='magvitv2')

    ## processing
    t2i_resolution: int = field(default=224) # 384
    num_visual_gen_tokens: int = field(default=64) # for 256*256 res images. FLUX VAE scales by 8, UNet scales by 2
    max_seq_length: int = field(default=26)

    ## inference 
    batch_size: int = field(default = 1) # inference batch size
    bf16: bool = field(default = True)
    cfg_weight: float = field(default=1.0) ##################
    cfg_schedule: str = field(default='constant') ##################
    temperature: float = field(default=1.0) ##################

    ## log, save, eval (HF trainer)
    huggingface_token: str = field(default=None)
    validation_prompts_file: str = field(default="./validation_prompts/imagenet_validation_prompts.txt")

    # t2i gen params
    timestep_sampling_strategy: str = field(default='uniform')
    vision_denoising_type: str = field(default='mar')
    add_timestep_token: bool = field(default=False)
    add_vision_gen_mask_token: bool = field(default=False)
    add_vision_soi_token: bool = field(default=False)
    add_vision_soi_eoi_tokens: bool = field(default=False)

    vision_head_type: str = field(default='linear')
    vision_loss_type: str = field(default='mse')
    vision_pos_emb_type: str = field(default='learnable_pos_emb')

    # masks
    full_vision_mask: bool = field(default=True)
    precise_prompt_mask: bool = field(default=True)
    add_vision_branch: bool = field(default=True)
    add_vision_branch_reuse_layernorm: bool = field(default=False)
    use_discrete_visual_tokenizer: bool = field(default=False)
    skip_text_part2: bool = field(default=True)
    proportion_empty_prompts: float = field(default=0.0)
    
    # lora
    use_lora: bool = field(default=False)
    lora_r: int = field(default=320)
    lora_alpha: int = field(default=320) # this is a good post explaining lora: https://medium.com/@fartypantsham/what-rank-r-and-alpha-to-use-in-lora-in-llm-1b4f025fd133
    lora_dropout: float = field(default=0.05)
    lora_target_modules: List[str] = field(default_factory=lambda: ["c_attn", "c_proj", "w1", "w2"])
    lora_weight_path: str = field(default="")
    lora_bias: str = field(default="none") # to keep all params in backbone frozen
    q_lora: bool = field(default=False)
    use_rslora: bool = field(default=False)
    use_clip_visual_encoder: bool = field(default=True)

    lambda_gpu: bool = field(default = True)

    # 2d query tokens
    use_2d_query_tokens: bool = field(default=False)

    # e2d training 
    e2e_training: bool = field(default=False) ##################
    ctrlnet_training: bool = field(default=False) ##################
    pretrained_diffusion_decoder_name_or_path: str = field(default="black-forest-labs/FLUX.1-dev")
    num_single_layers: int = field(default=1)
    num_double_layers: int = field(default=4)
    diffusion_decoder_text_dropout_prob: float = field(default=0.0)
    vae_w_ctrlnet_training: bool = field(default=False)
    vae_wo_ctrlnet_training: bool = field(default=False)

    # load from checkpoints 
    resume_from_checkpoint_mllm: str = field(default = None)
    resume_from_checkpoint_ctrlnet: str = field(default = None)
    
    # fid 
    num_fid_samples: int = field(default=10_000)
    num_classes: int = field(default=1000)
    num_images_per_batch: int = field(default = 10)
    eval_with_prompt: bool = field(default=True)
    



def mask_by_order(mask_len, order, bsz, seq_len):
    masking = torch.zeros(bsz, seq_len).cuda()
    masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
    return masking





@dataclass
class DataCollatorForValidation(object):

    def __init__(self, uni_prompting, args = None):

        self.args = args
        self.uni_prompting = uni_prompting

        self.t2i_image_processor = None 
        if args.vision_language_model == 'Qwen2_5_VLForConditionalGeneration':
            self.t2i_image_processor = AutoProcessor.from_pretrained(args.vision_language_model_name)

        with open("data/ImageNet/ImageNet_class_index_new.json", "r") as f:
            image_text_pairs = json.load(f)
        self.imagenet_class_names = list(image_text_pairs.values())
        self.imagenet_class_names = self.imagenet_class_names * (self.args.num_fid_samples // self.args.num_classes)


    def prepare_t2i(self, texts):

        input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i = self.uni_prompting.t2i_prompt(
            texts, 
            img_h=self.args.t2i_resolution,
            img_w=self.args.t2i_resolution,
            num_visual_gen_tokens=self.args.num_visual_gen_tokens
        )
        
        return input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i


    def __call__(self, i) -> Dict[str, torch.Tensor]:

        image_grid_thw = torch.tensor([self.args.num_images_per_batch, self.args.t2i_resolution//14, self.args.t2i_resolution//14]) #.repeat(2,1)

        t2i_input_ids = [""] * self.args.num_images_per_batch + self.imagenet_class_names[i * self.args.num_images_per_batch : (i + 1) * self.args.num_images_per_batch]
        input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i = self.prepare_t2i(t2i_input_ids)
        

        batch = {}
        batch['t2i_flow'] = {
            "input_ids": input_ids_t2i,
            "labels": labels_t2i,
            "attention_mask": attention_mask_t2i,
            'image_position_mask': image_position_mask_t2i,
            "position_ids": position_ids_t2i,
            "t2i_image_grid_thw": image_grid_thw,
            'ar_mask': torch.ones((2*self.args.num_images_per_batch, self.args.num_visual_gen_tokens)),
            'image_labels': t2i_input_ids,
            }

        return batch





if __name__ == '__main__':


    parser = transformers.HfArgumentParser((TrainingArguments))
    args = parser.parse_args_into_dataclasses()[0]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    args.num_fid_samples = math.ceil(args.num_fid_samples / args.num_images_per_batch) * args.num_images_per_batch


    #######################
    #    model -- MLLM    #
    #######################

    ##### VLM #####
    vision_language_model_config = {
        "cls": args.vision_language_model,
        "model_type": "vision_language_model",
        "params": {
            'model_name_or_path': args.vision_language_model_name, # 3.08B for LLM, 668M for CLIP vision_und_enc
            "load_from_pretrained": True,
            'remove_vision_und_encoder': False,
            'frozen_modules_in_vlm': ["vision_language_model.lm_head", "vision_language_model.model.embed_tokens", "vision_language_model.model.layers", "vision_language_model.model.norm"],
            'huggingface_token': args.huggingface_token,
        },
    }


    vision_gen_vae_config = {
        "cls": "AutoencoderKL", "model_type": "vision_gen_vae",
        "params": {
            'model_name_or_path': None, # 83M params for FLUX VAE
            "load_from_pretrained": True,
            'frozen_vision_gen_vae': False,
            'huggingface_token': args.huggingface_token,
        },
    }


    if args.vision_language_model_name == 'Qwen/Qwen2.5-0.5B-Instruct':
        n_embed = 896 
    elif args.vision_language_model_name == 'Qwen/Qwen2.5-1.5B-Instruct':
        n_embed = 1536
    elif args.vision_language_model_name in ['Qwen/Qwen2.5-VL-3B-Instruct', "Qwen/Qwen2.5-3B-Instruct"]:
        n_embed = 2048
    elif args.vision_language_model_name in ['Qwen/Qwen2.5-VL-7B-Instruct', "Qwen/Qwen2.5-7B-Instruct"]:
        n_embed = 3584
    elif args.vision_language_model_name == 'Qwen/Qwen2.5-14B-Instruct':
        n_embed = 5120 # 48 layers
    elif args.vision_language_model_name in ['Qwen/Qwen2.5-VL-32B-Instruct', "Qwen/Qwen2.5-32B-Instruct"]:
        n_embed = 5120 # 64 layers
    elif args.vision_language_model_name in ['Qwen/Qwen2.5-VL-72B-Instruct', "Qwen/Qwen2.5-72B-Instruct"]:
        n_embed = 8192


    ##### Diffusion Decoder #####
    diffusion_decoder_config = {
        "cls": args.pretrained_diffusion_decoder_name_or_path,
        "model_type": "diffusion_decoder",
        "params": {
            'pretrained_model_name_or_path': args.pretrained_diffusion_decoder_name_or_path, 
            "revision": None,
            "variant": None,
            "controlnet_model_name_or_path": None,
            'num_single_layers': args.num_single_layers,
            'num_double_layers': args.num_double_layers,
            "diffusion_decoder_text_dropout_prob": args.diffusion_decoder_text_dropout_prob,
        },
    }


    model_config = {
        'vision_language_model_config': vision_language_model_config,
        'vision_gen_vae_config': vision_gen_vae_config,
        "diffusion_decoder_config": diffusion_decoder_config,
        'timestep_sampling_strategy': args.timestep_sampling_strategy,
        'vision_denoising_type': args.vision_denoising_type,
        'max_seq_length': args.max_seq_length,
        'num_visual_gen_tokens': args.num_visual_gen_tokens,
        "add_vision_branch": args.add_vision_branch,
        "add_vision_branch_reuse_layernorm": args.add_vision_branch_reuse_layernorm,
        "use_discrete_visual_tokenizer": args.use_discrete_visual_tokenizer,
        "add_timestep_token": args.add_timestep_token,
        "skip_text_part2": args.skip_text_part2,
        "add_vision_gen_mask_token": args.add_vision_gen_mask_token,
        "add_vision_soi_eoi_tokens": args.add_vision_soi_eoi_tokens,
        "add_vision_soi_token": args.add_vision_soi_token,
        "vision_head_type": args.vision_head_type,
        "vision_loss_type": args.vision_loss_type,
        "vision_pos_emb_type": args.vision_pos_emb_type,
        "use_clip_visual_encoder": args.use_clip_visual_encoder,
        "batch_size_t2i": 1,
        "t2i_resolution": args.t2i_resolution,
        "lambda_gpu": args.lambda_gpu,
        "use_2d_query_tokens": args.use_2d_query_tokens,
        "e2e_training": True,
        "ctrlnet_training": args.ctrlnet_training,
        "proportion_empty_prompts": 0.0,
        "lambda_clip": 0.0,
        "remove_vae": False,
        "vae_w_ctrlnet_training": args.vae_w_ctrlnet_training,
        "vae_wo_ctrlnet_training": args.vae_wo_ctrlnet_training,
        "inner_dim": n_embed,
    }

    model_config = MultiModalityConfig(**model_config)
    model = MultiModalityCausalLM(model_config)


    if args.lambda_gpu:
        _load_from_checkpoint_mllm(args.resume_from_checkpoint_mllm, model, muted_keys = ['diffusion_decoder'])
    else:
        model = _load_from_checkpoint(resume_from_checkpoint, local_checkpoint_path, model, num_ckpts=num_ckpts) # load from gcs ckpts 


    model.eval()
    model = model.to(torch.bfloat16).to(device)
    autocast_ctx = torch.autocast(model.device.type)





    ###########################
    #  model -- FLUX CtrlNet  #
    ###########################

    from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel
    vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", revision=None, variant=None)
    flux_transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="transformer", revision=None, variant=None)

    from bifrost.models.flux_cnet.pipeline_flux_controlnet import FluxControlNetPipeline

    from bifrost.models.flux_cnet.controlnet_flux import FluxControlNetModel
    flux_controlnet = FluxControlNetModel.from_transformer(
        flux_transformer,
        attention_head_dim=flux_transformer.config["attention_head_dim"],
        num_attention_heads=flux_transformer.config["num_attention_heads"],
        num_layers=4,
        num_single_layers=1,
    )
    flux_controlnet = flux_controlnet.from_pretrained(os.path.join(args.resume_from_checkpoint_ctrlnet, "flux_controlnet"))

    

    flux_controlnet_pipeline = FluxControlNetPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        controlnet=flux_controlnet,
        transformer=flux_transformer,
        torch_dtype=torch.bfloat16,
    )

    flux_controlnet_pipeline.enable_model_cpu_offload()

    from diffusers.training_utils import free_memory




    #################
    #      data     #
    #################

    # set up processor
    conversation_config = {
        'processor_name_or_path': vision_language_model_config['params']['model_name_or_path'],
        'full_vision_mask': args.full_vision_mask,
        'precise_prompt_mask': args.precise_prompt_mask,
        "add_timestep_token": args.add_timestep_token,
        "cond_dropout_prob": 0.0,
        "add_vision_soi_eoi_tokens": args.add_vision_soi_eoi_tokens,
        "add_vision_soi_token": args.add_vision_soi_token,
        "vision_pos_emb_type": args.vision_pos_emb_type,
        "max_seq_length": args.max_seq_length
    }
    uni_prompting = Conversation(**conversation_config) ## TODO: need to set max_seq_length>60 for open-world image generation

    data_collator_kwargs = {"uni_prompting": uni_prompting, "args": args}
    data_collator = DataCollatorForValidation(**data_collator_kwargs)

    suffix_length = 0



    #################
    #   inference   #
    #################



    ouput_dir = args.output_dir 
    output_dir_w_prompt = os.path.join(ouput_dir, 'val_images_w_prompt')
    output_dir_wo_prompt = os.path.join(ouput_dir, 'val_images_wo_prompt')
    os.makedirs(output_dir_w_prompt, exist_ok=True)
    os.makedirs(output_dir_wo_prompt, exist_ok=True)



    for i in tqdm(range(args.num_fid_samples//args.num_images_per_batch)):

        ##########################
        #####  part 1: MLLM  #####
        ##########################

        if model.device.type == 'cpu':
            model = model.to('cuda')

        t2i_flow = data_collator(i)['t2i_flow']
        input_ids_t2i, labels, attention_mask, image_position_mask, position_ids, image_grid_thw, ar_mask, image_labels = t2i_flow["input_ids"], t2i_flow["labels"], t2i_flow["attention_mask"], t2i_flow['image_position_mask'], t2i_flow['position_ids'], t2i_flow['t2i_image_grid_thw'], t2i_flow['ar_mask'], t2i_flow['image_labels']                
        
        if position_ids.shape[1] == 3:
            position_ids = rearrange(position_ids, "bsz k c -> k bsz c")

        dtype, device = model.dtype, model.device
        bsz = args.num_images_per_batch

        with torch.no_grad():
            text_embs_part1 = model.vision_language_model.model.embed_tokens(input_ids_t2i[:, :-args.num_visual_gen_tokens].to(device)) 
            
        # initialize 
        mask = torch.ones(bsz, args.num_visual_gen_tokens).to(device) 
        tokens = model.mask_token.repeat(bsz, args.num_visual_gen_tokens, 1) 
        orders = sample_orders(bsz, args.num_visual_gen_tokens).to(device) 


        num_iter = args.num_visual_gen_tokens
        for step in tqdm(list(range(num_iter))):
            cur_tokens = tokens.clone() 

            tokens = torch.cat([tokens, tokens], dim=0) 
            mask = torch.cat([mask, mask], dim=0) 

            if args.vision_pos_emb_type == 'learnable_pos_emb':
                tokens = tokens + model.learnable_pos_emb
            
            input_embeddings = torch.cat((text_embs_part1, tokens), dim=1) 


            with torch.no_grad():
                outputs = model.vision_language_model(inputs_embeds=input_embeddings, 
                                        use_cache=None, 
                                        attention_mask=attention_mask.to(device),
                                        past_key_values=None,
                                        return_dict=True,
                                        position_ids = position_ids.to(device),
                                        image_position_mask=image_position_mask.to(device),
                                        t_emb = None
                                        )

            hidden_states = outputs['hidden_states'] # torch.Size([2, 316, 2048])
            
            denoised_hidden_states = hidden_states[:, -args.num_visual_gen_tokens:] # torch.Size([2, 256, 2048])

            if args.vision_head_type == 'linear':
                with torch.no_grad():
                    model_pred = model.vision_head(denoised_hidden_states) # torch.Size([2, 256, 2048])

            # mask ratio for the next round, following MaskGIT and MAGE.
            mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
            mask_len = torch.Tensor([np.floor(args.num_visual_gen_tokens * mask_ratio)]).cuda()

            # masks out at least one for the next iteration
            mask_len = torch.maximum(torch.Tensor([1]).cuda(), torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))

            # get masking for next iteration and locations to be predicted in this iteration
            mask_next = mask_by_order(mask_len[0], orders, bsz, args.num_visual_gen_tokens)
            if step >= num_iter - 1:
                mask_to_pred = mask[:bsz].bool()
            else:
                mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
            mask = mask_next
            mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)


            # sample token latents for this step
            model_pred = model_pred[mask_to_pred.nonzero(as_tuple=True)]
            # cfg schedule follow Muse
            if args.cfg_schedule == "linear":
                cfg_iter = 1 + (args.cfg_weight - 1) * (args.num_visual_gen_tokens - mask_len[0]) / args.num_visual_gen_tokens
            elif args.cfg_schedule == "constant":
                cfg_iter = args.cfg_weight


            sampled_token_latent = model_pred[:args.num_images_per_batch] + cfg_iter * (model_pred[args.num_images_per_batch:] - model_pred[:args.num_images_per_batch]) 
            
            mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)

            cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent.to(cur_tokens.dtype)
            tokens = cur_tokens.clone()


        if model.device.type == 'cuda':
            model = model.cpu()

        free_memory()

        ###########################
        ##### part 2: ctrlnet #####
        ###########################

        with torch.no_grad():
            prompt_embeds_cond, pooled_prompt_embeds_cond, text_ids_cond = flux_controlnet_pipeline.encode_prompt(image_labels[args.num_images_per_batch:], prompt_2=image_labels[args.num_images_per_batch:])
            prompt_embeds_uncond, pooled_prompt_embeds_uncond, text_ids_uncond = flux_controlnet_pipeline.encode_prompt(image_labels[:args.num_images_per_batch], prompt_2=image_labels[:args.num_images_per_batch])

        ## image_estimated_clip (with GT prompt)
        generator = torch.Generator(device=model.device).manual_seed(args.seed)
        with torch.no_grad():
            with autocast_ctx:
                images_w_prompt = flux_controlnet_pipeline(
                    prompt_embeds=prompt_embeds_cond,
                    pooled_prompt_embeds=pooled_prompt_embeds_cond,
                    control_image = tokens.detach(),
                    num_inference_steps=28,
                    controlnet_conditioning_scale=0.7,
                    guidance_scale=3.5,
                    generator=generator,
                    height=args.t2i_resolution * 16 // 14, 
                    width=args.t2i_resolution * 16 // 14,
                ).images

        ## image_estimated_clip (with empty prompt)
        generator = torch.Generator(device=model.device).manual_seed(args.seed)
        with torch.no_grad():
            with autocast_ctx:
                images_wo_prompt = flux_controlnet_pipeline(
                    prompt_embeds=prompt_embeds_uncond,
                    pooled_prompt_embeds=pooled_prompt_embeds_uncond,
                    control_image = tokens.detach(),
                    num_inference_steps=28,
                    controlnet_conditioning_scale=0.7,
                    guidance_scale=3.5,
                    generator=generator,
                    height=args.t2i_resolution * 16 // 14, 
                    width=args.t2i_resolution * 16 // 14,
                ).images


        for j in range(int(i*args.num_images_per_batch), int((i+1)*args.num_images_per_batch)):
            img_text = image_labels[args.num_images_per_batch + j-i*args.num_images_per_batch].replace(" ", "_")
            images_w_prompt[j-i*args.num_images_per_batch].save(os.path.join(output_dir_w_prompt, f"{j:06d}_{img_text}.png"))
            images_wo_prompt[j-i*args.num_images_per_batch].save(os.path.join(output_dir_wo_prompt, f"{j:06d}_{img_text}.png"))

        print("\n")
        print(int(i*args.num_images_per_batch), "=>", int((i+1)*args.num_images_per_batch))



