# coding=utf-8
# modified from https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py
# copy part of code for combining multiple samples into a block_size and casual attention mask (need disable Flash Attention 2) from https://github.com/AlongWY/

import argparse
import json
import logging
import math
import os
import random
import sys
import time
import multiprocessing
import json
import numpy as np
import pandas as pd
from collections import defaultdict
from functools import reduce
from typing import Union, List, Dict, Optional, Tuple
from itertools import chain
from pathlib import Path
from copy import deepcopy
from collections import OrderedDict
import shutil
from time import strftime, localtime

import datasets
import torch
import torch.nn as nn
from datetime import timedelta
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed, gather_object, InitProcessGroupKwargs
from datasets import load_dataset, load_from_disk, DatasetDict, concatenate_datasets
from peft import LoraConfig, get_peft_model, TaskType, PeftModelForCausalLM
from peft.utils.other import prepare_model_for_kbit_training
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    SchedulerType,
    default_data_collator,
    get_scheduler,
    GenerationConfig
)
import transformers.models.llama.modeling_llama
from modules.modeling_llama import custom_forward, custom_forward_causalLM, prepare_inputs_for_generation, make_custom_attention_mask
from modules.custom_optimization import custom_get_cosine_schedule_with_warmup
from modules.init_embeddings import init_embeddings_normal, CodebookEmbedding
from evaluation.eval_score_mmbench import mmbench_eval
from utils import get_nb_trainable_parameters, read_with_orjsonl, write_with_orjsonl, write_with_orjsonl_extend, get_grad_norm

logger = get_logger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
    
    # Meta settings
    parser.add_argument("--project_name", type=str, default="MLLM", help="Name of the project to log to.")
    parser.add_argument('--group_name', type=str, default='pre-train', help='group name')
    parser.add_argument('--run_name', type=str, default='pre-train', help='run name')
    parser.add_argument("--output_dir", type=str, default="YOUR_ROOT_PATH/model/checkpoint/MLLM/pre_train", help="Where to store the final model.")
    

    # Data settings
    parser.add_argument("--dataset_dir", type=str, default="YOUR_ROOT_PATH/data/MLLM", help="Path to the dataset directory.")
    parser.add_argument('--pretrain_dataset_name', type=str, choices=['wit', 'oi', 'ic', 'oi+wit', 'oi+wit+ic', 'oi+ic', 'mmc4', 'mmc4-pairs'], default='wit', help='pretrain dataset names')
    parser.add_argument('--concrete_type', type=str, choices=['concrete', 'abstract', 'all', 'concrete-strict', 'abstract-strict'], default='all', help='concrete type of wit')
    parser.add_argument('--oi_time', type=int, default=1, help='number of oi times for training')
    parser.add_argument('--ic_num', type=int, default=-1, help='number of ic samples for training')
    parser.add_argument('--ic_jd', action="store_true", help="Whether to use JourneyDB ic samples.")
    parser.add_argument('--ic_denoise_prob', type=float, default=0.0, help='denoise prob of ic')
    parser.add_argument('--ic_long', action="store_true", help="Whether to use long ic samples.")
    parser.add_argument('--turn_index', type=int, default=0, help='turn index of dataset')
    parser.add_argument("--block_size", type=int, default=2048, help="Optional input sequence length after tokenization. The training dataset will be truncated in block of this size for training. Default to the model max input length for single sentence inputs (take into account special tokens).")
    parser.add_argument("--eval_block_size", type=int, default=1024, help="Optional input sequence length after tokenization. The evaluation dataset will be truncated in block of this size for evaluation. Default to the model max input length for single sentence inputs (take into account special tokens).")
    parser.add_argument("--use_custom_attention_mask", action="store_true", help="Whether to use custom attention mask. Used for distinguishing different samples in the same block by custom masks.")
    parser.add_argument('--process_batch_size', type=int, default=1000, help='process batch size')
    parser.add_argument("--process_num_workers", type=int, default=multiprocessing.cpu_count(), help="The number of processes to use for the preprocessing.")
    parser.add_argument("--max_image_num", type=int, default=24, help="The maximum number of images in a multi-image-caption sample. Used for padding.")
    parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets")
    parser.add_argument("--compress_batch", action="store_true", help="Whether to compress the batch to save memory.")

    # Model settings
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
    parser.add_argument("--model_name_or_path", type=str, default='YOUR_ROOT_PATH/model/llama2-1229/Llama-2-7b-hf', help="Path to pretrained model or model identifier from huggingface.co/models.")
    parser.add_argument("--visual_codebook", type=str, default='YOUR_ROOT_PATH/model/LaVIT-7B-v2', help="Path to pretrained visual codebook.")
    parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).")
    parser.add_argument("--use_xformers", action="store_true", help="Whether to use xformers.")
    parser.add_argument("--use_flash_attention_2", action="store_true", help="Whether to use Flash Attention 2.")
    parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES)
    parser.add_argument("--low_cpu_mem_usage", action="store_true", help="It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. If passed, LLM loading time and RAM consumption will be benefited.")
    parser.add_argument("--unidirectional_loss", action="store_true", help="Whether to calculate only unidirectional loss, which means only text-to-image or image-to-text loss for text-first or image-first samples")
    parser.add_argument("--uni_and_bi", action="store_true", help="Whether to calculate both unidirectional and bidirectional loss.")
    parser.add_argument("--uni_image_prob", type=float, default=0.5, help="The probability to calculate bidirectional loss.")
    parser.add_argument("--uni_text_prob", type=float, default=0.5, help="The probability to calculate bidirectional loss.")
    parser.add_argument("--remove_related_part", action="store_true", help="Whether to remove the related part of input")
    parser.add_argument("--remove_all_only_ic", action="store_true", help="Whether to remove all and become ic samples")
    parser.add_argument("--remove_des", action="store_true", help="Whether to remove the descriptions")
    parser.add_argument("--only_obj", action="store_true", help="Whether to only preserve object labels and descriptions")
    parser.add_argument('--image_first_prob', type=float, default=0.5, help='image first prob')
    parser.add_argument("--expand_vocab", type=str, default="normal", help="How to expand the language vocab to vision-language vocab.", choices=["normal", "random", "factorized"])
    parser.add_argument("--factorized_linear_mlp", action="store_true", help="Whether to use mlp as factorized linear.")
    parser.add_argument("--vl_vocab_size", type=int, default=48386, help="The vocab size of vision-language vocab.")
    parser.add_argument("--img_size", type=int, default=224, help="The image size")
    parser.add_argument("--max_image_length", type=int, default=256, help="The max image length")
    parser.add_argument("--image_start_token_id", type=int, default=32000, help="The start token id of image tokens.")
    parser.add_argument("--image_end_token_id", type=int, default=32001, help="The end token id of image tokens.")
    
    # Peft settings
    parser.add_argument("--unfreeze", type=str, default='embed_tokens,lm_head,norm', choices=['none', 'embed_tokens,lm_head', 'embed_tokens,lm_head,norm', 'all'], help="Whether to freeze the model.")
    parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA.")
    parser.add_argument("--lora_rank", type=int, default=16, help="Lora attention dimension.") # 4, 8, 16
    parser.add_argument("--lora_alpha", type=int, default=32, help="The alpha parameter for Lora scaling.") # 8, 16, 32, 64 (lora_rank * 2/4/8)
    parser.add_argument("--lora_dropout", type=float, default=0.05, help="Dropout for LoRA.")
    parser.add_argument("--lora_bias", type=str, default="none", help="Bias for LoRA. Note that LLaMA-2 has no bias.")
    parser.add_argument("--lora_target_modules", type=str, default="gate_proj,down_proj,up_proj", help="The modules to apply LoRA.") # q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj
    parser.add_argument("--lora_modules_to_save", type=str, default="embed_tokens,lm_head,norm", help="The modules to save for LoRA.")
    parser.add_argument("--lora_name_or_path", type=str, default=None, help="Path to lora model saved by save_pretrained")

    # Training settings
    parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
    parser.add_argument("--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.")
    parser.add_argument("--max_eval_steps", type=int, default=None, help="Total number of evaluation steps to perform.")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=32, help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--per_device_train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader.")
    parser.add_argument("--per_device_eval_batch_size", type=int, default=16, help="Batch size (per device) for the evaluation dataloader.")
    parser.add_argument("--per_device_eval_down_batch_size", type=int, default=16, help="Batch size (per device) for the evaluation downstream dataloader.")
    parser.add_argument("--eval_frequency", type=int, default=1, help="Evaluate n times every epoch")
    parser.add_argument("--eval_downstream_dataset_dir", type=str, default='YOUR_ROOT_PATH/data/MLLM/Evaluation', help="Path to the evaluation dataset directory.")
    parser.add_argument('--eval_downstream_dataset_name', type=str, default=None, help='Dataset name for evaluation')
    parser.add_argument('--eval_downstream_prompt_settings', type=str, default='zero_shot', help='Prompt settings')
    parser.add_argument("--checkpointing_frequency", type=int, default=2, help="Checkpointing n times every epoch")
    parser.add_argument("--save_total_limit", type=int, default=4, help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="If the training should continue from a checkpoint folder that saved by save_state.")
    parser.add_argument("--gradient_checkpointing", action="store_true", help="Whether to enable gradient checkpointing to save memory at the expense of slower backward pass.")
    parser.add_argument("--with_tracking", action="store_true", help="Whether to enable experiment trackers for logging.")
    parser.add_argument("--report_to", type=str, default="all", help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`, `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'))

    # Optimizer settings
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.")
    parser.add_argument("--loss_split", type=str, default="v2", choices=[None, "v1", "v2"], help="Whether to split the loss.")
    parser.add_argument("--loss_scale_visual", type=float, default=1.0, help="The scale factor for visual loss.")
    parser.add_argument("--lr_multi_visual", type=float, default=1.0, help="The multiplier for visual parameters.")
    # parser.add_argument("--lr_multi_vl_vocab", action="store_true", help="Whether to use different learning rate for the visual and textual parts of vocab.")
    parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay to use.")
    parser.add_argument("--betas", type=tuple, default=(0.9, 0.95), help="Betas for AdamW optimizer.")
    parser.add_argument("--adam_epsilon", type=float, default=1e-18, help="Epsilon for AdamW optimizer.")
    parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm. Will be set to 0.0 if < 0.0.")
    parser.add_argument("--lr_scheduler_type", type=SchedulerType, default="cosine", help="The scheduler type to use.", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"])
    parser.add_argument("--num_warmup_steps", type=float, default=0.1, help="Number of steps for the warmup in the lr scheduler.")
    parser.add_argument("--min_lr_ratio", type=float, default=0.1, help="The minimum ratio lr to the initial lr in the lr scheduler.")
    parser.add_argument("--custom_lr_scheduler", action="store_true", help="Whether to use custom lr scheduler. It is a linear warmup and cosine decay scheduler, but decay to min_lr.")
    
    args = parser.parse_args()

    print('Number of available cores:', multiprocessing.cpu_count())
    print('Number of available gpus:', torch.cuda.device_count())
    
    try:
        print('GPU model name:', torch.cuda.get_device_name(0))
        print('GPU memory size:', torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024, 'GB')
    except:
        print('No GPU available.')

    return args


def binary_search_for_fit(numbers, capacity):
    """
    Perform binary search to find the largest number that fits into the knapsack with the given capacity.
    """
    left, right = 0, len(numbers) - 1
    result = -1  # If no number fits, return -1

    while left <= right:
        mid = (left + right) // 2
        if numbers[mid] <= capacity:
            result = mid
            left = mid + 1
        else:
            right = mid - 1

    return result


def efficient_greedy_knapsack(numbers, capacity):
    """
    An efficient greedy algorithm with binary search for the knapsack problem.
    """
    numbers.sort()  # Sort numbers in ascending order for binary search
    knapsacks = []

    while numbers:
        current_knapsack = []
        remaining_capacity = capacity

        while True:
            index = binary_search_for_fit(numbers, remaining_capacity)
            if index == -1:
                break  # No more numbers fit in this knapsack

            # Add the found number to the knapsack and update the remaining capacity
            current_knapsack.append(numbers[index])
            remaining_capacity -= numbers[index]

            # Remove the number from the list
            numbers.pop(index)

        knapsacks.append(current_knapsack)

    return knapsacks


def main():
    args = parse_args()

    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
    # in the environment
    accelerator_kwargs = {}
    accelerator_kwargs["kwargs_handlers"] = [InitProcessGroupKwargs(timeout=timedelta(seconds=36000))]
    if args.with_tracking:
        accelerator_kwargs["log_with"] = args.report_to
        accelerator_kwargs["project_dir"] = args.output_dir

    if sys.platform in ["darwin"]:
        accelerator_kwargs["cpu"] = True
        accelerator_kwargs["mixed_precision"] = "no"

    if args.use_xformers:
        from modules.llama_attn_hijack import hijack_llama_attention
        hijack_llama_attention(use_xformers=args.use_xformers)
        print("LLaMA Attention Hijacked!")

    # if args.use_custom_attention_mask or args.expand_vocab == "factorized":
    transformers.models.llama.modeling_llama.LlamaModel.forward = custom_forward
    transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = custom_forward_causalLM
    
    if args.eval_downstream_dataset_name:
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        **accelerator_kwargs
    )

    accelerator.print(f"# Process: {accelerator.num_processes} \nGlobal Index: {accelerator.process_index}\nLocal Index:{accelerator.local_process_index}")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if args.config_name:
        config = AutoConfig.from_pretrained(args.config_name)
    elif args.model_name_or_path:
        config = AutoConfig.from_pretrained(args.model_name_or_path)
    else:
        config = CONFIG_MAPPING[args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")
        
    if args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            args.tokenizer_name, legacy=False, use_fast=not args.use_slow_tokenizer)
    elif args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path, legacy=False, use_fast=not args.use_slow_tokenizer)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    tokenizer.add_special_tokens({'additional_special_tokens': ['<image>']}, replace_additional_special_tokens=False)
    tokenizer.add_special_tokens({'additional_special_tokens': ['</image>']}, replace_additional_special_tokens=False)
    accelerator.print(tokenizer)
    # tokenizer.add_tokens([f"<image_{str(i)}>" for i in range(args.vl_vocab_size - args.image_start_token_id - 2)]) # 48386-32000-2
    
    logger.info(f'mixed_precision: {accelerator.mixed_precision}')
    if accelerator.mixed_precision == 'fp16':
        torch_dtype = torch.float16
    elif accelerator.mixed_precision == 'bf16':
        torch_dtype = torch.bfloat16
    elif accelerator.mixed_precision == 'no':
        torch_dtype = torch.float32
    else:
        torch_dtype = "auto"
    logger.info(f"torch_dtype: {torch_dtype}")

    key_mapping = {
        "input_ids": tokenizer.eos_token_id,
        "attention_mask": 0,
        "position_ids": 0,
        "labels": -100,
        "image_ids": 0,
        "image_index": -1,
    }

    def pad_to_block_size(inputs, pad_id, padding_side="right", block_size=args.block_size):
        if isinstance(inputs[0], int):
            if len(inputs) > block_size:
                raise ValueError(f"Input length {len(inputs)} is greater than block_size {block_size}")
            if padding_side == "right":
                return inputs + [pad_id] * (block_size - len(inputs))
            else:
                return [pad_id] * (block_size - len(inputs)) + inputs
        else:
            if max(map(len, inputs)) > block_size:
                raise ValueError(f"Input length {max(map(len, inputs))} is greater than block_size {block_size}")
            if padding_side == "right":
                for example in inputs:
                    example.extend([pad_id] * (block_size - len(example)))
            else:
                for example in inputs:
                    example[:0] = [pad_id] * (block_size - len(example))
        return inputs

    def group_texts_eval(examples):
        # padding_side: left

        input_ids = [pad_to_block_size(x, key_mapping["input_ids"], padding_side="left", block_size=args.eval_block_size) for x in examples['input_tokens']]
        lengths = [len(x) for x in examples['input_tokens']]
        position_ids = [pad_to_block_size(list(range(x)), key_mapping["position_ids"], padding_side="left", block_size=args.eval_block_size) for x in lengths]
        attention_mask = [pad_to_block_size([1] *  x, key_mapping['attention_mask'], padding_side="left", block_size=args.eval_block_size) for x in lengths]
        
        examples =  {
            "input_ids": input_ids,
            "position_ids": position_ids,
            # "lengths": [[x] for x in lengths],
            "attention_mask": attention_mask,
            "input_index": examples["input_index"]
        }
        
        if args.expand_vocab == "factorized":
            examples["image_ids"], examples["image_starts"], examples["image_ends"] = prepare_factorized_inputs(input_ids, padding_side="left", block_size=args.eval_block_size)
            
        return examples   
    
    def find_sub_list(main_list, sub_list, count_limit=1):
        len_main = len(main_list)
        len_sub = len(sub_list)
        count = 0
        for i in range(len_main - len_sub + 1):
            if main_list[i:i + len_sub] == sub_list:
                count += 1
                if count == count_limit:
                    return i
                else:
                    continue
        return -1

    def remove_related_part(texts, image_first_list):
        # remove related part of input
        if "wit" not in args.pretrain_dataset_name:
            for i in range(len(texts)):
                bound_index = find_sub_list(texts[i], [13, 13], 2)
                if bound_index != -1:
                    texts[i] = texts[i][:bound_index] + [texts[i][-1]]
        else:
            for i in range(len(texts)):
                if image_first_list[i]:
                    bound_index = find_sub_list(texts[i], [13, 13], 2)
                else:
                    bound_index = find_sub_list(texts[i], [32001, 13, 13], 1)
                    if bound_index != -1:
                        bound_index += 1
                if bound_index != -1:
                    texts[i] = texts[i][:bound_index] + [texts[i][-1]]
        return texts

    def prepare_factorized_inputs(input_ids, padding_side="right", block_size=args.block_size):
        image_ids, image_starts, image_ends = [], [], []
        for all_ids in input_ids:
            all_ids = np.array(all_ids)
            image_ids.append((all_ids[all_ids > args.image_end_token_id]- args.image_end_token_id - 1).tolist())
            image_starts.append((np.nonzero(all_ids == args.image_start_token_id)[0] + 1).tolist())
            image_ends.append((np.nonzero(all_ids == args.image_end_token_id)[0]).tolist())
        
        image_ids = [pad_to_block_size(x, key_mapping["image_ids"], padding_side=padding_side, block_size=block_size) for x in image_ids]
        image_starts = [pad_to_block_size(x, key_mapping["image_index"], padding_side=padding_side, block_size=args.max_image_num) for x in image_starts]
        image_ends = [pad_to_block_size(x, key_mapping["image_index"], padding_side=padding_side, block_size=args.max_image_num) for x in image_ends]

        return image_ids, image_starts, image_ends

    def group_texts_compress(examples):
        # padding_side: right
        # use efficient_greedy_knapsack to group multiple image-caption pairs into a sample with block_size, and make custom attention masks for different pairs in the same sample
        
        length2examples_idx = defaultdict(list)
        if 'input_tokens' in examples:
            text_key = 'input_tokens'
        elif 'combined_desc' in examples:
            text_key = 'combined_desc'
        else:
            raise ValueError("No valid text key in examples")
        
        if args.remove_related_part:
            examples[text_key] = remove_related_part(examples[text_key], examples["image_first"])

        for idx, example in enumerate(examples[text_key]):
            length2examples_idx[len(example)].append(idx)

        knapsacks = efficient_greedy_knapsack(
            [len(x) for x in examples[text_key]], args.block_size)

        input_ids = []
        labels = []
        position_ids = []
        lengths = []
        # attention_mask = []

        for knapsack in knapsacks:
            batch_input_ids = []
            batch_position_ids = []
            batch_attention_mask = []
            batch_labels = []
            for length in knapsack:
                example_idx = length2examples_idx[length].pop()

                batch_input_ids.append(examples[text_key][example_idx])
                # different labels for two directions
                if args.unidirectional_loss:
                    label = np.array(examples[text_key][example_idx])
                    if examples["image_first"][example_idx]:
                        label[label >= args.image_start_token_id] = key_mapping["labels"]
                    else:
                        label[(label < args.image_start_token_id) & (label > tokenizer.eos_token_id)] = key_mapping["labels"]
                    batch_labels.append(label.tolist())
                elif args.uni_and_bi:
                    if examples["image_first"][example_idx]:
                        if random.random() < args.uni_image_prob:
                            label = np.array(examples[text_key][example_idx])
                            label[label >= args.image_start_token_id] = key_mapping["labels"]
                            batch_labels.append(label.tolist())
                        else:
                            batch_labels.append(examples[text_key][example_idx])
                    else:
                        if random.random() < args.uni_text_prob:
                            label = np.array(examples[text_key][example_idx])
                            label[(label < args.image_start_token_id) & (label > tokenizer.eos_token_id)] = key_mapping["labels"]
                            batch_labels.append(label.tolist())
                        else:
                            batch_labels.append(examples[text_key][example_idx])
                else:
                    batch_labels.append(examples[text_key][example_idx])
                batch_position_ids.append([i for i in range(length)])
                batch_attention_mask.append([1] * length)

            # concatenate all
            input_ids.append(pad_to_block_size(list(chain(*batch_input_ids)), key_mapping["input_ids"], block_size=args.block_size))
            position_ids.append(pad_to_block_size(list(chain(*batch_position_ids)), key_mapping["position_ids"], block_size=args.block_size))
            labels.append(pad_to_block_size(list(chain(*batch_labels)), key_mapping["labels"], block_size=args.block_size))
            # # useless when compress, since 1D attention mask cannot be used for compress
            # attention_mask.append(pad_to_block_size(list(chain(*batch_attention_mask)), key_mapping["attention_mask"], block_size=args.block_size))
            
            # for make block casual attention mask
            lengths.append([len(x) for x in batch_input_ids])
        
        # attention_mask = make_custom_attention_mask(lengths)

        examples = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "labels": labels,
            # # for vanilla causal attention mask
            # "attention_mask": attention_mask,
            # for make block casual attention mask
            "lengths": lengths,
        }
        
        if args.expand_vocab == "factorized":
            examples["image_ids"], examples["image_starts"], examples["image_ends"] = prepare_factorized_inputs(input_ids, padding_side="right", block_size=args.block_size)

        return examples  

    def group_texts(examples):
        # padding_side: right
        # may manually change labels
        # examples["labels"] = [[key_mapping["labels"]] + input_ids[1:-1] + [tokenizer.eos_token_id] for input_ids in examples["combined_desc"]]
        if 'input_tokens' in examples:
            text_key = 'input_tokens'
        elif 'combined_desc' in examples:
            text_key = 'combined_desc'
        else:
            raise ValueError("No valid text key in examples")

        if args.remove_related_part:
            examples[text_key] = remove_related_part(examples[text_key], examples["image_first"])

        input_ids = [pad_to_block_size(x, key_mapping["input_ids"], block_size=args.block_size) for x in examples[text_key]]
        
        if args.unidirectional_loss:
            # different labels for two directions
            labels = []
            for label, image_first in zip(examples[text_key], examples['image_first']):
                label = np.array(label)
                if image_first:
                    label[label >= args.image_start_token_id] = key_mapping["labels"]
                else:
                    label[(label < args.image_start_token_id) & (label > tokenizer.eos_token_id)] = key_mapping["labels"]
                labels.append(label.tolist())
            labels = [pad_to_block_size(x, key_mapping["labels"], block_size=args.block_size) for x in labels]
        elif args.uni_and_bi:
            labels = []
            for label, image_first in zip(examples[text_key], examples['image_first']):
                if image_first:
                    if random.random() < args.uni_image_prob:
                        label = np.array(label)
                        label[label >= args.image_start_token_id] = key_mapping["labels"]
                        labels.append(label.tolist())
                    else:
                        labels.append(label)
                else:
                    if random.random() < args.uni_text_prob:
                        label = np.array(label)
                        label[(label < args.image_start_token_id) & (label > tokenizer.eos_token_id)] = key_mapping["labels"]
                        labels.append(label.tolist())
                    else:
                        labels.append(label)
            labels = [pad_to_block_size(x, key_mapping["labels"], block_size=args.block_size) for x in labels]
        else:
            labels = [pad_to_block_size(x, key_mapping["labels"], block_size=args.block_size) for x in examples[text_key]]

        lengths = [len(x) for x in examples[text_key]]
        position_ids = [pad_to_block_size(list(range(x)), key_mapping["position_ids"], block_size=args.block_size) for x in lengths]
        attention_mask = [pad_to_block_size([1] *  x, key_mapping['attention_mask'], block_size=args.block_size) for x in lengths]
        # attention_mask = make_custom_attention_mask([[x] for x in lengths])
        
        examples =  {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "labels": labels,
            # "lengths": lengths,
            # for vanilla causal attention mask
            "attention_mask": attention_mask,
            # for make block casual attention mask
            "lengths": [[x] for x in lengths],
        }
        
        if args.expand_vocab == "factorized":
            examples["image_ids"], examples["image_starts"], examples["image_ends"] = prepare_factorized_inputs(input_ids, padding_side="right", block_size=args.block_size)

        return examples  

    def build_gen_config():
        gen_kwargs = dict(
            # top_k=50,
            # top_p=0.9, # LLaMA-2
            # temperature=0.6, # LLaMA-2
            max_new_tokens=10,
            min_new_tokens=2,
            do_sample=False,
            num_beams=5,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            length_penalty=1.0,
            early_stopping=False,
            suppress_tokens=list(range(args.image_start_token_id, args.vl_vocab_size)),
        )
        return GenerationConfig(**gen_kwargs)

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
    with accelerator.main_process_first():
        if 'wit' in args.pretrain_dataset_name:
            data_cache_name = 'cache' if args.expand_vocab != "factorized" else f"cache-{args.expand_vocab}"
            
            if args.concrete_type == 'all':
                dataset_name = 'WITv4'
            else:
                dataset_name = f"WITv4-{args.concrete_type}"
            
            if args.remove_all_only_ic: 
                dataset_name = f"{dataset_name}-IC"
            elif args.remove_des:
                dataset_name = f"{dataset_name}-nodes"
            
            if args.remove_related_part:
                data_cache_name = f"{data_cache_name}-nor"
            if args.unidirectional_loss:
                data_cache_name = f"{data_cache_name}-uni"
            if args.uni_and_bi:
                data_cache_name = f"{data_cache_name}-uni-bi-{args.uni_image_prob}-{args.uni_text_prob}"

            dataset_cache_path = os.path.join(args.dataset_dir, f"{dataset_name}-{data_cache_name}", f"turn_{args.turn_index}")
            
            if os.path.exists(dataset_cache_path) and not args.overwrite_cache:
                wit_datasets = load_from_disk(dataset_cache_path)
                logger.info(f"Load cache from {dataset_cache_path}")
            else:
                wit_datasets = load_from_disk(os.path.join(args.dataset_dir, dataset_name, f"turn_{args.turn_index}"))
                column_names = wit_datasets['train'].column_names
                if not args.remove_all_only_ic:
                    wit_datasets['dev'] = wit_datasets['dev'].remove_columns('used_page_split')
                logger.info(wit_datasets)

                if args.remove_all_only_ic or args.compress_batch: # shuffle dataset
                    wit_datasets['train'] = wit_datasets['train'].shuffle(seed=args.seed)
                    wit_datasets['dev'] = wit_datasets['dev'].shuffle(seed=args.seed)

                wit_datasets.cleanup_cache_files()
                
                wit_datasets = wit_datasets.map(
                    group_texts_compress if args.compress_batch else group_texts,
                    batched=True,
                    batch_size=args.process_batch_size,
                    num_proc=args.process_num_workers,
                    remove_columns=column_names,
                    # keep_in_memory=True,
                    desc="Transform datasets to pre-train data",
                )

                wit_datasets.save_to_disk(dataset_cache_path, max_shard_size="20GB")
                logger.info(wit_datasets)
                logger.info(f"Save cache to {dataset_cache_path}")
                exit(-1)
            
            accelerator.print(f"wit_datasets: {wit_datasets}")
            accelerator.print(f"column_names: {wit_datasets['train'].column_names}")

        if 'oi' in args.pretrain_dataset_name: 
            data_cache_name = 'cache' if args.expand_vocab != "factorized" else f"cache-{args.expand_vocab}"
            # oi_datasets
            # oi has four parts, we will combined them here, use dataset.select to implement the different ratios
            if args.remove_all_only_ic: 
                dataset_name = "OIv4-IC"
            elif args.remove_des:
                dataset_name = "OIv4-nodes"
                if args.only_obj:
                    dataset_name = "OIv4-nodes-only_obj_label"
            elif args.only_obj:
                dataset_name = "OIv4-only_obj_label"
            else:
                dataset_name = "OIv4"
            
            if args.remove_related_part:
                data_cache_name = f"{data_cache_name}-nor"
            if args.unidirectional_loss:
                data_cache_name = f"{data_cache_name}-uni"
            if args.uni_and_bi:
                data_cache_name = f"{data_cache_name}-uni-bi-{args.uni_image_prob}-{args.uni_text_prob}"
            
            dataset_cache_path = os.path.join(args.dataset_dir, f"{dataset_name}-{data_cache_name}", f"turn_{args.turn_index}")
            print(dataset_cache_path, os.path.exists(dataset_cache_path), args.overwrite_cache)
            if os.path.exists(dataset_cache_path) and not args.overwrite_cache:
                oi_datasets = load_from_disk(dataset_cache_path)
                logger.info(f"Load cache from {dataset_cache_path}")
            else:
                oi_parts = ['OpenImage', 'Object365', 'V3DET', 'Visual_Genome']
                oi_datasets = DatasetDict()
                for part in oi_parts:
                    if part in ['V3DET', 'Visual_Genome']:
                        if args.turn_index == 0:
                            turn_index_list = list(range(3))
                        else:
                            turn_index_list = list(range(3, 6))
                    else:
                        turn_index_list = [args.turn_index]
                    
                    for turn_index in turn_index_list:
                        oi_datasets_part = load_from_disk(os.path.join(args.dataset_dir, f"{dataset_name}", f"{part}", f"turn_{turn_index}"))
                        column_names = oi_datasets_part['train'].column_names
                        logger.info(oi_datasets_part['train'])
                        if part == 'OpenImage':
                            oi_datasets['train'] = oi_datasets_part['train']
                            oi_datasets['dev'] = oi_datasets_part['dev']
                        else:
                            # TODO: use dataset.select to implement the different ratios
                            oi_datasets['train'] = concatenate_datasets([oi_datasets['train'], oi_datasets_part['train']])
                            oi_datasets['dev'] = concatenate_datasets([oi_datasets['dev'], oi_datasets_part['dev']])
                        logger.info(oi_datasets['train'])
                
                if args.remove_all_only_ic or args.compress_batch: # shuffle dataset
                    oi_datasets['train'] = oi_datasets['train'].shuffle(seed=args.seed)
                    oi_datasets['dev'] = oi_datasets['dev'].shuffle(seed=args.seed)
                
                oi_datasets.cleanup_cache_files()

                oi_datasets = oi_datasets.map(
                    group_texts_compress if args.compress_batch else group_texts,
                    batched=True,
                    batch_size=args.process_batch_size,
                    num_proc=args.process_num_workers,
                    remove_columns=column_names,
                    # keep_in_memory=True,
                    desc="Transform datasets to pre-train data",
                )

                oi_datasets.save_to_disk(dataset_cache_path, max_shard_size="20GB")
                logger.info(oi_datasets)
                logger.info(f"Save cache to {dataset_cache_path}")
                exit(-1)

            accelerator.print(f"oi_datasets: {oi_datasets}")
            accelerator.print(f"column_names: {oi_datasets['train'].column_names}")

        if 'ic' in args.pretrain_dataset_name:
            # ic_datasets need to shuffle here, since we will patch multiple image-caption pairs into a sample
            data_cache_name = 'cache' if args.expand_vocab != "factorized" else f"cache-{args.expand_vocab}"
            dataset_name = "IC"
            # if args.ic_jd:
            #     dataset_name = f"{dataset_name}-with_JD"
            
            if args.unidirectional_loss:
                data_cache_name = f"{data_cache_name}-uni"
            if args.uni_and_bi:
                data_cache_name = f"{data_cache_name}-uni-bi-{args.uni_image_prob}-{args.uni_text_prob}"
            
            if args.ic_num != -1:
                data_cache_name = f"{data_cache_name}-{args.ic_num}"
            if args.ic_long:
                dataset_cache_path = os.path.join(args.dataset_dir, f"{dataset_name}-{data_cache_name}", f"long_turn_{args.turn_index}")
            else:
                dataset_cache_path = os.path.join(args.dataset_dir, f"{dataset_name}-{data_cache_name}", f"turn_{args.turn_index}_{args.ic_denoise_prob}")
            print(dataset_cache_path, os.path.exists(dataset_cache_path), args.overwrite_cache)
            if os.path.exists(dataset_cache_path) and not args.overwrite_cache:
                ic_datasets = load_from_disk(dataset_cache_path)
                logger.info(f"Load cache from {dataset_cache_path}")
            else:
                ic_parts = ['Merged_new', 'laion-coco-aesthetic']
                if args.ic_jd:
                    ic_parts.append('JourneyDB')
                ic_datasets = DatasetDict()
                for part in ic_parts:
                    turn_index_list = [args.turn_index]
                    
                    for turn_index in turn_index_list:
                        if args.ic_long:
                            ic_datasets_part = load_from_disk(os.path.join(args.dataset_dir, f"{dataset_name}", f"{part}", f"long_turn_{turn_index}"))
                        else:
                            ic_datasets_part = load_from_disk(os.path.join(args.dataset_dir, f"{dataset_name}", f"{part}", f"turn_{turn_index}_{args.ic_denoise_prob}"))
                        column_names = ic_datasets_part['train'].column_names
                        logger.info(ic_datasets_part['train'])
                        if part == 'Merged_new':
                            ic_datasets['train'] = ic_datasets_part['train']
                            ic_datasets['dev'] = ic_datasets_part['dev']
                        else:
                            # TODO: use dataset.select to implement the different ratios
                            ic_datasets['train'] = concatenate_datasets([ic_datasets['train'], ic_datasets_part['train']])
                            ic_datasets['dev'] = concatenate_datasets([ic_datasets['dev'], ic_datasets_part['dev']])
                        logger.info(ic_datasets['train'])
                
                if args.remove_all_only_ic or args.compress_batch: # shuffle dataset
                    ic_datasets['train'] = ic_datasets['train'].shuffle(seed=args.seed)
                    ic_datasets['dev'] = ic_datasets['dev'].shuffle(seed=args.seed)
                    if args.ic_num != -1:
                        ic_datasets['train'] = ic_datasets['train'].select(range(args.ic_num))
                
                ic_datasets.cleanup_cache_files()

                ic_datasets = ic_datasets.map(
                    group_texts_compress if args.compress_batch else group_texts,
                    batched=True,
                    batch_size=args.process_batch_size,
                    num_proc=args.process_num_workers,
                    remove_columns=column_names,
                    # keep_in_memory=True,
                    desc="Transform datasets to pre-train data",
                )

                ic_datasets.save_to_disk(dataset_cache_path, max_shard_size="20GB")
                logger.info(ic_datasets)
                logger.info(f"Save cache to {dataset_cache_path}")
                exit(-1)

            accelerator.print(f"ic_datasets: {ic_datasets}")
            accelerator.print(f"column_names: {ic_datasets['train'].column_names}")

        if 'mmc4' in args.pretrain_dataset_name:
            data_cache_name = 'cache' if args.expand_vocab != "factorized" else f"cache-{args.expand_vocab}"
            dataset_name = 'MMC4'
            
            if 'pairs' in args.pretrain_dataset_name:
                data_cache_name = f"{data_cache_name}-pairs"

            if args.unidirectional_loss:
                data_cache_name = f"{data_cache_name}-uni"
            if args.uni_and_bi:
                data_cache_name = f"{data_cache_name}-uni-bi-{args.uni_image_prob}-{args.uni_text_prob}"

            dataset_cache_path = os.path.join(args.dataset_dir, 'Interleaved', f"{dataset_name}-{data_cache_name}")
            
            if os.path.exists(dataset_cache_path) and not args.overwrite_cache:
                mmc4_datasets = load_from_disk(dataset_cache_path)
                logger.info(f"Load cache from {dataset_cache_path}")
            else:
                if 'pairs' in args.pretrain_dataset_name:
                    mmc4_datasets = load_from_disk(os.path.join(args.dataset_dir, 'Interleaved', dataset_name, "datasets_pairs"))
                else:
                    mmc4_datasets = load_from_disk(os.path.join(args.dataset_dir, 'Interleaved', dataset_name, "datasets"))
                column_names = mmc4_datasets['train'].column_names

                logger.info(mmc4_datasets)

                if args.remove_all_only_ic or args.compress_batch: # shuffle dataset
                    mmc4_datasets['train'] = mmc4_datasets['train'].shuffle(seed=args.seed)
                    # mmc4_datasets['dev'] = mmc4_datasets['dev'].shuffle(seed=args.seed)
                    mmc4_datasets['dev'] = mmc4_datasets['train'].select(range(100000)).shuffle(seed=args.seed)

                mmc4_datasets.cleanup_cache_files()
                
                mmc4_datasets = mmc4_datasets.map(
                    group_texts_compress if args.compress_batch else group_texts,
                    batched=True,
                    batch_size=args.process_batch_size,
                    # num_proc=args.process_num_workers,
                    remove_columns=column_names,
                    # keep_in_memory=True,
                    desc="Transform datasets to pre-train data",
                )

                mmc4_datasets.save_to_disk(dataset_cache_path, max_shard_size="20GB")
                logger.info(mmc4_datasets)
                logger.info(f"Save cache to {dataset_cache_path}")
                exit(-1)
            
            accelerator.print(f"mmc4_datasets: {mmc4_datasets}")
            accelerator.print(f"column_names: {mmc4_datasets['train'].column_names}")

        
        if '+' in args.pretrain_dataset_name:
            combined_datasets = DatasetDict()
            if 'oi+wit' in args.pretrain_dataset_name:
                combined_datasets["train"] = concatenate_datasets([oi_datasets["train"], wit_datasets["train"]])
                combined_datasets["dev"] = concatenate_datasets([oi_datasets["dev"], wit_datasets["dev"]])
                accelerator.print("use both oi and wit datasets")
            else:
                combined_datasets["train"] = concatenate_datasets([oi_datasets["train"]] * args.oi_time)
                combined_datasets["dev"] = concatenate_datasets([oi_datasets["dev"]])
                accelerator.print(f"use oi datasets for {args.oi_time} times")
            if 'ic' in args.pretrain_dataset_name:
                combined_datasets["train"] = concatenate_datasets([combined_datasets["train"], ic_datasets["train"]])
                combined_datasets["dev"] = concatenate_datasets([combined_datasets["dev"], ic_datasets["dev"]])
                accelerator.print("use ic datasets")
        else:
            if 'wit' in args.pretrain_dataset_name:
                combined_datasets = wit_datasets
            elif 'oi' in args.pretrain_dataset_name:
                combined_datasets = oi_datasets
            elif 'ic' in args.pretrain_dataset_name:
                combined_datasets = ic_datasets
            elif 'mmc4' in args.pretrain_dataset_name:
                combined_datasets = mmc4_datasets
            accelerator.print(f"use only {args.pretrain_dataset_name} datasets")

        accelerator.print(f"combined_datasets: {combined_datasets}")

        if args.eval_downstream_dataset_name:
            data_cache_name = 'cache' if args.expand_vocab != "factorized" else f"cache-{args.expand_vocab}"
            args.eval_downstream_prompt_settings = args.eval_downstream_prompt_settings.split(',')
            ed_dataset_cache_path = os.path.join(args.eval_downstream_dataset_dir, args.eval_downstream_dataset_name, f"eval-{data_cache_name}")
            if os.path.exists(ed_dataset_cache_path) and not args.overwrite_cache:
                ed_dataset = load_from_disk(ed_dataset_cache_path)
                logger.info(f"Load {args.eval_downstream_dataset_name} from cache")
            else:
                tokenized_datasets = load_from_disk(os.path.join(args.eval_downstream_dataset_dir, args.eval_downstream_dataset_name, 'eval'))
                for prompt_setting in list(tokenized_datasets.keys()):
                    if prompt_setting not in set(args.eval_downstream_prompt_settings):
                        tokenized_datasets.pop(prompt_setting)
                column_names = tokenized_datasets['zero_shot'].column_names
                logger.info(tokenized_datasets)
                
                ed_dataset = tokenized_datasets.map(
                    group_texts_eval,
                    batched=True,
                    batch_size=args.process_batch_size,
                    num_proc=args.process_num_workers,
                    remove_columns=column_names,
                    # keep_in_memory=True,
                    desc="Transform datasets to evaluation data",
                )
                
                ed_dataset.save_to_disk(ed_dataset_cache_path, max_shard_size="20GB")
                logger.info(ed_dataset)
                logger.info(f"Save cache to {ed_dataset_cache_path}")
            accelerator.print(f"eval_downstream_dataset: {ed_dataset}")
            logger.info(f"column_names: {ed_dataset['zero_shot'].column_names}")
        
    train_dataset = combined_datasets["train"]

    # drop last incomplete batch in train_dataset
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    if train_dataset.num_rows % total_batch_size != 0:
        train_dataset = train_dataset.select(range(train_dataset.num_rows - train_dataset.num_rows % total_batch_size))

    if 'validation' in combined_datasets:
        eval_dataset = combined_datasets["validation"]
    elif 'dev' in combined_datasets:
        eval_dataset = combined_datasets["dev"]
    else:
        eval_dataset = combined_datasets["test"]
    
    if args.eval_downstream_dataset_name:
        eval_downstream_dataset = ed_dataset['zero_shot']

    num_attention_heads = config.num_attention_heads

    def collate_fn(batch: List[Dict]):
        new_batch = {}
        for key in batch[0].keys():
            if key != "lengths" and key != "attention_mask":
                new_batch[key] = torch.tensor([x[key] for x in batch], dtype=torch.long)
        
        # if "ic" in args.pretrain_dataset_name or args.remove_all_only_ic or args.compress_batch:
        if args.compress_batch and args.use_custom_attention_mask:
            new_batch["attention_mask"] = make_custom_attention_mask([x["lengths"] for x in batch], block_size=args.block_size, padding_side="right", use_xformers=args.use_xformers, device=accelerator.device, torch_dtype=torch_dtype, num_attention_heads=num_attention_heads)
        
        elif args.use_flash_attention_2:
            new_batch["attention_mask"] = torch.tensor([x["attention_mask"] for x in batch], dtype=torch.long)

        return new_batch
    
    def collate_fn_eval(batch: List[Dict]):
        new_batch = {}
        for key in batch[0].keys():
            if key != "lengths" and key != "attention_mask":
                new_batch[key] = torch.tensor([x[key] for x in batch], dtype=torch.long)
        
        new_batch["attention_mask"] = torch.tensor([x["attention_mask"] for x in batch], dtype=torch.long)
        # new_batch["attention_mask"] = make_custom_attention_mask([x["lengths"] for x in batch], block_size=args.eval_block_size, padding_side="left", use_xformers=args.use_xformers, device=accelerator.device, torch_dtype=torch_dtype, num_attention_heads=num_attention_heads)
        return new_batch

    # DataLoaders creation:
    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
    )
    eval_dataloader = DataLoader(
        eval_dataset, collate_fn=collate_fn, batch_size=args.per_device_eval_batch_size
    )

    if args.eval_downstream_dataset_name:
        eval_downstream_dataloader = DataLoader(
            eval_downstream_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_down_batch_size
        )

    # # Log a few random samples from the training set:
    # for index in random.sample(range(train_dataset.num_rows), 3):
    #     logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
    # logger.info(f"Sample {train_dataset.num_rows} of the training set: {train_dataset[-1]}.")

    if args.model_name_or_path:
        logger.info(f"Load model from: {args.model_name_or_path}")
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
            low_cpu_mem_usage=args.low_cpu_mem_usage,
            device_map={"": accelerator.device},
            torch_dtype=torch_dtype,
            use_flash_attention_2=args.use_flash_attention_2,
            # attn_implementation="eager",
            # use_flash_attention_2=accelerator.device.type == "cuda" and not args.use_custom_attention_mask,
        )
    else:
        logger.info("Training new model from scratch")
        model = AutoModelForCausalLM.from_config(
            config,
        )

    # Need to do this for gpt2, because it doesn't have an official pad token.
    if model.config.pad_token_id is None:
        # tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.eos_token_id

    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    input_embedding_size = model.get_input_embeddings().weight.shape
    # expand the language vocab to vision-language vocab
    model.resize_token_embeddings(args.vl_vocab_size)
    logger.info(f"Resized vocab from {input_embedding_size[0]} to {args.vl_vocab_size}")

    if args.expand_vocab in ["normal", "factorized"]:
        init_embeddings_normal(model, args.image_start_token_id)
        if args.expand_vocab == "factorized":
            visual_codebook_weights = torch.load(os.path.join(args.visual_codebook, 'visual_tokenizer', 'tokenizer_encoder.bin'))['quantize.embedding.weight']
            model.model.visual_codebook = CodebookEmbedding(num_tokens=visual_codebook_weights.shape[0], codebook_dim=visual_codebook_weights.shape[1])
            model.model.visual_codebook.load_state_dict(OrderedDict({'weight' : visual_codebook_weights}))
            if args.factorized_linear_mlp:
                model.model.visual_factorized_linear = nn.Sequential(
                    nn.Linear(visual_codebook_weights.shape[1], input_embedding_size[1], bias=False),
                    nn.SiLU(),
                    nn.Linear(input_embedding_size[1], input_embedding_size[1], bias=False)
                )
                
            else:
                model.model.visual_factorized_linear = nn.Linear(visual_codebook_weights.shape[1], input_embedding_size[1], bias=False)
            model.model._init_weights(model.model.visual_factorized_linear)

    if args.unfreeze == "all": 
        logger.info("Unfreeze all parameters")
        for p in model.parameters():
            p.requires_grad = True
    elif args.unfreeze == "none":
        logger.info("Freeze all parameters")
        for p in model.parameters():
            p.requires_grad = False
    else:
        unfreeze_names = args.unfreeze.split(',')
        if args.expand_vocab == "factorized":
            unfreeze_names.extend(["visual_codebook", "visual_factorized_linear"])
        logger.info(f"Unfreeze {unfreeze_names} parameters")
        for name, param in model.named_parameters():
            if any([unfreeze_name in name for unfreeze_name in unfreeze_names]):
                param.requires_grad = True
            else:
                param.requires_grad = False
    
    # # cast the small parameters (e.g. layernorm) to fp32 for stability
    # # but LLaMA doesn't have bias
    # for p in model.parameters():
    #     if p.ndim == 1:
    #         p.data = p.data.to(torch.float32)
    
    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    if args.use_lora:
        logger.info(f"Use LoRA: {args.use_lora}")
        if args.lora_name_or_path is None:
            lora_target_modules = args.lora_target_modules.split(',')
            lora_modules_to_save = args.lora_modules_to_save.split(',')
            if args.expand_vocab == "factorized":
                lora_modules_to_save.extend(["visual_codebook", "visual_factorized_linear"])
            logger.info(f"LoRA target modules: {lora_target_modules}")
            logger.info(f"LoRA modules to save: {lora_modules_to_save}")
            
            peft_config = LoraConfig(
                r=args.lora_rank,
                lora_alpha=args.lora_alpha,
                lora_dropout=args.lora_dropout,
                bias=args.lora_bias,
                fan_in_fan_out=args.model_name_or_path == "gpt2",
                task_type=TaskType.CAUSAL_LM,
                target_modules=lora_target_modules,
                modules_to_save=lora_modules_to_save,
            )
            model = get_peft_model(model, peft_config)
        else:
            model = PeftModelForCausalLM.from_pretrained(
                model,
                model_id=args.lora_name_or_path,
                is_trainable=True,
            )

        logger.info(model.base_model_torch_dtype)
        # model.print_trainable_parameters()
    
    accelerator.print(model.generation_config)
    accelerator.print(model)
    get_nb_trainable_parameters(model, logger, torch_dtype)
    
    # print all parameter and their grad
    accelerator.print("#########Trainable Parameters#########")
    for name, param in model.named_parameters():
        if param.requires_grad:
            accelerator.print(name, param.shape, param.requires_grad)
    accelerator.print("#########Frozen Parameters#########")
    for name, param in model.named_parameters():
        if not param.requires_grad:
            accelerator.print(name, param.shape, param.requires_grad)

    # Handle the repository creation
    if args.output_dir is None:
        logger.warning(
            "There is no `args.output_dir` specified! Model checkpoints will not be saved."
        )
        exit()
    
    if accelerator.is_main_process and args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    # Creates Dummy Optimizer if `optimizer` was specified in the config file else creates Adam Optimizer
    if (
            getattr(accelerator.state, "deepspeed_plugin", None) is None
            or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config or args.custom_lr_scheduler
    ):
        optimizer_cls = torch.optim.AdamW
        logger.info("Use torch.optim.AdamW")
    else:
        from accelerate.utils import DummyOptim
        optimizer_cls = DummyOptim
        logger.info("Use accelerate.utils.DummyOptim")

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "norm.weight"]
    
    if args.lr_multi_visual > 1.0:
        # when use lora, maybe also multi for embed_tokens and lm_head?
        lr_multi_visual_names = ["embed_tokens", "lm_head"]
        if args.expand_vocab == "factorized":
            lr_multi_visual_names.extend(["visual_codebook", "visual_factorized_linear"])
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and not any(mvn in n for mvn in lr_multi_visual_names) and p.requires_grad],
                "weight_decay": args.weight_decay,
                "lr": args.learning_rate,
            },
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and any(mvn in n for mvn in lr_multi_visual_names) and p.requires_grad],
                "weight_decay": args.weight_decay,
                "lr": args.learning_rate * args.lr_multi_visual,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
                "weight_decay": 0.0,
                "lr": args.learning_rate,
            },
        ]
    else:
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
                "weight_decay": args.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
                "weight_decay": 0.0,
            },
        ]

    optimizer = optimizer_cls(
        optimizer_grouped_parameters,
        lr=args.learning_rate,
        betas=args.betas,
        eps=args.adam_epsilon,
        weight_decay=args.weight_decay,
    )

    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps / accelerator.num_processes)
    assert num_update_steps_per_epoch * total_batch_size == len(train_dataset), f"{num_update_steps_per_epoch} * {total_batch_size} != {len(train_dataset)}"
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

    if args.num_warmup_steps < 1:
        num_warmup_steps = int(args.num_warmup_steps * args.max_train_steps)
    else:
        num_warmup_steps = int(args.num_warmup_steps)

    # Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler
    if (
            getattr(accelerator.state, "deepspeed_plugin", None) is None
            or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config or args.custom_lr_scheduler
    ):
        if args.custom_lr_scheduler:
            lr_scheduler = custom_get_cosine_schedule_with_warmup(
                optimizer=optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=args.max_train_steps,
                min_lr_ratio=args.min_lr_ratio,
            )
            logger.info("Use torch.optim.lr_scheduler, linear warmup with cosine decay to min_lr")
        else:
            lr_scheduler = get_scheduler(
                name=args.lr_scheduler_type,
                optimizer=optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=args.max_train_steps,
            )
            logger.info(f"Use torch.optim.lr_scheduler, {args.lr_scheduler_type}")
    else:
        from accelerate.utils import DummyScheduler
        lr_scheduler = DummyScheduler(
            optimizer, 
            total_num_steps=args.max_train_steps * accelerator.num_processes, # since accelerate's prepare will divide accelerator.num_processes in _prepare_deepspeed()
            warmup_num_steps=num_warmup_steps
        )
        logger.info("Use accelerate.utils.DummyScheduler")
    
    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if args.with_tracking:
        experiment_config = vars(args)
        # TensorBoard cannot log Enums, need the raw value
        experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
        accelerator.init_trackers(project_name=args.project_name, config=experiment_config, init_kwargs={"wandb": {"name": f"{args.run_name}_{strftime('%Y-%m-%d_%H:%M:%S',localtime())}", "group": args.group_name}})

    # Prepare everything with our `accelerator`.
    if args.eval_downstream_dataset_name:
        model, optimizer, train_dataloader, eval_dataloader, eval_downstream_dataloader, lr_scheduler = accelerator.prepare(
            model, optimizer, train_dataloader, eval_dataloader, eval_downstream_dataloader, lr_scheduler
        )
    else:
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
        )

    # # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
    # if accelerator.distributed_type == DistributedType.TPU:
    #     model.tie_weights()

    # Figure out how many steps we should evaluate and save the Accelerator states
    num_eval_steps = num_update_steps_per_epoch // args.eval_frequency
    num_checkpointing_steps = num_update_steps_per_epoch // args.checkpointing_frequency
    # num_checkpointing_steps = args.checkpointing_frequency * num_eval_steps

    logger.info(f"train_dataloader: {len(train_dataloader)}, max_train_steps: {args.max_train_steps}, num_warmup_steps: {num_warmup_steps}, num_eval_steps: {num_eval_steps}, num_checkpointing_steps: {num_checkpointing_steps}")
    
    get_nb_trainable_parameters(model, logger)

    # Train!
    logger.info("***** Running training *****")
    logger.info(f"  Num processes = {accelerator.num_processes}")
    logger.info(f"  Process Index = {accelerator.process_index}")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    logger.info(f"  Checkpointing steps = {num_checkpointing_steps}")
    logger.info(f"  Evaluation steps = {num_eval_steps}")
    logger.info(f"  Warmup steps = {num_warmup_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, ncols=120, mininterval=5)
    completed_steps = 0
    starting_epoch = 0
    best_metric = None

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint is not None:
        if args.resume_from_checkpoint == "latest":
            # Get the most recent checkpoint
            # dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
            dirs = [os.path.join(args.output_dir, f.name) for f in os.scandir(args.output_dir) if f.is_dir() and any(resume_name in f.name for resume_name in ['step_', 'epoch_'])]
            dirs.sort(key=lambda path: os.path.getctime(os.path.join(args.output_dir, path)))
            # Sorts folders by date modified, most recent checkpoint is the last
            checkpoint_path = dirs[-1]
            logger.info(f"Resumed from latest checkpoint: {checkpoint_path}")
        else:
            checkpoint_path = args.resume_from_checkpoint
            logger.info(f"Resumed from checkpoint: {checkpoint_path}")
        
        with open(os.path.join(checkpoint_path, "best_metric.json"), "r") as f:
            best_metric = json.load(f)["best_metric"]
        
        logger.info(f"Resumed from best metric: {best_metric}")

        accelerator.load_state(checkpoint_path)
        # Extract `epoch_{i}` or `step_{i}`
        training_difference = os.path.splitext(os.path.basename(checkpoint_path))[0]

        if "epoch" in training_difference:
            starting_epoch = int(training_difference.replace("epoch_", "")) + 1
            resume_step = None
            completed_steps = starting_epoch * num_update_steps_per_epoch
        else:
            # need to multiply `gradient_accumulation_steps` to reflect real steps
            completed_steps = int(training_difference.replace("step_", ""))
            resume_step = completed_steps * args.gradient_accumulation_steps
            starting_epoch = resume_step // len(train_dataloader)
            resume_step -= starting_epoch * len(train_dataloader)

    # update the progress_bar if load from checkpoint
    progress_bar.update(completed_steps)

    for epoch in range(starting_epoch, args.num_train_epochs):
        model.train()
        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
            active_dataloader = accelerator.skip_first_batches(
                train_dataloader, resume_step)
            logger.info(f"Epoch {epoch}, Step {completed_steps}, Resume from {resume_step}, Training: start")
        else:
            active_dataloader = train_dataloader
            logger.info(f"Epoch {epoch}, Step {completed_steps}, Training: start")

        progress_bar.set_postfix_str("", refresh=False)
        progress_bar.set_description_str("Training: ", refresh=False)
        
        for step, batch in enumerate(active_dataloader):
            with accelerator.autocast(), accelerator.accumulate(model):
                outputs = model(**batch, custom_mask=args.use_custom_attention_mask, use_xformers=args.use_xformers, loss_split=args.loss_split, loss_scale_visual=args.loss_scale_visual, image_start_token_id=args.image_start_token_id)
                if args.loss_split:
                    loss, v_loss, t_loss = outputs.loss, outputs.v_loss, outputs.t_loss
                else:
                    loss = outputs.loss
                if accelerator.sync_gradients:
                    # miss the last accumulate step's grad, but it's ok to reflect the grad norm
                    try:
                        embed_tokens_grad_norm = get_grad_norm(model, name=["embed_tokens"])
                        lm_head_grad_norm = get_grad_norm(model, name=["lm_head"])
                        norm_grad_norm = get_grad_norm(model, name=["norm"])
                    except Exception as e:
                        print(f"Error in step {completed_steps} to get grad norm: {e}")
                        embed_tokens_grad_norm = 0
                        lm_head_grad_norm = 0
                        norm_grad_norm = 0
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    if grad_norm is None:
                        grad_norm = model.get_global_grad_norm()
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
            
            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                if args.loss_split:
                    loss_dict = accelerator.reduce({
                        "loss": loss.detach(),
                        "v_loss": v_loss,
                        "t_loss": t_loss,
                        },
                        reduction="mean"
                    )
                    progress_bar.set_postfix_str(f"loss: {loss_dict['loss']:.3f}, {loss_dict['v_loss']:.3f}, {loss_dict['t_loss']:.3f}", refresh=False)
                else:
                    loss = accelerator.reduce(loss.detach(), reduction="mean")
                    progress_bar.set_postfix_str(f"loss: {loss:.3f}", refresh=False)
                
                if args.with_tracking:
                    # progress_bar.set_description_str("Logging training metrics: ")
                    if args.loss_split:
                        accelerator.log(
                            {
                                "train/loss": loss_dict["loss"],
                                "train/v_loss": loss_dict["v_loss"],
                                "train/t_loss": loss_dict["t_loss"],
                                "learning_rate": lr_scheduler.get_last_lr()[0],
                                "grad_norm": grad_norm,
                                "embed_tokens_norm": embed_tokens_grad_norm,
                                "lm_head_norm": lm_head_grad_norm,
                                "norm_grad_norm": norm_grad_norm,
                            },
                            step=completed_steps,
                        )
                    else:
                        accelerator.log(
                            {
                                "train/loss": loss,
                                "learning_rate": lr_scheduler.get_last_lr()[0],
                            },
                            step=completed_steps,
                        )
                progress_bar.update(1)
                completed_steps += 1

                if completed_steps % num_checkpointing_steps == 0 or completed_steps >= args.max_train_steps:
                    accelerator.wait_for_everyone()
                    model.eval()
                    # progress_bar.set_description_str("Saving checkpoint by steps:")
                    progress_bar.set_postfix_str("", refresh=False)
                    progress_bar.set_description_str("Saving Checkpoints:")
                    if completed_steps < args.max_train_steps:
                        output_dir = os.path.join(args.output_dir, f"step_{completed_steps}")
                        logger.info(
                            f"Epoch {epoch}, Step {completed_steps}: saving checkpoints at {output_dir}"
                        )
                        accelerator.save_state(output_dir)
                        if accelerator.is_main_process:
                            with open(os.path.join(output_dir, "best_metric.json"), "w") as f:
                                json.dump({"best_metric": best_metric}, f)
                    else:
                        # save the last checkpoint
                        output_dir = os.path.join(args.output_dir, f"last_{completed_steps}")
                        logger.info(
                            f"Epoch {epoch}, Step {completed_steps}: saving the last checkpoint at {output_dir}"
                        )
                        accelerator.unwrap_model(model).save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, safe_serialization=False)
                        if accelerator.is_main_process:
                            tokenizer.save_pretrained(output_dir)
                    
                    if args.save_total_limit is not None:
                        # Only keep the last top-N checkpoints
                        checkpoint_dirs = [os.path.join(args.output_dir, f.name) for f in os.scandir(args.output_dir) if f.is_dir() and any(resume_name in f.name for resume_name in ['step_', 'epoch_'])]
                        checkpoint_dirs.sort(key=lambda path: os.path.getctime(path), reverse=True)
                        # dirs.sort(dirs, key=lambda x: int(x.split("step_")[-1]), reverse=True)
                        
                        accelerator.wait_for_everyone()
                        checkpoints_to_delete_paths = checkpoint_dirs[args.save_total_limit:]
                        if len(checkpoints_to_delete_paths) > 0:
                            logger.info(f"Epoch {epoch}, Step {completed_steps}: deleting older checkpoints {checkpoints_to_delete_paths}")
                            
                            if accelerator.is_main_process:
                                for checkpoints_to_delete_path in checkpoints_to_delete_paths:
                                    shutil.rmtree(checkpoints_to_delete_path)
                        accelerator.wait_for_everyone()
                    model.train()

                if completed_steps % num_eval_steps == 0:
                    accelerator.wait_for_everyone()
                    model.eval()
                    # progress_bar.set_description_str("Computing eval metrics: ")
                    progress_bar.set_postfix_str("", refresh=False)
                    progress_bar.set_description_str("Evaluation: ")
                    losses, v_losses, t_losses = [], [], []

                    for eval_step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process, ncols=120, mininterval=5)):
                        with accelerator.autocast(), torch.no_grad():
                            outputs = model(**batch, custom_mask=args.use_custom_attention_mask, use_xformers=args.use_xformers, loss_split=args.loss_split, loss_scale_visual=args.loss_scale_visual, image_start_token_id=args.image_start_token_id)
                        if args.loss_split:
                            loss, v_loss, t_loss = outputs.loss, outputs.v_loss, outputs.t_loss
                        else:
                            loss = outputs.loss
                        
                        if args.loss_split:
                            loss_dict = accelerator.gather_for_metrics({
                                "loss": loss.repeat(args.per_device_eval_batch_size),
                                "v_loss": v_loss.repeat(args.per_device_eval_batch_size),
                                "t_loss": t_loss.repeat(args.per_device_eval_batch_size),
                                },
                            )
                            losses.append(loss_dict["loss"])
                            v_losses.append(loss_dict["v_loss"])
                            t_losses.append(loss_dict["t_loss"])
                        else:
                            losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))

                        if args.max_eval_steps is not None and eval_step + 1 >= args.max_eval_steps:
                            break
                        
                        # break # TODO: Delete here after debugging

                    losses = torch.cat(losses)
                    if args.loss_split:
                        v_losses = torch.cat(v_losses)
                        t_losses = torch.cat(t_losses)
                    try:
                        eval_loss = torch.mean(losses)
                        perplexity = math.exp(eval_loss)
                        if args.loss_split:
                            eval_v_loss = torch.mean(v_losses)
                            eval_t_loss = torch.mean(t_losses)
                    except OverflowError:
                        perplexity = float("inf")
                        logger.info(
                            f"Epoch {epoch}, Step {completed_steps}, Evaluation: overflow detected, eval_loss: {eval_loss}")

                    if args.loss_split:
                        progress_bar.set_postfix_str(f"loss: {eval_loss:.3f}, {eval_v_loss:.3f}, {eval_t_loss:.3f}")
                    else:
                        progress_bar.set_postfix_str(f"loss: {eval_loss:.3f}")
                    
                    if args.with_tracking:
                        if args.loss_split:
                            accelerator.log(
                                {
                                    "eval/perplexity": perplexity,
                                    "eval/loss": eval_loss,
                                    "eval/v_loss": eval_v_loss,
                                    "eval/t_loss": eval_t_loss,
                                    "epoch": epoch,
                                    "step": completed_steps,
                                },
                                step=completed_steps,
                            )
                        else:
                            accelerator.log(
                                {
                                    "eval/perplexity": perplexity,
                                    "eval/loss": eval_loss,
                                    "epoch": epoch,
                                    "step": completed_steps,
                                },
                                step=completed_steps,
                            )
                    
                    accelerator.wait_for_everyone()
                    
                    if best_metric is None or best_metric > perplexity:
                        progress_bar.set_postfix_str("", refresh=False)
                        progress_bar.set_description_str("Saving Best Checkpoints: ")
                        best_metric = perplexity
                        best_metric_checkpoint = os.path.join(
                            args.output_dir, f"best_{completed_steps}_ppl_{best_metric:.3f}")
                        if args.loss_split:
                            logger.info(
                                f"Epoch {epoch}, Step {completed_steps}, Evaluation: perplexity: {perplexity} eval_loss: {eval_loss}, eval_v_loss: {eval_v_loss}, eval_t_loss: {eval_t_loss}, new best metric {best_metric}, saving checkpoints at {best_metric_checkpoint}"
                            )
                        else:
                            logger.info(
                                f"Epoch {epoch}, Step {completed_steps}, Evaluation: perplexity: {perplexity} eval_loss: {eval_loss}, new best metric {best_metric}, saving checkpoints at {best_metric_checkpoint}"
                            )
                        
                        # accelerator.save_state(best_metric_checkpoint)
                        accelerator.unwrap_model(model).save_pretrained(best_metric_checkpoint, is_main_process=accelerator.is_main_process, save_function=accelerator.save, safe_serialization=False)
                        if accelerator.is_main_process:
                            tokenizer.save_pretrained(best_metric_checkpoint)
                        
                        if args.save_total_limit is not None:
                            # Only keep the last top-N checkpoints
                            checkpoint_dirs = [os.path.join(args.output_dir, f.name) for f in os.scandir(args.output_dir) if f.is_dir() and any(resume_name in f.name for resume_name in ['best_'])]
                            checkpoint_dirs.sort(key=lambda path: os.path.getctime(path), reverse=True)
                            # checkpoint_dirs.sort(key=lambda path: float(path.split("ppl_")[-1]), reverse=True)

                            accelerator.wait_for_everyone()
                            checkpoints_to_delete_paths = checkpoint_dirs[args.save_total_limit:]
                            if len(checkpoints_to_delete_paths) > 0:
                                logger.info(f"Epoch {epoch}, Step {completed_steps}: deleting older best checkpoints {checkpoints_to_delete_paths}")
                                
                                if accelerator.is_main_process:
                                    for checkpoints_to_delete_path in checkpoints_to_delete_paths:
                                        shutil.rmtree(checkpoints_to_delete_path)
                            accelerator.wait_for_everyone()
                    else:
                        if args.loss_split:
                            logger.info(
                                f"Epoch {epoch}, Step {completed_steps}, Evaluation: perplexity: {perplexity} eval_loss: {eval_loss}, eval_v_loss: {eval_v_loss}, eval_t_loss: {eval_t_loss}, worse than best metric"
                            )
                        else:
                            logger.info(
                                f"Epoch {epoch}, Step {completed_steps}, Evaluation: perplexity: {perplexity} eval_loss: {eval_loss}, worse than best metric"
                            )
                    
                    if args.with_tracking:
                        accelerator.log(
                            {
                                "eval/best_metric": best_metric,
                            },
                            step=completed_steps,
                        )

                    if args.eval_downstream_dataset_name:
                        progress_bar.set_postfix_str("", refresh=False)
                        progress_bar.set_description_str("Evaluation Downstream: ")
                        outputs_list = []
                        input_index_list = []

                        for eval_downstream_step, batch in enumerate(eval_downstream_dataloader):
                            with accelerator.autocast(), torch.no_grad():
                                input_index_list.extend(batch.pop("input_index").cpu().tolist())
                                outputs = model.generate(
                                    **batch,
                                    use_xformers=args.use_xformers,
                                    generation_config=build_gen_config(),
                                ).cpu()
                                # print(outputs.shape)
                                outputs = outputs[:, args.eval_block_size:].tolist()
                                # print(outputs)
                                # print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
                                outputs_list.extend(outputs)
                        
                        outputs = tokenizer.batch_decode(outputs_list, skip_special_tokens=True)
                        outputs = [text.strip() for text in outputs]
                        output_dict = [{"input_index": input_index, "response": output} for input_index, output in zip(input_index_list, outputs)]
                        eval_result_cache_path = os.path.join(args.output_dir, f"eval_{args.eval_downstream_dataset_name}_{args.eval_downstream_prompt_settings[0]}_step_{completed_steps}")
                        local_eval_result_cache_file = f"{eval_result_cache_path}_rank_{accelerator.process_index}.jsonl"

                        write_with_orjsonl(output_dict, local_eval_result_cache_file)
                        logger.info(f"Write to {local_eval_result_cache_file}")
                        accelerator.wait_for_everyone()

                        if accelerator.is_main_process:
                            eval_result_cache_files = [f"{eval_result_cache_path}_rank_{i}.jsonl" for i in range(accelerator.num_processes)]
                            eval_result_cache_files_exist = [f for f in eval_result_cache_files if os.path.exists(f)]
                            logger.info(f"{eval_result_cache_files_exist} local eval result files exist")
                            # assert len(eval_result_cache_files_exist) == accelerator.num_processes, f"Only {len(eval_result_cache_files_exist)} eval result files exist, but accelerator.num_processes is {accelerator.num_processes}"
                            eval_result_cache_file_merge = f"{eval_result_cache_path}.jsonl"
                            if os.path.exists(eval_result_cache_file_merge):
                                os.remove(eval_result_cache_file_merge)
                            # merge all local eval result
                            for eval_file in eval_result_cache_files_exist:
                                eval_result = read_with_orjsonl(eval_file)
                                write_with_orjsonl_extend(eval_result, eval_result_cache_file_merge)
                            for eval_file in eval_result_cache_files_exist:
                                os.remove(eval_file)
                            num_lines = len(read_with_orjsonl(eval_result_cache_file_merge))
                            logger.info(f"Write all {num_lines} lines to {eval_result_cache_file_merge}")

                            match_rate, acc_match, acc, rolling_acc = mmbench_eval(args, os.path.join(args.eval_downstream_dataset_dir, args.eval_downstream_dataset_name, 'datasets'), prediction_path=eval_result_cache_file_merge)
                            if rolling_acc is None:
                                logger.info(f'Evaluation Downstream: prediction length {len(predictions)} != dataset length {len(origin_dataset)}')
                                progress_bar.set_postfix_str("fail")
                            else:
                                if args.with_tracking:
                                    accelerator.log(
                                        {
                                            "downstream/match_rate": match_rate,
                                            "downstream/acc_match": acc_match,
                                            "downstream/acc": acc,
                                            "downstream/rolling_acc": rolling_acc,
                                            "epoch": epoch,
                                            "step": completed_steps,
                                        },
                                        step=completed_steps,
                                    )
                                progress_bar.set_postfix_str(f"ds: {rolling_acc:.3f}, {acc:.3f}, {acc_match:.3f}, {match_rate:.3f}")

                        accelerator.wait_for_everyone()
                    
                    model.train()
                
                progress_bar.set_postfix_str("", refresh=False)
                progress_bar.set_description_str("Training: ", refresh=False)
            
            if completed_steps >= args.max_train_steps:
                logger.info(
                    f"Epoch {epoch}, Step {completed_steps} Training: end"
                )
                break

        accelerator.wait_for_everyone()

    if args.with_tracking:
        accelerator.end_training()


if __name__ == "__main__":
    main()
