# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and functions related to finetune a text generation model with LoRA.
Modified from https://github.com/tloen/alpaca-lora
"""
import os
import sys

sys.path.append(os.getcwd())

import copy
from dataclasses import dataclass, field
import logging
from packaging import version
from typing import Optional, Union, List, Dict
import warnings
import torch
import transformers
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainingArguments,
    HfArgumentParser,
    LlamaForCausalLM,
    # MistralForCausalLM,
    # MixtralForCausalLM,
    OPTForCausalLM,
    set_seed,
    GenerationConfig,
)
from transformers.trainer_utils import is_main_process, PREFIX_CHECKPOINT_DIR
from transformers.integrations import is_deepspeed_zero3_enabled

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
)

from models.chatglm import ChatGLMTokenizer, ChatGLMForConditionalGeneration, ChatGLMConfig
from models.llama import LlamaForCausalLMModelForContrastiveLearning
from models.llava_next import LlavaNextConfig, LlavaNextForConditionalGeneration, LlavaNextImageProcessor
from models.recur_qwen import Qwen2Config as RecurQwen2Config, Qwen2ForCausalLM as RecurQwen2ForCausalLM
# from models.qwen import QWenTokenizer, QWenLMHeadModel, QWenConfig, QWenLMHeadModelForContrastiveLearning
from utils.callbacks import HaltTrainingCallback, PeftSaveCallback
from utils.data_utils import (
    prepare_datasets,
    DataArguments, 
    ImageProcessingArguments,
    simple_image_preprocessor,
    TokenizedPromptProcessor, 
    TokenizedChatProcessor,
    TokenizedPromptProcessorWithDA,
    TokenizedChatProcessorWithDA,
    DataCollatorForSeq2SeqForAllKeys,
)
from utils.peft_utils import MoLoraConfig, load_peft_weights_into_model
from utils.metric_utils import Seq2SeqMetricsOnSeqIDs, Seq2SeqMetricsOnGenerationSeqIDs
from utils.rl_utils import (
    MyAutoModelForCausalLMWithValueHead, load_rl_weights_into_model 
)
from utils.trainer_utils import count_parameters, RenameCKPTFiles
from global_const import CKPT_FOLDER, BASE_MODELS, MODEL_SIZE, DTYPE_CLASS

logger = logging.getLogger(__name__)


# NOTE: only include models that need special attention, as by default the auto class is used. 
CONFIG_CLASS = {
    'chatglm': ChatGLMConfig,
    'llava-mistral': LlavaNextConfig,
    'recur-qwen2.5': RecurQwen2Config,
}
MODEL_TOKENIZER = {
    'chatglm': ChatGLMTokenizer,
}
IMAGE_PROCESSOR = {
    'llava-mistral': LlavaNextImageProcessor,
}
MODEL_CLASS = {
    'chatglm': ChatGLMForConditionalGeneration,
    'galactica': OPTForCausalLM,
    # 'mistral': MistralForCausalLM,
    # 'mixtral': MixtralForCausalLM,
    'llava-mistral': LlavaNextForConditionalGeneration,
    'recur-qwen2.5': RecurQwen2ForCausalLM,
}
MODEL_CLASS_FOR_CL = {
    'llama': LlamaForCausalLMModelForContrastiveLearning,
    'llama2': LlamaForCausalLMModelForContrastiveLearning,
    'codellama': LlamaForCausalLMModelForContrastiveLearning,
}


@dataclass
class ModelArguments(MoLoraConfig):
    """ Arguments related to the model itself.

    Some related arguments are not listed here but in LoraConfig and PeftConfig.
    See virtual_env/lib/python3.x/site-packages/peft/tuners/lora.py for more info.
        inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.

        r (`int`): Lora attention dimension.  # however passing r leads to ambiguous option; we use lora_r instead.
        target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to.
            For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'
        lora_alpha (`float`): The alpha parameter for Lora scaling.
        lora_dropout (`float`): The dropout probability for Lora layers.
        merge_weights (`bool`):
            Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode.
        enable_lora ( `List[bool]`): Used with `lora.MergedLinear`. TODO what does this do?
        bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'
            none: do not mark any bias as trainable 
            all: make all biases in base model as trainable
            lora_only: make all biases in LoraLayer as trainable
        modules_to_save (`List[str]`): List of modules apart from LoRA layers to be set as trainable
            and saved in the final checkpoint.
    """
    base_model: str = field(
        default="llama-7b", metadata={"help": "The name of the model to be finetuned."}
    )
    model_name_or_path: str = field(
        default=None,
        metadata={
            "help": "The LoRA model checkpoint for weights initialization."
            "Don't set if you want to train a LoRA model from scratch."
        },
    )
    load_in_8bit: bool = field(
        default=False, metadata={"help": "Whether to load base model in 8bit to lower down hardware requirements."}
    )
    torch_dtype: str = field(
        default='float16', metadata={"help": "Whether to store base model weights in float16, bfloat16 or float32."}
    )
    lora_r: int = field(
        default=None, metadata={"help": "Lora attention dimension, aka LoraConfig.r"}
    )
    target_modules: str = field(  # LoraConfig uses Optional[Union[List[str], str]], which may leads to wrong parsing.
        default=None,
        metadata={
            "help": "List of module names or regex expression of the module names to replace with Lora."
            "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
        },
    )
    use_flash_attn: bool = field(
        default=False, metadata={"help": "Whether to use flash attention when it is available."}
    )
    # customization
    adapter_name: str = field(
        default="default", metadata={"help": "Name for LoRA adapter."}
    )
    # DPO related
    DPO_loss_weight: float = field(
        default=0, metadata={"help": "the weight of DPO loss of training. if weight>0, DPO loss works"}
    )
    DPO_loss_beta: float = field(
        default=0.1, metadata={"help": "the beta of DPO loss of training. default is 0.1"}
    )
    DPO_loss_inference_free: bool = field(
        default=False, metadata={"help": "If True, we ignore the reference model logprobs and implicitly use a reference model that assigns equal probability to all responses"}
    )
    # ORM, PRM, OVM related
    summary_dropout_prob: float = field(
        default=0.1, metadata={"help": "Dropout probs for value head."}
    )
    value_loss_fct_type: str = field(
        default="bce", metadata={"help": "Loss function type for SFT of value function; currently support `bce`, `ce`, `mse`."}
    )
    value_loss_weight: float = field(
        default=1.0, metadata={"help": "Weight for value_loss when both token_loss and value_loss are calculated."}
    )
    # recurrent networks related
    recur_strategy: str = field(
        default="blockwise", metadata={"help": "Number of recurrent times to apply. Must be posiive."}
    )
    recur_times: int = field(
        default=1, metadata={"help": "Number of recurrent times to apply. Must be posiive."}
    )
    num_prelude_layers: int = field(
        default=4, metadata={"help": "Number of prelude layers. Must be posiive."}
    )
    num_coda_layers: int = field(
        default=4, metadata={"help": "Number of coda layers. Must be posiive."}
    )
    input_injection_type: str = field(
        default="None", metadata={"help": "Type of input injecton, if None, no injection, if add, add hiddent state and input embedding."}
    )
    state_init_strategy: str = field(
        default="None", metadata={"help": "Strategy of recurrent state initialization, None means initialize with input embedding"}
    )
    init_std: str = field(
        default="None", metadata={"help": "std of recurrent state initialization, defult 'takase'"}
    )
    attn_to_recur_key_values: bool = field(
        default=False, metadata={"help": "Whether to allow recurrent layers to attend to themselves in the previous round."}
    )
    ln_after_recur: bool = field(
        default=False, metadata={"help": "Whether to apply layer norm after each recurrent iteration."}
    )

    def __post_init__(self):
        super().__post_init__()
        self.task_type = "CAUSAL_LM"  # one of {"SEQ_CLS" "SEQ_2_SEQ_LM" "CAUSAL_LM" "TOKEN_CLS"}
        # check base_model
        model_prefix = self.base_model.rsplit('-', 1)[0]
        assert (model_prefix in BASE_MODELS) and (self.base_model in MODEL_SIZE[model_prefix]), NotImplementedError(
            'Model {} not supported.'.format(self.base_model))
        assert self.torch_dtype in ['float16', 'bfloat16', 'float32'], ValueError(
            'torch_dtype must be float16, bfloat16 or float32; got {}'.format(self.torch_dtype))
        # check lora
        if self.model_name_or_path in {'None', '', ' '}:  # TODO make lora accept ckpt
            self.model_name_or_path = None
        if isinstance(self.model_name_or_path, str) and self.model_name_or_path.startswith('~'):
            self.model_name_or_path = os.path.expanduser("~") + self.model_name_or_path[1:]
        if (
            isinstance(self.target_modules, str) 
            and self.target_modules.startswith('[') 
            and self.target_modules.endswith(']') 
        ):
            self.target_modules = eval(self.target_modules)
        if isinstance(self.target_modules, list):
            assert all(isinstance(module, str) for module in self.target_modules), ValueError(
                'target_modules must be a list of strings, got {}'.format(self.target_modules))
        if self.lora_r is not None:
            self.r = self.lora_r
        assert self.summary_dropout_prob >= 0 and self.summary_dropout_prob < 1.0, ValueError(
            "summary_dropout_prob must be in range [0, 1); got {}".format(self.summary_dropout_prob)
        )
        assert self.value_loss_fct_type in ["bce", "ce", "mse", "stepwise_cl", "samplewise_cl"], NotImplementedError(
            "Loss type {} is not implemented.".format(self.value_loss_fct_type))
        assert self.value_loss_weight > 0, NotImplementedError(
            "value_loss_weight must be positive; got {}.".format(self.value_loss_weight))
        assert self.recur_strategy in ["blockwise", "layerwise"], NotImplementedError(
            "recur_strategy must be either 'blockwise' or 'layerwise'; got {}.".format(self.recur_strategy))
        assert self.recur_times >= 0 and self.num_prelude_layers >= 0 and self.num_coda_layers >= 0, NotImplementedError(
            "recur_times, num_prelude_layers, and num_coda_layers must be positive; got {}.".format(
                self.recur_times, self.num_prelude_layers, self.num_coda_layers))
        if self.recur_times == 0:  # 0 and 1 both mean not using recurrent layers
            self.recur_times = 1


@dataclass
class ExtendedDataArguments(DataArguments):
    """ Arguments pertaining to what data we are going to input our model for training and eval.

    Some related arguments are not listed here but in LoraConfig and PeftConfig.
        train_file (`str`): The path of training data file.
        validation_file (`str`): (Optional) the path of evaluation data file to evaluate the perplexity on.
        validation_split (`float`): The percentage of the train set used as validation set in case there's no validation_file.
        max_seq_length (`int`): The maximum total input sequence length after tokenization.
    """
    train_on_inputs: bool = field(
        default=False,
        metadata={"help": "Whether to calculate loss on all (instructions, inputs, outputs) tokens or only outputs."},
    )
    disable_caching: bool = field(
        default=False,
        metadata={"help": "Whether to disable dataset caching and clear cache."},
    )
    data_generation_task: str = field(
        default='SimpleMathFormulation',
        metadata={
            "help": "Optionally a str or list of generation task names."
            "For example, 'SimpleMathFormulation', 'LPProblemDescriptionGeneration', 'IncorrectMathChecking' "
        }
    )
    data_augmentations: str = field(
        default=None,
        metadata={
            "help": "Optionally a list of data augmentation method names or their dict of configs."
            "For example, 'symmetric_shuffle', '[symmetric_shuffle]', '\{'name': 'symmetric_shuffle'\}' "
        }
    )
    force_postprocessor: bool = field(
        default=False,
        metadata={
            "help": "Force to call postprocessor, e.g., TokenizedPromptProcessorWithDA and TokenizedChatProcessorWithDA."
            "This is useful if data contains some stochastic process that does sampling during training. For example,"
            "there are multiple augmented Q or A (or both) for a single instance."
        }
    )
    data_const_deduplicate: bool = field(
        default=False,
        metadata={"help": "Whether to use deduplicated constraints as model outputs."},
    )
    keep_origin_instance: bool = field(
        default=False,
        metadata={"help": "Whether to keep all the information in instance."},
    )
    filter_data_by_indices: str = field(
        default=None,
        metadata={"help": "Whether to filter the training data by indices. May provide a list or a file path that saves the list."},
    )
    # ValueLabelPrediction related
    end_of_step_id: int = field(
        default=-1,
        metadata={
            "help": "A token id indicating the end of step (position to predict value label)."
            "could be None or negative to indicate that the value label should be predicted at every token."
        },
    )
    false_eos_ids: str = field(
        default="[]",
        metadata={"help": "List[List[int]] indicate the cases that should not be considered as the end of step."},
    )
    num_positive_value_samples: int = field(
        default="1",
        metadata={"help": "Number of positive samples for Contrastive Learning of value function."},
    )
    num_negative_value_samples: int = field(
        default="1",
        metadata={"help": "Number of negative samples for Contrastive Learning of value function."},
    )
    # ChatRecords
    roles_to_predict: str = field(
        default=None,
        metadata={"help": "Roles whose utterances to predict in a chat."},
    )
    role_tags: str = field(
        default=None,
        metadata={
            "help": "Some tokenizers may have special tokens for the role tags (e.g., <｜User｜>, <｜Assistant｜>), instead of"
            "standard format <bos_token> + role (e.g., <|im_start|>assistant). For these tokenizers, need to provide"
            "role_tags dict in order to get a correct label, e.g., {'user': '<｜User｜>','assistant':'<｜Assistant｜>'}"
        },
    )
    role_map: str = field(
        default=None,
        metadata={
            "help": "Some tokenizers cannot handle the roles in chat. Provide a role map if you do not want the role to be"
            "automatically detected and mapped. "
        },
    )
    
    def __post_init__(self):
        super().__post_init__()
        if self.data_generation_task.startswith("[") and self.data_generation_task.endswith("]"):
            self.data_generation_task = eval(self.data_generation_task)
        if self.data_augmentations in {"None", "", " "}:
            self.data_augmentations = None
        elif (
            isinstance(self.data_augmentations, str)
            and ((self.data_augmentations.startswith("{") and self.data_augmentations.endswith("}"))
            or (self.data_augmentations.startswith("[") and self.data_augmentations.endswith("]")))
        ):
            self.data_augmentations = eval(self.data_augmentations)
        # self.filter_data_by_indices can also be a file, but the handling of file is done inside data loader,
        # so that its content will not be loggered in run_lora
        if self.filter_data_by_indices == "None":
            self.filter_data_by_indices = None
        elif (
            isinstance(self.filter_data_by_indices, str) 
            and self.filter_data_by_indices.startswith("[") 
            and self.filter_data_by_indices.endswith("]")
        ):
            self.filter_data_by_indices = eval(self.filter_data_by_indices)

        if self.false_eos_ids.startswith("[") and self.false_eos_ids.endswith("]"):
            self.false_eos_ids = eval(self.false_eos_ids)
        else:
            self.false_eos_ids = []
        # ChatRecords
        if self.roles_to_predict == "None":
            self.roles_to_predict = None
        elif isinstance(self.roles_to_predict, str) and self.roles_to_predict.startswith("["):
            self.roles_to_predict = eval(self.roles_to_predict)
        assert self.roles_to_predict is None or isinstance(self.roles_to_predict, list), RuntimeError(
            "Fail to parse {}; expect it to be a list in str.".format(self.roles_to_predict))
        if self.role_tags in {"None", ""}:
            self.role_tags = None
        elif isinstance(self.role_tags, str) and self.role_tags.startswith("{"):
            self.role_tags = eval(self.role_tags)
        assert self.role_tags is None or isinstance(self.role_tags, dict), RuntimeError(
            "Fail to parse {}; expect it to be a dict in str.".format(self.role_tags))
        if self.role_map in {"None", ""}:
            self.role_map = None
        elif isinstance(self.role_map, str) and self.role_map.startswith("{"):
            self.role_map = eval(self.role_map)
        assert self.role_map is None or isinstance(self.role_map, dict), RuntimeError(
            "Fail to parse {}; expect it to be a dict in str.".format(self.role_map))
    
    @staticmethod
    def _is_subtask(task1, task2):
        task1 = [task1] if isinstance(task1, str) else task1
        task2 = [task2] if isinstance(task2, str) else task2
        return all(task in task2 for task in task1)


@dataclass
class ExtendedTrainingArguments(TrainingArguments):
    """ Arguments related to training.

    Some related arguments are not listed here but in TrainingArguments.
    See virtual_env/lib/python3.x/site-packages/transformers/training_args.py for more info.
        overwrite_output_dir (`bool`, *optional*, defaults to `False`)
        do_train (`bool`, *optional*, defaults to `False`)
        do_eval (`bool`, *optional*, defaults to `False`)
        do_predict (`bool`, *optional*, defaults to `False`)
        evaluation_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`, could be `"no"`, `"steps"`, `"epoch"`)
        per_device_train_batch_size (`int`, *optional*, defaults to 8)
        per_device_eval_batch_size (`int`, *optional*, defaults to 8)
        gradient_accumulation_steps (`int`, *optional*, defaults to 1)
        learning_rate (`float`, *optional*, defaults to 5e-5)
        lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`)
        weight_decay (`float`, *optional*, defaults to 0)
        adam_beta1, adam_beta2 (`float`, *optional*, defaults to 0.9 and 0.999)
        adam_epsilon (`float`, *optional*, defaults to 1e-8)
        max_grad_norm (`float`, *optional*, defaults to 1.0)
        num_train_epochs(`float`, *optional*, defaults to 3.0)
        warmup_steps (`int`, *optional*, defaults to 0)
        eval_steps (`int`, *optional*)
        logging_dir (`str`, *optional*, default to *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***)
        logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`)
        logging_steps (`int`, *optional*, defaults to 500)
        save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`)
        save_steps (`int`, *optional*, defaults to 500)
        save_total_limit (`int`, *optional*, defaults to None)
        seed (`int`, *optional*, defaults to 42)
        bf16 (`bool`, *optional*, defaults to `False`):
            Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
            NVIDIA architecture or using CPU (no_cuda).
        fp16 (`bool`, *optional*, defaults to `False`) 
            whether use mixed precision training.
        load_best_model_at_end (`bool`, *optional*, defaults to `False`)
        gradient_checkpointing (`bool`, *optional*, defaults to `False`):
            If True, use gradient checkpointing to save memory at the expense of slower backward pass.
    """
    output_dir: str = field(
        default="../output/", metadata={"help": "The output parent dir to save training logs and lora model ckpt."}
    )
    scale_learning_rate_to_batch_size: bool = field(
        default=True, metadata={"help": "Whether to scale the learning according to the world size."}
    )
    ignore_trainer_state: bool = field(
        default=False,
        metadata={
            "help": "When continue training from a checkpoint, whether to load the trainer_state.json or not."
            "If False, will continue from the epochs_trained in trainer_state.json"
        },
    )
    debug_mode: bool = field(
        default=False, metadata={"help": "In debug mode, norms of parameters will be added to tensorboard."},
    )
    halt_step: int = field(
        default=-1, metadata={"help": "whether to stop training after a certain number of steps."},
    )
    estimate_mem_using_deepspeed_zero: bool = field(
        default=False, metadata={"help": "Call deepspeed zero to estimate the memory needed under each stage."},
    )
    validation_metrics: str = field(
        default="accuracy", metadata={"help": "The metrics used to evaluate validation data."}
    )
    evaluation_method: str = field(
        default="teacher_forcing", metadata={"help": "The validation method used in training, 'autoregressive' is like generation"}
    )

    def __post_init__(self):
        if self.deepspeed in {'None', '', ' '}:
            self.deepspeed = None
        if not self.do_eval:  # if not do_eval, set evaluation_strategy to no; this needs to be done before __post_init__
            self.evaluation_strategy = "no"
        super().__post_init__()
        if self.world_size != 1:
            self.device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
            self.ddp_find_unused_parameters=False
        else:
            self.device_map = 'auto'
        if self.scale_learning_rate_to_batch_size:
            self.learning_rate *= self.per_device_train_batch_size*self.gradient_accumulation_steps*self.world_size/128
            logger.info('The scaled learning rate is {}'.format(self.learning_rate))
        if self.validation_metrics.startswith("[") and self.validation_metrics.endswith("]"):
            self.validation_metrics = eval(self.validation_metrics)
        if self.evaluation_method not in['teacher_forcing', 'autoregressive']:
            raise NotImplementedError(f"The evaluation method {self.evaluation_method} is not a valid method")
        assert not (self.fp16 and self.bf16), ValueError("fp16 and bf16 cannot be activated simultaneously.")


@dataclass
class GenerationArguments:
    """ Arguments related to the model generation.

    GenerationConfig cantains a comprehensive list of arguments. But unlike TrainingArguments, GenerationConfig 
    is not dataclass with fields, meaning that we could not parse it like we parse other arguments. So here we 
    make this dataclass.

    For more arguments, see transformers.generation.configruation_utils.GenerationConfig.

    Below arguments are available in GenerationConfig, that are not listed here but may be useful later. 
    bad_words_ids(`List[List[int]]`, *optional*):
        List of token ids that are not allowed to be generated. In order to get the token ids of the words that
        should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
        add_special_tokens=False).input_ids`.
    force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
        List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
        words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
        triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
        can allow different forms of each word.
    generation_kwargs:
        Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not
        present in `generate`'s signature will be used in the model forward pass.

    """
    # Parameters that control the length of the output
    min_new_tokens: int = field(
        default=None,
        metadata={"help": "The maximum numbers of tokens to generate."},
    )
    max_new_tokens: int = field(
        default=None,
        metadata={"help": "The minimum numbers of tokens to generate."},
    )
    # Parameters that control the generation strategy
    do_sample: bool = field(
        default=False,
        metadata={"help": "Whether or not to use sampling ; use greedy decoding otherwise."},
    )
    num_beams: int = field(
        default=1,
        metadata={"help": "Number of beams for beam search. 1 means no beam search."},
    )
    # Parameters for manipulation of the model output logits
    temperature: float = field(
        default=1.0,
        metadata={"help": "The value used to modulate the next token probabilities."},
    )
    top_k: int = field(
        default=50,
        metadata={
            "help": "If set, will keep the top k with largest logits for generation."
        },
    )
    top_p: float = field(
        default=1.0,
        metadata={
            "help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to"
            "`top_p` or higher are kept for generation."
        },
    )
    length_penalty: float = field(
        default=1.0,
        metadata={
            "help": "Exponential penalty to the length that is used with beam-based generation. "
            "`length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences."
        },
    )
    renormalize_logits: bool = field(
        default=True,
        metadata={
            "help": "Whether to renormalize the logits after applying all the logits processors or warpers (including the custom"
            "ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits"
            "are normalized but some logit processors or warpers break the normalization."
        },
    )
    # Parameters that define the output variables of `generate`
    num_return_sequences: int = field(
        default=1, metadata={
            "help": "The number of independently computed returned sequences for each element in the batch."},
    )
    # others
    streaming: bool = field(
        default=False,
        metadata={"help": "Whether stream-output the generated text."},
    )
    ensemble_method: str = field(
        default=None,
        metadata={"help": "Whether or not to return the probability of 'yes' or 'no', only use for reward model"},
    )
    output_answer_probs: bool = field(
        default=False,
        metadata={"help": "Whether to output the logits and prob for the instruction-answer pair, only for DPO."},
    )

    def __post_init__(self):
        # super().__post_init__()
        self.pad_token_id = 0  # GenerationConfig
        self.bos_token_id = 1
        self.eos_token_id = 2
        if self.ensemble_method in {"None", "", " "}:
            self.ensemble_method = None
        if self.ensemble_method:
            assert self.ensemble_method in {"reward_model", "simple_voting"}, NotImplementedError(
                "ensemble_method {} not used.".format(self.ensemble_method))


def parse_args(args):
    # We keep distinct sets of args, for a cleaner separation of concerns.
    # parser = HfArgumentParser({
    #     "model_args": ModelArguments,
    #     "data_args": ExtendedDataArguments,
    #     "training_args": ExtendedTRLTrainingArguments,
    #     "generation_args": GenerationArguments,
    #     "image_processing_args": ImageProcessingArguments
    # })
    parser = HfArgumentParser((
        ModelArguments, ExtendedDataArguments, ExtendedTrainingArguments, GenerationArguments, ImageProcessingArguments
    ))
    if len(args) == 1 and args[0].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args, generation_args, image_processing_args = parser.parse_json_file(json_file=os.path.abspath(args))
    else:
        model_args, data_args, training_args, generation_args, image_processing_args = parser.parse_args_into_dataclasses(args)

    if (
        os.path.exists(training_args.output_dir)
        and any(_dir.startswith(PREFIX_CHECKPOINT_DIR) for _dir in os.listdir(training_args.output_dir))
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # check torch_dtype and fp16/bf16 options
    if model_args.torch_dtype == "float16":
        assert training_args.fp16 and (not training_args.bf16), ValueError(
            "When torch_dtype is float16, fp16 must be True and bf16 must be False.")
    elif model_args.torch_dtype == "bfloat16":
        assert (not training_args.fp16) and training_args.bf16, ValueError(
            "When torch_dtype is bfloat16, fp16 must be False and bf16 must be True.")
        warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")

    return model_args, data_args, training_args, generation_args, image_processing_args



def get_tokenizer(base_model_prefix, base_model_path):
    padding_side = 'left'
    tokenizer_class = MODEL_TOKENIZER.get(base_model_prefix, AutoTokenizer)
    tokenizer = tokenizer_class.from_pretrained(
        base_model_path, 
        padding_side=padding_side,  # allow batched inference
        truncation_side=padding_side,
        )
    if base_model_prefix in ['llama', 'llama2', 'codellama', 'mistral', 'mixtral']:  # mistral uses LlamaTokenizer
        tokenizer.add_special_tokens({'pad_token': '<unk>'})  # set pad_token to '<unk>'
    elif base_model_prefix in ['llama3', 'llama3.3']:
        if tokenizer.pad_token is None:  # llama3 does not set this but llama3.3 has <|finetune_right_pad_id|>
            tokenizer.add_special_tokens({'pad_token': '<|reserved_special_token_0|>'})
    elif base_model_prefix == 'chatglm':
        tokenizer.add_special_tokens({'eos_token': '<eop>'})  # set eos_token to '<eop>'
    elif base_model_prefix == 'galactica':
        tokenizer.add_special_tokens({'bos_token': '<s>'})
        tokenizer.add_special_tokens({'eos_token': '</s>'})
        tokenizer.add_special_tokens({'pad_token': '<pad>'})
        tokenizer.add_special_tokens({'unk_token': '<unk>'})
    elif base_model_prefix == 'qwen':
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        # tokenizer.pad_token_id = tokenizer.eos_token_id
    elif base_model_prefix == 'llava-mistral':
        tokenizer.add_special_tokens({'pad_token': '<unk>'})  # set pad_token to '<unk>'
        image_token_index = [key for key, value in tokenizer.added_tokens_decoder.items() if value.content == "<image>"]
        assert len(image_token_index) == 1
        tokenizer.image_token_index = image_token_index[0]
    elif base_model_prefix in ['qwen1.5','qwen2', 'qwen2-math', 'qwen2.5', 'qwen2.5-math', 'recur-qwen2.5']:
        logger.info("tokenizer.pad_token", tokenizer.pad_token)
        logger.info("tokenizer.bos_token_id", tokenizer.bos_token_id)
        logger.info("tokenizer.eos_token_id", tokenizer.eos_token_id)
        tokenizer.unk_token_id = tokenizer.pad_token_id
        tokenizer.add_special_tokens({'bos_token': '<|im_start|>'})  # set bos_token to '<|im_start|>'
    elif base_model_prefix in ['DeepSeek-R1-Distill-Qwen']:
        tokenizer.add_special_tokens({'pad_token': '<|fim_pad|>'})
        tokenizer.unk_token_id = tokenizer.pad_token_id
    elif base_model_prefix in ['granite-guardian-3.1']:
        tokenizer.add_special_tokens({'pad_token': '<fim_pad>'})
        tokenizer.unk_token_id = tokenizer.pad_token_id

    assert tokenizer.pad_token_id not in [tokenizer.bos_token_id, tokenizer.eos_token_id], RuntimeError(
        "pad_token collidess with bos_token/eos_token, which may result in unexpected output scores.")

    return tokenizer

def get_generation_config(data_args, generation_args, tokenizer):
    # set up generation_config
    generation_config = GenerationConfig(**copy.deepcopy(generation_args.__dict__))
    # configure generation_config
    generation_config.pad_token_id = tokenizer.pad_token_id
    generation_config.bos_token_id = tokenizer.bos_token_id
    generation_config.eos_token_id = tokenizer.eos_token_id
    # set up role_tags and role_map
    generation_config.role_tags = data_args.role_tags
    generation_config.role_map = data_args.role_map

    logger.info(generation_config)

    return generation_config


def get_image_processor(base_model_prefix, image_processing_args):
    if base_model_prefix in ['llava-mistral']:
        image_processor_class = IMAGE_PROCESSOR[base_model_prefix]
        image_processor = image_processor_class(
            do_resize=image_processing_args.do_resize,
            size=image_processing_args.image_size,
            image_grid_pinpoints=image_processing_args.image_grid_pinpoints,
            resample=image_processing_args.image_resample,
            do_center_crop=image_processing_args.do_center_crop,
            crop_size=image_processing_args.crop_size,
            do_normalize=image_processing_args.do_normalize,
            image_mean=image_processing_args.image_mean,
            image_std=image_processing_args.image_std,
        )
    else:
        image_processor = None

    return image_processor


def get_data(data_args, training_args, tokenizer, image_processor=None):
    BASIC_TASK = [
        "SimpleMathFormulation", "SimpleQA", "SimpleRL", "ValueLabelPrediction", "ValueLabelPredictionWithoutGeneration",
        "ValueLabelContrastiveLearning", "ValueLabelContrastiveLearningWithoutGeneration",
        ]
    if data_args.data_generation_task in {"SimpleRL"}:
        data_args.do_tokenization = False
    if (
        data_args.data_augmentations
        or not data_args._is_subtask(data_args.data_generation_task, BASIC_TASK)
        or data_args.force_postprocessor
    ):
        if 'qa' in data_args.prompt_templates:
            preprocessor, eval_preprocessor = None, None
            postprocessor, eval_postprocessor = TokenizedPromptProcessorWithDA(tokenizer, data_args), TokenizedPromptProcessorWithDA(tokenizer, data_args, is_eval=True)
        elif 'chat' in data_args.prompt_templates:
            preprocessor, eval_preprocessor = simple_image_preprocessor, simple_image_preprocessor
            postprocessor = TokenizedChatProcessorWithDA(tokenizer, data_args, image_processor=image_processor)
            eval_postprocessor = TokenizedChatProcessorWithDA(tokenizer, data_args, image_processor=image_processor, is_eval=True)
        else:
            raise ValueError("prompt_templates should contain either qa or chat; got {}".format(data_args.prompt_templates))
        columns_to_read = [
            "instruction", "input", "language", "objective_description", "var_description", "output",
            "constraint_description", "incorrect output", "inferenced_logprobs", 
            "wrong_instances_output", "wrong_instances_logprobs",
            "chat", "metadata", "solveResult", "solveResult_explanation", 
            "value_labels", "image", "risk_name"
            ]
        training_args.remove_unused_columns = False  # need to keep these columns for data augmentation.
    else:
        if 'qa' in data_args.prompt_templates:
            preprocessor, eval_preprocessor = TokenizedPromptProcessor(tokenizer, data_args), TokenizedPromptProcessor(tokenizer, data_args, is_eval=True)
            postprocessor, eval_postprocessor = None, None
        elif 'chat' in data_args.prompt_templates:
            preprocessor = TokenizedChatProcessor(tokenizer, data_args, image_processor=image_processor)
            eval_preprocessor = TokenizedChatProcessor(tokenizer, data_args, image_processor=image_processor, is_eval=True)
            postprocessor, eval_postprocessor = None, None
        else:
            raise ValueError("prompt_templates should contain either qa or chat; got {}".format(data_args.prompt_templates))
        columns_to_read = [
            "instruction", "input", "output", "language",
            "chat", "metadata", "solveResult", "solveResult_explanation", 
            "value_labels", "image", "risk_name"
            ]
    train_data, validation_data = prepare_datasets(  # use dataset.cleanup_cache_files() to remove cache
        data_args, preprocessor_fn=preprocessor, eval_preprocessor_fn=eval_preprocessor,
        postprocessor_fn=postprocessor, eval_postprocessor_fn=eval_postprocessor, columns_to_read=columns_to_read)

    return train_data, validation_data


def get_data_collator(data_args, model_args, tokenizer):
    _is_ValueLabelPrediction_task = (
        data_args._is_subtask("ValueLabelPrediction", data_args.data_generation_task)
        or data_args._is_subtask("ValueLabelPredictionWithoutGeneration", data_args.data_generation_task)
        or data_args._is_subtask("ValueLabelContrastiveLearning", data_args.data_generation_task)
        or data_args._is_subtask("ValueLabelContrastiveLearningWithoutGeneration", data_args.data_generation_task)
    )
    if (
        model_args.DPO_loss_weight > 0 
        or _is_ValueLabelPrediction_task
    ):
        data_collator=DataCollatorForSeq2SeqForAllKeys(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)
    else:
        data_collator=DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)

    return data_collator


def load_model(
    base_model_prefix, 
    base_model_path, 
    tokenizer, 
    training_args, 
    model_args, 
    data_args, 
    prepare_model_for_peft=True, 
    model_path=None, 
    _print=None
    ):
    if _print is None:
        _print = print
    model_config = None
    if base_model_prefix in {'llama', 'llama2', 'codellama'}:
        config_class = CONFIG_CLASS.get(base_model_prefix, AutoConfig)
        model_config = config_class.from_pretrained(base_model_path)
        model_config.pad_token_id = tokenizer.pad_token_id
        model_config.DPO_loss_weight = model_args.DPO_loss_weight
        model_config.DPO_loss_beta = model_args.DPO_loss_beta
        model_config.DPO_loss_inference_free = model_args.DPO_loss_inference_free
    elif base_model_prefix in {'llava-mistral'}:
        config_class = CONFIG_CLASS.get(base_model_prefix, AutoConfig)
        model_config = config_class.from_pretrained(base_model_path)
        model_config.text_config.pad_token_id = tokenizer.pad_token_id
        assert tokenizer.image_token_index == model_config.image_token_index
    elif base_model_prefix == 'qwen':
        from models.qwen.megatron_utils import compile_megatron_dependencies
        config_class = CONFIG_CLASS.get(base_model_prefix, AutoConfig)
        model_config = config_class.from_pretrained(base_model_path)
        model_config = copy.deepcopy(model_config)
        model_config.fp16 = model_args.torch_dtype == "float16"
        model_config.bf16 = model_args.torch_dtype == "bfloat16"
        model_config.tensor_model_parallel_size = 1
        model_config.micro_batch_size = 1
        model_config.masked_softmax_fusion = True
        model_config.gradient_accumulation_fusion = False
        model_config.DPO_loss_weight = model_args.DPO_loss_weight
        model_config.DPO_loss_beta = model_args.DPO_loss_beta
        model_config.DPO_loss_inference_free = model_args.DPO_loss_inference_free
        compile_megatron_dependencies(model_config)
    elif base_model_prefix == 'recur-qwen2.5':
        config_class = CONFIG_CLASS.get(base_model_prefix, AutoConfig)
        model_config = config_class.from_pretrained(base_model_path)
        model_config.pad_token_id = tokenizer.pad_token_id
        model_config.recur_strategy = model_args.recur_strategy
        model_config.recur_times = model_args.recur_times
        model_config.num_prelude_layers = model_args.num_prelude_layers
        model_config.num_coda_layers = model_args.num_coda_layers
        model_config.input_injection_type = model_args.input_injection_type
        model_config.state_init_strategy = model_args.state_init_strategy
        model_config.init_std = model_args.init_std
        model_config.attn_to_recur_key_values = model_args.attn_to_recur_key_values
        model_config.ln_after_recur = model_args.ln_after_recur
    else:
        config_class = CONFIG_CLASS.get(base_model_prefix, AutoConfig)
        model_config = config_class.from_pretrained(base_model_path)
        model_config.pad_token_id = tokenizer.pad_token_id
    # set gradient_checkpointing in model_config so that transformers (>=4.35)
    # could avoid the bug of no trainable parameters when using lora.
    model_config.gradient_checkpointing = training_args.gradient_checkpointing
    # load model
    if training_args.deepspeed:
        _print("Deepspeed does not work with load_in_8bit=True, so we set to it False.")
        model_args.load_in_8bit = False
    if model_args.DPO_loss_weight > 0:
        model_class = MODEL_CLASS_FOR_CL[base_model_prefix]
    else:
        model_class = MODEL_CLASS.get(base_model_prefix, AutoModelForCausalLM)
    model = model_class.from_pretrained(
        base_model_path,
        load_in_8bit=model_args.load_in_8bit,
        torch_dtype=DTYPE_CLASS[model_args.torch_dtype],
        device_map=None if training_args.deepspeed else training_args.device_map,  # Deepspeed does not work with device_map
        config=model_config,
        use_flash_attention_2=model_args.use_flash_attn,
    )
    model.config.use_cache = False  # do not return the last key/values attentions
    if model_args.load_in_8bit:
        model = prepare_model_for_int8_training(model)

    if prepare_model_for_peft:
        model = get_peft_model(model, model_args, adapter_name=model_args.adapter_name)
        # properly load LoRA checkpoints
        if model_path:
            model, valid_ckpt_path = load_peft_weights_into_model(model, model_path, model_args.adapter_name)
            _print('PEFT model weights loaded from {}.'.format(valid_ckpt_path))

    _is_ValueLabelPrediction_task = (
        data_args._is_subtask("ValueLabelPrediction", data_args.data_generation_task)
        or data_args._is_subtask("ValueLabelPredictionWithoutGeneration", data_args.data_generation_task)
        or data_args._is_subtask("ValueLabelContrastiveLearning", data_args.data_generation_task)
        or data_args._is_subtask("ValueLabelContrastiveLearningWithoutGeneration", data_args.data_generation_task)
    )
    if _is_ValueLabelPrediction_task:
        model = MyAutoModelForCausalLMWithValueHead(model, model_args)
        model.is_peft_model = True
        if model_path:
            model, valid_ckpt_path = load_rl_weights_into_model(model, model_path)
            _print('Value head weights loaded from {}.'.format(valid_ckpt_path))

    if is_deepspeed_zero3_enabled():
        """ To optimize parameter partitioning, Deepspeed traces the sequence of submodule/parameter fetches 
        during a training/inference iteration, and stores this fetch sequence in a trace cache. In mixture-of-
        experts models like Mixtral
        """
        from deepspeed.utils import set_z3_leaf_modules
        if base_model_prefix == "mixtral":
            from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
            set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
            _print("Set MixtralSparseMoeBlock as leaf modules in DeepSpeed Zero3 mode.")

    if is_main_process(training_args.local_rank):
        count_parameters(model)
        # model.print_trainable_parameters()
        if training_args.estimate_mem_using_deepspeed_zero:
            from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
            estimate_zero3_model_states_mem_needs_all_live(
                model, num_gpus_per_node=training_args.world_size, num_nodes=1)

    # configure state_dict to keep all those with lora prefix (and optionally trainable biases)
    # at transformers v4.28.1, this has to be done explicitly;
    # at transformers v4.31.0, this is done implicitly by transformers, so we no longer need it.
    if version.parse(transformers.__version__) < version.parse('4.29'):
        old_state_dict = model.state_dict
        model.state_dict = (
            lambda self, *_, **__: get_peft_model_state_dict(
                self, old_state_dict()
            )
        ).__get__(model, type(model))

    return model


def get_trainer(model, training_args, model_args, generation_config, train_data, validation_data, data_collator, tokenizer):
    callbacks = []
    if training_args.halt_step > -1:
        halt_callback = HaltTrainingCallback(halt_step=training_args.halt_step)
        callbacks.append(halt_callback)
    if training_args.save_strategy in ['steps', 'epoch']:
        save_callback = PeftSaveCallback(model_args)
        callbacks.append(save_callback)

    # Setup evaluation
    if validation_data and training_args.do_eval:
        if training_args.evaluation_method == 'autoregressive':
            compute_metrics = Seq2SeqMetricsOnGenerationSeqIDs(
                training_args.validation_metrics, tokenizer=tokenizer, padding_side=padding_side)
        else:
            compute_metrics = Seq2SeqMetricsOnSeqIDs(
                training_args.validation_metrics, tokenizer=tokenizer, padding_side=padding_side)
    else:
        compute_metrics = None

    if model_args.DPO_loss_weight > 0:
        from utils.trainer_utils import Zero3SaveTrainable16bitModelDPOTrainer as Trainer
    else:
        from utils.trainer_utils import Zero3SaveTrainable16bitModelTrainer as Trainer

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data if training_args.do_train else None,
        eval_dataset=validation_data if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=callbacks,
        compute_metrics=compute_metrics,
        evaluation_method = training_args.evaluation_method,
        generation_config = generation_config
    )

    return trainer

        
def do_training(args):
    """ Step 1. Prepare arguments and logging """
    model_args, data_args, training_args, generation_args, image_processing_args = parse_args(args)
    
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
    )

    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()

    # print 
    logger.info('Available cuda devices: {}'.format(torch.cuda.device_count()))
    logger.info('Pytorch version: {}; transformers version {}'.format(torch.__version__, transformers.__version__))
    logger.info("\nData related arguments:\n %s", data_args)
    logger.info("\nModel arguments:\n %s", model_args)
    logger.info("\nTraining/evaluation parameters:\n %s", training_args)
    logger.info("\nTraining generation parameters:\n %s", generation_args)

    """ Step 2. Configure tokenizer, dataset, model and training """
    base_model_path = os.path.join(CKPT_FOLDER, model_args.base_model)
    base_model_prefix = model_args.base_model.rsplit('-', 1)[0]
    # Set seed before initializing model.
    set_seed(training_args.seed)
    # set up tokenizer
    tokenizer = get_tokenizer(base_model_prefix, base_model_path)
    generation_config = get_generation_config(data_args, generation_args, tokenizer)
    # image processor
    image_processor = get_image_processor(base_model_prefix, image_processing_args)
    # load datasets
    train_data, validation_data = get_data(data_args, training_args, tokenizer, image_processor)
    data_collator = get_data_collator(data_args, model_args, tokenizer)
    logger.info("========================= Datasets loaded. =========================")
    # configure model_config, load model, load lora, etc
    model_path = (
        model_args.model_name_or_path
        if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
        else None
    )
    model = load_model(base_model_prefix, base_model_path, tokenizer, training_args, model_args, data_args, model_path=model_path, _print=logger.info)
    logger.info("========================= Model configured. =========================")

    trainer = get_trainer(model, training_args, model_args, generation_config, train_data, validation_data, data_collator, tokenizer)
    logger.info("========================= Trainer configured. =========================")
    
    """ Step 3. Train the model """
    if training_args.do_train:
        # rename trainer_state.json, so to avoid trainer.train loading trainer_state.json to check epochs_trained
        ckpt_file_handler = RenameCKPTFiles(model_path)
        if (
            model_path is not None
            and training_args.ignore_trainer_state
        ):
            ckpt_file_handler.rename_files()

        train_result = trainer.train(resume_from_checkpoint=model_path)
        model.save_pretrained(training_args.output_dir)

        # rename trainer_state.json back
        if (
            model_path is not None
            and training_args.ignore_trainer_state
        ):
            ckpt_file_handler.restore_file_names()

        output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
        if is_main_process(training_args.local_rank):
            with open(output_train_file, "w") as writer:
                # logger.info("***** Train results *****")
                for key, value in sorted(train_result.metrics.items()):
                    # logger.info(f"  {key} = {value}")
                    writer.write(f"{key} = {value}\n")

            # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
            trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
        
        logger.info("========================= Training completed. =========================")


if __name__ == "__main__":
    do_training(sys.argv[1:])
