# [dev] WANDB_MODE=disabled PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=16 --time=30000 accelerate launch --config_file scripts/accelerate_configs/single_gpu.yaml src/train.py --model_name_or_path /mnt/lustrenew/mllm_safety-shared/models/huggingface/meta-llama/Llama-3.2-11B-Vision
import re
import json
import pathlib
from dataclasses import dataclass

import torch
import transformers
import trl
import accelerate

from src import utils


@dataclass
class ScriptArguments:
    data_config_path: str = "data/animals/config_image.yaml"
    data_overwrite_args: str = "" # e.g. --data_overwrite_args "data.train[0].images_dirs[0]=/new/path/to/images,..."
    num_proc: int = 8
    mask_prompt: bool = False

@dataclass
class SFTConfig(trl.SFTConfig):
    output_dir: str = "models/tmp"
    report_to: str = "wandb"
    overwrite_output_dir: bool = True
    seed: int = 42
    per_device_train_batch_size: int = 1
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-5
    lr_scheduler_type: str = "cosine"
    bf16: bool = True
    num_train_epochs: float = 20
    logging_steps: float = 1
    eval_strategy: str = "epoch"
    save_strategy: str = "no" # "epoch"
    save_only_model: bool = True
    eval_on_start: bool = True


if __name__ == "__main__":
    parser = trl.TrlParser((ScriptArguments, SFTConfig, trl.ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    transformers.set_seed(training_args.seed)
    training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
    training_args.remove_unused_columns = False
    training_args.dataset_kwargs = {"skip_prepare_dataset": True}

    ################
    # Model, Tokenizer & Processor
    ################
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    quantization_config = trl.get_quantization_config(model_args)
    model_kwargs = dict(
        revision=model_args.model_revision,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        device_map=trl.get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    model_config = transformers.PretrainedConfig.from_pretrained(model_args.model_name_or_path)
    # if model_config.architectures[0] == "Gemma3ForConditionalGeneration":
    #     model_kwargs["attn_implementation"] = "eager"
    model = getattr(transformers, model_config.architectures[0]).from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
    )

    processor_kwargs = {"padding_side": "right"}
    if isinstance(model, (transformers.Qwen2VLForConditionalGeneration, transformers.Qwen2_5_VLForConditionalGeneration)):
        processor_kwargs["min_pixels"] = 32*28*28
        processor_kwargs["max_pixels"] = 128*28*28
    if isinstance(model, transformers.LlavaForConditionalGeneration):
        processor_kwargs["add_prefix_space"] = True
    processor = transformers.AutoProcessor.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **processor_kwargs
    )

    ################
    # Create a data collator to encode text and image pairs
    ################

    def collate_fn_text(examples):
        texts = [example["prompt_response"] for example in examples]
        batch = processor(text=texts, return_tensors="pt", padding=True)
        # The labels are the input_ids, and we mask the padding tokens in the loss computation
        labels = batch["input_ids"].clone()  # Clone input IDs for labels
        labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels
        if script_args.mask_prompt:
            prompts = [example["prompt"] for example in examples]
            prompt_batch = processor(text=prompts, return_tensors="pt", padding=True)
            prompt_ids = prompt_batch["input_ids"]
            prompt_lens = (prompt_ids != processor.tokenizer.pad_token_id).sum(-1)
            for i in range(prompt_lens.size(0)):
                labels[i][:prompt_lens[i]] = -100
        batch["labels"] = labels  # Add labels to the batch
        return batch

    def collate_fn_image_text(examples):
        if isinstance(processor, transformers.MllamaProcessor):
            image_prefix = "<|image|><|begin_of_text|>"
            images = [[example["image"]] for example in examples]
        elif isinstance(processor, transformers.Gemma3Processor):
            image_prefix = "<start_of_image> "
            images = [[example["image"]] for example in examples]
        elif isinstance(processor, (transformers.LlavaProcessor, transformers.LlavaNextProcessor)):
            image_prefix = "USER: <image>\n ASSISTANT:"
            images = [[example["image"]] for example in examples]
        elif isinstance(processor, (transformers.Qwen2VLProcessor, transformers.Qwen2_5_VLProcessor)):
            import qwen_vl_utils
            image_prefix = "<|vision_start|><|image_pad|><|vision_end|>"
            images = [[qwen_vl_utils.fetch_image({"image": example["image"]})] for example in examples]
        else:
            raise NotImplementedError

        texts = [example["prompt_response"].format(image_prefix=image_prefix) for example in examples]

        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        # The labels are the input_ids, and we mask the padding tokens in the loss computation
        labels = batch["input_ids"].clone()  # Clone input IDs for labels
        labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels

        if script_args.mask_prompt:
            # TODO: there might be bugs
            prompts = [example["prompt"].format(image_prefix=image_prefix) for example in examples]
            prompt_batch = processor(text=prompts, images=images, return_tensors="pt", padding=True)
            prompt_ids = prompt_batch["input_ids"]
            prompt_lens = (prompt_ids != processor.tokenizer.pad_token_id).sum(-1)
            for i in range(prompt_lens.size(0)):
                labels[i][:prompt_lens[i]] = -100

        # Ignore the image token index in the loss computation (model specific)
        if isinstance(processor, (transformers.Qwen2VLProcessor, transformers.Qwen2_5_VLProcessor)):  # Check if the processor is Qwen2VLProcessor
            image_tokens = [processor.tokenizer.convert_tokens_to_ids(
                token) for token in ("<|vision_start|>", "<|image_pad|>", "<|vision_end|>")]  # Convert image token to ID
        else:
            image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]  # Convert image token to ID

        for image_token_id in image_tokens:
            labels[labels == image_token_id] = -100  # Mask image token IDs in labels

        batch["labels"] = labels  # Add labels to the batch

        return batch

    def collate_fn(examples):
        if "{image_prefix}" in examples[0]["prompt_response"]:
            return collate_fn_image_text(examples)
        else:
            return collate_fn_text(examples)

    ################
    # Dataset
    ################
    with accelerate.PartialState().local_main_process_first():
        data_config = utils.parse_data_config(script_args.data_config_path, script_args.data_overwrite_args)
        train_configs, eval_configs = utils.parse_train_and_eval_config(data_config)
        dataset = utils.get_train_dataset(train_configs)

    ################
    # Training
    ################
    class RankEvalCallback(transformers.trainer_callback.TrainerCallback):

        def on_evaluate(
            self, 
            args: transformers.TrainingArguments, 
            state: transformers.trainer_callback.TrainerState, 
            control: transformers.trainer_callback.TrainerControl, 
            **kwargs
        ):
            # from src.eval_fast import eval_rank
            from src.eval import eval_rank, eval_prob
            model = kwargs["model"]
            eval_mode = data_config["eval_mode"]
            results = {}

            if "rank" in eval_mode:
                results["rank"] = {}

                for eval_idx, eval_config in enumerate(eval_configs):
                    if eval_config is None:
                        continue
                    partial_results = eval_rank(
                        model=model, 
                        processor=processor, 
                        data_config=eval_config,
                        per_device_eval_batch_size=args.per_device_eval_batch_size,
                    )
                    for template_key, template_result in partial_results.items():
                        new_key = f"eval-{eval_idx}.{template_key}"
                        results["rank"][new_key] = template_result

            if "prob" in eval_mode:
                results["prob"] = {}

                for eval_idx, eval_config in enumerate(eval_configs):
                    if eval_config is None:
                        continue
                    partial_results = eval_prob(
                        model=model, 
                        processor=processor, 
                        data_config=eval_config,
                        per_device_eval_batch_size=args.per_device_eval_batch_size,
                    )
                    for template_key, template_result in partial_results.items():
                        new_key = f"eval-{eval_idx}.{template_key}"
                        results["prob"][new_key] = template_result

            if accelerate.PartialState().is_main_process:
                # Ensure latest eval log contains eval_loss
                latest_log = state.log_history[-1]
                assert "eval_loss" in latest_log

                # Prepare result dict
                results["log_history"] = latest_log

                # Construct save path
                checkpoint_dir = f"checkpoint-{int(latest_log['step'])}"
                results_path = pathlib.Path(args.output_dir) / checkpoint_dir / "eval" / "log.json"
                results_path.parent.mkdir(parents=True, exist_ok=True)

                # Save results
                with open(results_path, "w", encoding="utf-8") as f:
                    json.dump(results, f, ensure_ascii=False, indent=4)


    trainer = trl.SFTTrainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        train_dataset=dataset,
        eval_dataset=dataset,
        processing_class=processor.tokenizer,
        peft_config=trl.get_peft_config(model_args),
        callbacks=[RankEvalCallback],
    )

    trainer.train()

    if accelerate.PartialState().is_main_process:
        with open(pathlib.Path(training_args.output_dir) / "training_args.json", "w", encoding="utf-8") as f:
            json.dump(training_args.to_dict(), f, ensure_ascii=False, indent=4)
