# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


import os

import io
import re
import copy
import random
import json
import wget

import numpy as np
from google.cloud import storage
from packaging import version
import dataclasses
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List, Union
import transformers
import tokenizers

import torch
from torch import nn
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import Dataset

from bifrost.train.bifrost_trainer import BifrostTrainer
from bifrost.train.prompting_utils import UniversalPrompting, create_attention_mask_predict_next, create_attention_mask_for_mmu, create_attention_mask_for_mmu_vit

from bifrost.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from bifrost.models.configuration_bifrost import MultiModalityConfig
from bifrost.models.modeling_bifrost import MultiModalityCausalLM

from bifrost.conversation import Conversation
from bifrost.utils import IS_XLA_AVAILABLE, count_params, MyTrainerState


if IS_XLA_AVAILABLE:
    os.environ['PJRT_DEVICE'] = 'TPU'
    os.environ['XLA_USE_BF16'] = '0' 

from bifrost.train.data_loader import make_supervised_data_module


from data.parquet import RefinedWebDataset
from datasets import load_dataset, interleave_datasets

import logging
from ezcolorlog import root_logger as logger
logger.setLevel(logging.WARNING)


local_rank = None

XLA_DISABLE_FUNCTIONALIZATION = bool(os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', False))

PRINT_LOGS = True


def print_rank0(*args):
    if local_rank in (0, -1) and PRINT_LOGS:
        print(*args)


def log_rank0(log):
    if local_rank in (0, -1) and PRINT_LOGS:
        logger.info(log, stacklevel=2)


IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')



if IS_XLA_AVAILABLE: # and training_args.bf16:
    from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel
    from bifrost.utils import _shard_parameters_
    XlaFullyShardedDataParallel._shard_parameters_ = _shard_parameters_
    print("######## using default _shard_parameters_ func ##########")
else:
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp import CPUOffload
# else: # for gpu
#     from deepspeed import zero
#     from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus



def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['vq_model', 'vision_gen_aligner', 'vision_gen_head', 'visual_embed_tokens', '.visual.', 'vision_head', 'lm_head', 'diffloss']
    for name, module in model.named_modules():
        # print(name)
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            # print(name)
            lora_module_names.add(name)
            # names = name.split('.')
            # lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    # if 'lm_head' in lora_module_names:
        # lora_module_names.remove('lm_head')
    # if 'vision_language_model.lm_head' in lora_module_names:
        # lora_module_names.remove('vision_language_model.lm_head')

    return list(lora_module_names)



@dataclass
class TrainingArguments(transformers.TrainingArguments):

    ## model 
    vision_language_model: str = field(default=None)
    vision_language_model_name: str = field(default=None)
    vision_gen_vae: str = field(default=None)
    vision_gen_enc: str = field(default=None)
    vision_gen_dec: str = field(default=None)
    vision_gen_tokenizer: str = field(default='VQ-16')

    # frozen_llm: bool = field(default = True)
    frozen_modules_in_vlm: Optional[str] = field(default=None)
    remove_vision_und_encoder: bool = field(default = True)
    remove_vae: bool = field(default = False)
    frozen_vision_gen_vae: bool = field(default = True)
    frozen_vision_gen_encdec: bool = field(default = False)
    fully_trainable: bool = field(default=False)
    
    # vocab_size: int = field(default=58498)
    num_visual_gen_tokens: int = field(default=256) # for 256*256 res images. FLUX VAE scales by 8, UNet scales by 2
    max_seq_length: int = field(default=60)


    ## data
    dataset_list: str = field(default='["imagenet1k"]')
    dataset_path_list: str = field(default=None)
    # use_cached_dataset: bool = field(default=False)
    t2i_resolution: int = field(default=256) # 384
    mmu_resolution: int = field(default=256) # 336
    dataloader_num_workers: int = field(default = 1) # TODO: changed from 4 to 8
    random_flip: bool = field(default=True)
    shuffle: bool = field(default=True)
    dataloader_drop_last: bool = field(default=False)
    dataloader_persistent_workers: bool = field(default=False)
    lambda_gpu: bool = field(default=False)

    
    ## optimizer & lr (HF trainer, follow showo)
    optim: str = field(default="adamw_torch")
    scale_lr: bool = field(default=False)
    weight_decay: float = field(default=0.0)
    adam_beta1: float = field(default=0.9)
    adam_beta2: float = field(default=0.95)
    adam_epsilon: float = field(default=1e-8)
    max_grad_norm: float = field(default=1.0)
    
    lr_scheduler_type: str = field(default="constant_with_warmup") # cosine
    learning_rate: float = field(default=1e-4)
    warmup_steps: int = field(default=5000)
    max_steps: int = field(default=0)
    num_train_epochs: int = field(default=1000)

    ## training 
    gradient_accumulation_steps: Optional[int] = field(default=1)
    gradient_checkpointing: bool = field(default=False)

    batch_size_t2i: Optional[int] = field(default=2) # 15
    batch_size_lm: Optional[int] = field(default=0) # 4
    batch_size_mmu: Optional[int] = field(default=0) # 10

    t2i_coeff: Optional[float] = field(default=1.0)
    lm_coeff: Optional[float] = field(default=0.0)
    mmu_coeff: Optional[float] = field(default=0.0)
    remove_unused_columns: Optional[bool] = field(default=False) # hard lession, for custom collate_fn, set this as False. 
    ignore_data_skip: Optional[bool] = field(default=True) # resume

    cond_dropout_prob: float = field(default=0.0)
    proportion_empty_prompts: float = field(default=0.0)
    label_smoothing: float = field(default=0.0)


    ## log, save, eval (HF trainer)
    output_dir: Optional[str] = field(default=None)
    huggingface_token: str = field(default=None)
   
    resume_from_checkpoint: Optional[str] = field(default=None) 
    logging_steps: int = field(default=1) # log even step works well for now
    logging_first_step: bool = field(default = True) # TODO: experiment
    eval_steps: int = field(default=100000000) # never eval
    save_steps: int = field(default=1000)
    save_total_limit: int = field(default = 10000)
    save_strategy: str = field(default = "steps")
    log_task_specific_loss: bool = field(default = False)


    ## FSDP params (from cambrian-1)
    fsdp: str = field(default="")
    fsdp_config: str = field(default=None)
    bf16: bool = field(default = False)
    is_fsdp_enabled: bool = field(default = True)


    timestep_sampling_strategy: str = field(default='uniform')
    vision_denoising_type: str = field(default='continuous_diffusion')
    add_timestep_token: bool = field(default=True)
    add_vision_gen_mask_token: bool = field(default=False)
    add_vision_soi_token: bool = field(default=False)
    add_vision_soi_eoi_tokens: bool = field(default=False)
    add_separate_rope_for_vision: bool = field(default=False)
    
    vision_head_type: str = field(default=None)
    vision_loss_type: str = field(default=None)
    vision_pos_emb_type: str = field(default='2drope')
    lambda_clip: float = field(default=0.25)

    # masks
    full_vision_mask: bool = field(default=False)
    precise_prompt_mask: bool = field(default=False)
    skip_text_part2: bool = field(default=False)

    # architecture
    add_vision_branch: bool = field(default=False)
    add_vision_branch_reuse_layernorm: bool = field(default=False)
    use_discrete_visual_tokenizer: bool = field(default=False)
    use_clip_visual_encoder: bool = field(default=False)

    # lora
    use_lora: bool = field(default=False)
    lora_r: int = field(default=320)
    lora_alpha: int = field(default=320) 
    lora_dropout: float = field(default=0.05)
    lora_target_modules: List[str] = field(default_factory=lambda: ["c_attn", "c_proj", "w1", "w2"])
    lora_weight_path: str = field(default="")
    lora_bias: str = field(default="none") # to keep all params in backbone frozen
    q_lora: bool = field(default=False)
    use_rslora: bool = field(default=False)

    # 2d query tokens
    use_2d_query_tokens: bool = field(default=False)

    # e2d training 
    e2e_training: bool = field(default=False)
    ctrlnet_training: bool = field(default=False)
    pretrained_diffusion_decoder_name_or_path: str = field(default="black-forest-labs/FLUX.1-dev")
    num_single_layers: int = field(default=0)
    num_double_layers: int = field(default=4)
    diffusion_decoder_text_dropout_prob: float = field(default=0.0)
    vae_w_ctrlnet_training: bool = field(default=False)
    vae_wo_ctrlnet_training: bool = field(default=False)
    vae_scale_by_4: bool = field(default=False)
    
    




# def train(attn_implementation=None):
def train(INDEX=0, attn_implementation=None):

    global local_rank
    log_rank0(f"Training on index {INDEX}. Local rank: {local_rank}")
    parser = transformers.HfArgumentParser((TrainingArguments))
    training_args = parser.parse_args_into_dataclasses()[0]



    max_batch_size_per_task = max(training_args.batch_size_t2i, training_args.batch_size_lm, training_args.batch_size_mmu)
    num_tasks = (training_args.batch_size_t2i > 0) + (training_args.batch_size_lm > 0) + (training_args.batch_size_mmu > 0)
    training_args.per_device_train_batch_size = max_batch_size_per_task * num_tasks

    local_rank = training_args.local_rank
    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) # TODO: float32 for now





    #############
    #  fix lib  #
    #############

    transformers.trainer_callback.TrainerState = MyTrainerState # re-write TrainerState class


    # TPU Note, the original LLaMA RMSNorm implementation has a bug here, the dtype conversion is not correct. It is ok in GPU but kills TPU training.
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        output = (self.weight * hidden_states).to(input_dtype) 
        return output

    transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = forward
    transformers.models.mistral.modeling_mistral.MistralRMSNorm.forward = forward


    def new_forward_conv(self, input):
        if self.bias is None:
            return self._conv_forward(input, self.weight, self.bias)
        return self._conv_forward(input, self.weight.to(input.dtype), self.bias.to(input.dtype)) ### this is good, but not work for debug sdxl

    nn.Conv2d.forward = new_forward_conv

    def new_forward_linear(self, input):
        if self.bias is None:
            return F.linear(input, self.weight, self.bias)
        return F.linear(input, self.weight, self.bias.to(input.dtype)).to(input.dtype)

    nn.Linear.forward = new_forward_linear

    




    #############
    #   model   #
    #############

    if training_args.vision_gen_vae == 'black-forest-labs/FLUX.1-dev':
        vae_channels = 16
    elif training_args.vision_gen_vae == 'stabilityai/sdxl-vae':
        vae_channels = 4
    elif training_args.vision_gen_vae == 'pretrained_models/vae/kl16.ckpt':
        vae_channels = 16
    elif training_args.vision_gen_vae == None:
        vae_channels = None
    
    ##### VLM #####
    vision_language_model_config = {
        "cls": training_args.vision_language_model,
        "model_type": "vision_language_model",
        "params": {
            'model_name_or_path': training_args.vision_language_model_name, # 3.08B for LLM, 668M for CLIP vision_und_enc
            "load_from_pretrained": True,
            'remove_vision_und_encoder': training_args.remove_vision_und_encoder,
            'frozen_modules_in_vlm': json.loads(training_args.frozen_modules_in_vlm),
            'huggingface_token': training_args.huggingface_token,
        },
    }


    ##### VAE #####
    vision_gen_enc_config = {  
        "cls": "ShallowUViTEncoder", "model_type": "vision_gen_enc",
        "params": {"block_out_channels": [768], "elementwize_affine": True, "hidden_size": 2048, "input_channels": vae_channels, "kernel_size": 2, "layers_in_middle": 2, "norm_eps": 1e-06, "num_extra_tensors": 5, "padding": 0, "stride": 2, "use_bias": True, "use_mid_block": True, 'frozen_vision_gen_encdec': training_args.frozen_vision_gen_encdec,
        },
    }
    vision_gen_dec_config = {  
        "cls": "ShallowUViTDecoder", "model_type": "vision_gen_dec",
        "params": {"block_out_channels": [768], "elementwise_affine": True, "hidden_size": 2048, "in_channels": 768, "layers_in_middle": 2, "norm_eps": 1e-06, "out_channels": vae_channels, "upsamples": 1, "use_bias": True, "use_mid_block": True, 'frozen_vision_gen_encdec': training_args.frozen_vision_gen_encdec,
        },
    }
    vision_gen_vae_config = {
        "cls": "AutoencoderKL", "model_type": "vision_gen_vae",
        "params": {
            'model_name_or_path': training_args.vision_gen_vae, # 83M params for FLUX VAE
            "load_from_pretrained": True,
            'frozen_vision_gen_vae': training_args.frozen_vision_gen_vae,
            'huggingface_token': training_args.huggingface_token,
        },
    }


    ##### tokenizer #####
    if training_args.vision_gen_tokenizer == 'VQ-16':
        image_token_size = 16384
        model_name_or_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt"
        image_token_size = image_token_size + 1 if training_args.add_vision_gen_mask_token else image_token_size
        image_token_size = image_token_size + 2 if training_args.add_vision_soi_eoi_tokens else image_token_size
    elif training_args.vision_gen_tokenizer == 'VQ-8':
        image_token_size = 16384
        model_name_or_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds8_c2i.pt"
        image_token_size = image_token_size + 1 if training_args.add_vision_gen_mask_token else image_token_size
        image_token_size = image_token_size + 2 if training_args.add_vision_soi_eoi_tokens else image_token_size
    elif training_args.vision_gen_tokenizer == 'magvitv2':
        image_token_size = 8192
        model_name_or_path = "showlab/magvitv2"
        image_token_size = image_token_size + 1 if training_args.add_vision_gen_mask_token else image_token_size
        image_token_size = image_token_size + 2 if training_args.add_vision_soi_eoi_tokens else image_token_size


    if training_args.vision_language_model_name == 'Qwen/Qwen2.5-0.5B-Instruct':
        n_embed = 896 
    elif training_args.vision_language_model_name == 'Qwen/Qwen2.5-1.5B-Instruct':
        n_embed = 1536
    elif training_args.vision_language_model_name in ['Qwen/Qwen2.5-VL-3B-Instruct', "Qwen/Qwen2.5-3B-Instruct"]:
        n_embed = 2048
    elif training_args.vision_language_model_name in ['Qwen/Qwen2.5-VL-7B-Instruct', "Qwen/Qwen2.5-7B-Instruct"]:
        n_embed = 3584
    elif training_args.vision_language_model_name == 'Qwen/Qwen2.5-14B-Instruct':
        n_embed = 5120 # 48 layers
    elif training_args.vision_language_model_name in ['Qwen/Qwen2.5-VL-32B-Instruct', "Qwen/Qwen2.5-32B-Instruct"]:
        n_embed = 5120 # 64 layers
    elif training_args.vision_language_model_name in ['Qwen/Qwen2.5-VL-72B-Instruct', "Qwen/Qwen2.5-72B-Instruct"]:
        n_embed = 8192

        

    vision_gen_tokenizer_config = {
        "cls": training_args.vision_gen_tokenizer, "model_type": "vision_gen_tokenizer", 
        "params": {"image_token_size": image_token_size, "n_embed": 8, "model_name_or_path": model_name_or_path}
    }
    vision_gen_aligner_config = {
        "cls": "MlpProjector", "model_type": "vision_gen_aligner", 
        "params": {"depth": 2, "input_dim": 8, "n_embed": n_embed, "projector_type": "mlp_gelu"}
    }
    vision_gen_head_config = {
        "cls": "VisionHead", "model_type": "vision_gen_head", 
        "params": {"image_token_embed": n_embed, "image_token_size": image_token_size, "n_embed": n_embed}
    }


    ##### Diffusion Decoder #####
    diffusion_decoder_config = {
        "cls": training_args.pretrained_diffusion_decoder_name_or_path,
        "model_type": "diffusion_decoder",
        "params": {
            'pretrained_model_name_or_path': training_args.pretrained_diffusion_decoder_name_or_path, 
            "revision": None,
            "variant": None,
            "controlnet_model_name_or_path": None,
            'num_single_layers': training_args.num_single_layers,
            'num_double_layers': training_args.num_double_layers,
            "diffusion_decoder_text_dropout_prob": training_args.diffusion_decoder_text_dropout_prob,
        },
    }



    model_config = {
        'vision_gen_enc_config': vision_gen_enc_config, 
        'vision_gen_dec_config': vision_gen_dec_config, 
        'vision_language_model_config': vision_language_model_config,
        'vision_gen_vae_config': vision_gen_vae_config,
        "vision_gen_tokenizer_config": vision_gen_tokenizer_config,
        "vision_gen_aligner_config": vision_gen_aligner_config,
        "vision_gen_head_config": vision_gen_head_config,
        "diffusion_decoder_config": diffusion_decoder_config,
        'timestep_sampling_strategy': training_args.timestep_sampling_strategy,
        'vision_denoising_type': training_args.vision_denoising_type,
        'max_seq_length': training_args.max_seq_length,
        'num_visual_gen_tokens': training_args.num_visual_gen_tokens,
        "add_vision_branch": training_args.add_vision_branch,
        "add_vision_branch_reuse_layernorm": training_args.add_vision_branch_reuse_layernorm,
        "use_discrete_visual_tokenizer": training_args.use_discrete_visual_tokenizer,
        "add_timestep_token": training_args.add_timestep_token,
        "skip_text_part2": training_args.skip_text_part2,
        "add_vision_gen_mask_token": training_args.add_vision_gen_mask_token,
        "add_vision_soi_eoi_tokens": training_args.add_vision_soi_eoi_tokens,
        "add_vision_soi_token": training_args.add_vision_soi_token,
        "vision_head_type": training_args.vision_head_type,
        "vision_loss_type": training_args.vision_loss_type,
        "vision_pos_emb_type": training_args.vision_pos_emb_type,
        "fully_trainable": training_args.fully_trainable,
        "use_clip_visual_encoder": training_args.use_clip_visual_encoder,
        "batch_size_t2i": training_args.batch_size_t2i,
        "t2i_resolution": training_args.t2i_resolution,
        "lambda_gpu": training_args.lambda_gpu,
        "use_2d_query_tokens": training_args.use_2d_query_tokens,
        "e2e_training": training_args.e2e_training,
        "ctrlnet_training": training_args.ctrlnet_training,
        "remove_vae": training_args.remove_vae,
        "proportion_empty_prompts": training_args.proportion_empty_prompts,
        "lambda_clip": training_args.lambda_clip,
        "vae_w_ctrlnet_training": training_args.vae_w_ctrlnet_training,
        "vae_wo_ctrlnet_training": training_args.vae_wo_ctrlnet_training,
        "inner_dim": n_embed,
        "vae_scale_by_4": training_args.vae_scale_by_4
    }

    model_config = MultiModalityConfig(**model_config)
    model = MultiModalityCausalLM(model_config)


    if training_args.use_lora:
        log_rank0("Adding LoRA adapters...")
        from peft import LoraConfig, get_peft_model
        lora_config = LoraConfig(
            r=training_args.lora_r,
            lora_alpha=training_args.lora_alpha,
            target_modules=find_all_linear_names(model),
            lora_dropout=training_args.lora_dropout,
            bias=training_args.lora_bias,
            use_rslora=training_args.use_rslora,
            task_type="CAUSAL_LM", # TODO: need to double check task_type here: https://github.com/huggingface/peft/blob/v0.8.2/src/peft/utils/peft_types.py#L68-L73
        )
        model = get_peft_model(model, lora_config)


    if training_args.bf16:
        model = model.to(dtype=torch.bfloat16) # always use fp32 for model dtype, set compute_dtype as bf16/fp32 in fsdp config instead of here
    else:
        model = model.to(dtype=torch.float32) # always use fp32 for model dtype, set compute_dtype as bf16/fp32 in fsdp config instead of here


    #############
    #   data    #
    #############


    print("=====> Configuring data module...")
    log_rank0("Configuring data module...")

    log_rank0(f"cond_dropout_prob = {training_args.cond_dropout_prob}")
    print("############## dataloader_drop_last = ", training_args.dataloader_drop_last)

    conversation_config = {
        'processor_name_or_path': vision_language_model_config['params']['model_name_or_path'],
        'full_vision_mask': training_args.full_vision_mask,
        'precise_prompt_mask': training_args.precise_prompt_mask,
        "add_timestep_token": training_args.add_timestep_token,
        "cond_dropout_prob": training_args.cond_dropout_prob,
        "add_vision_soi_eoi_tokens": training_args.add_vision_soi_eoi_tokens,
        "add_vision_soi_token": training_args.add_vision_soi_token,
        "vision_pos_emb_type": training_args.vision_pos_emb_type,
        "max_seq_length": training_args.max_seq_length,
    }
    uni_prompting = Conversation(**conversation_config)
    
    data_module = make_supervised_data_module(tokenizer=None,
                                              uni_prompting=uni_prompting,
                                              training_args=training_args,
                                              )


    #############
    #  trainer  #
    #############

    log_rank0("Configuring trainer...")
    print("=====> Configuring trainer...")
    trainer = BifrostTrainer(model=model,
                    args=training_args,
                    **data_module)
    trainer.is_fsdp_enabled = training_args.is_fsdp_enabled
    if training_args.log_task_specific_loss:
        trainer.initialize_task_specific_loss()


    if training_args.resume_from_checkpoint is not None:
        trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    else:
        trainer.train()



if __name__ == "__main__":
    train()
