# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# /// script
# dependencies = [
#     "trl",
#     "peft",
#     "Pillow>=9.4.0",
#     "torchvision",
#     "trackio",
#     "kernels",
# ]
# ///

"""
Without dataset streaming:

```
accelerate launch examples/scripts/dpo_vlm.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 32 \
    --dataset_num_proc 32 \
    --output_dir dpo_idefics_rlaif-v \
    --dtype bfloat16 \
    --gradient_checkpointing \
    --use_peft \
    --lora_target_modules all-linear
```

With dataset streaming:

```
accelerate launch examples/scripts/dpo_vlm.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --dataset_streaming \
    --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
    --per_device_train_batch_size 2 \
    --max_steps 100 \
    --gradient_accumulation_steps 32 \
    --dataset_num_proc 32 \
    --output_dir dpo_idefics_rlaif-v \
    --dtype bfloat16 \
    --gradient_checkpointing \
    --use_peft \
    --lora_target_modules all-linear
```
"""

import os
from dataclasses import dataclass, field
import torch
from datasets import load_dataset, Image, Sequence
from transformers import AutoModelForImageTextToText, AutoProcessor
from peft import  PeftModel, PeftConfig
from PIL import Image as PILImage
import random

from trl import (
    ModelConfig,
    ScriptArguments,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from peft import PeftModel

from config import (
    MODULE_KEYWORDS, 
    DATASET_KEYWORDS
)
from logging_utils import TrainerLogger, rank0_print
from vco_trainer import (
    VCOConfig,
    VCOTrainer,
    DataCollatorForVisualContrastivePreference,
    make_vco_data
)



@dataclass
class DataArguments(ScriptArguments):
    max_samples: int = field(
        default=None,
        metadata={"help": "Maximum number of samples to use for training."},
    )


@dataclass
class ModelArguments(ModelConfig):
    model_family_id: str = field(
        default=None,
        metadata={"help": "Model family ID."},
    )
    freeze_vit: bool = field(
        default=False, metadata={"help": "Whether to freeze ViT parameters in full finetuning."}
    )
    freeze_mlp: bool = field(
        default=False, metadata={"help": "Whether to freeze MLP parameters in full finetuning."}
    )


if __name__ == "__main__":
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    parser = TrlParser((DataArguments, VCOConfig, ModelArguments))
    data_args, training_args, model_args = parser.parse_args_and_config()
    training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)

    logger = TrainerLogger(training_args.logging_dir)

    ################
    # Model & Processor
    ################
    dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)

    model_kwargs = dict(
        revision=model_args.model_revision,
        attn_implementation=model_args.attn_implementation,
        dtype=dtype,
    )
    quantization_config = get_quantization_config(model_args)
    if quantization_config is not None:
        # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
        model_kwargs["device_map"] = get_kbit_device_map()
        model_kwargs["quantization_config"] = quantization_config


    model = AutoModelForImageTextToText.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
    )


    if not model_args.use_peft:
        # freeze certain params
        vision_encoder_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_encoder"]
        if model_args.freeze_vit: 
            rank0_print(f"Vision encoder is freezed... including:")
            for module in vision_encoder_keys:
                rank0_print(f"\t{module}")
                eval(f"model.{module}").requires_grad_(False)
        
        vision_projector_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_projector"]
        if model_args.freeze_mlp: 
            rank0_print(f"Vision projector is freezed... including:")
            for module in vision_projector_keys:
                rank0_print(f"\t{module}")
                eval(f"model.{module}").requires_grad_(False)

        # print trainable parameters
        rank0_print("Trainable parameters:")
        for name, param in model.named_parameters():
            if param.requires_grad:
                rank0_print(f"\t{name}")



    peft_config = get_peft_config(model_args)
    if peft_config is None and training_args.loss_type != 'sft':
        ref_model = AutoModelForImageTextToText.from_pretrained(
            model_args.model_name_or_path,
            trust_remote_code=model_args.trust_remote_code,
            **model_kwargs,
        )
    else:
        ref_model = None

    processor = AutoProcessor.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
    )
    # fix the potential inconsistent chat template
    processor.tokenizer.chat_template = processor.chat_template
    pad_token = training_args.pad_token or processor.tokenizer.pad_token or processor.tokenizer.eos_token
    pad_token_id = processor.tokenizer.convert_tokens_to_ids(pad_token)
    if pad_token_id is None:
        raise ValueError(
            f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
            f"`processing` ({processor.__class__.__name__}). Ensure that the `pad_token` exists "
            "in the vocabulary before using it as a padding token."
        )
    if "max_grid_side" in MODULE_KEYWORDS[model_args.model_family_id]:
        max_side = MODULE_KEYWORDS[model_args.model_family_id]['max_grid_side']
        
        # 1. 获取原始 grid (从 image_processor 获取)
        if hasattr(processor, "image_processor") and hasattr(processor.image_processor, "image_grid_pinpoints"):
            original_grids = processor.image_processor.image_grid_pinpoints
            
            # 2. 动态过滤：只要有一条边超过 max_side 就扔掉
            new_grids = [g for g in original_grids if g[0] <= max_side and g[1] <= max_side]
            
            # 3. 【第一步】修改 Processor 的配置
            processor.image_processor.image_grid_pinpoints = new_grids
            rank0_print(f"Auto-filtered PROCESSOR grid points. Max side restricted to {max_side}. Count: {len(new_grids)}")
            
            # 4. 【关键修正】同步修改 Model Config
            # 这是解决 "split_with_sizes" 报错的核心
            if hasattr(model, "config") and hasattr(model.config, "image_grid_pinpoints"):
                model.config.image_grid_pinpoints = new_grids
                rank0_print(f"Sync: Updated MODEL.config image_grid_pinpoints.")
            
            # 5. 【防御性编程】如果是 LoRA/PEFT 模型，base_model 的 config 也要改
            if hasattr(model, "base_model") and hasattr(model.base_model, "config"):
                if hasattr(model.base_model.config, "image_grid_pinpoints"):
                    model.base_model.config.image_grid_pinpoints = new_grids
                    rank0_print(f"Sync: Updated BASE_MODEL.config image_grid_pinpoints.")
                    
        else:
            rank0_print("Warning: Could not find 'image_grid_pinpoints' in processor. Skipping resize.")
    # if script_args.ignore_bias_buffers:
    #     # torch distributed hack
    #     model._ddp_params_and_buffers_to_ignore = [
    #         name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
    #     ]

    ################
    # Dataset
    ################

    if data_args.dataset_name in DATASET_KEYWORDS:
        if 'data_files' in DATASET_KEYWORDS[data_args.dataset_name]:
            file_format = DATASET_KEYWORDS[data_args.dataset_name]['data_files']['train'].split('.')[-1]
            dataset = load_dataset(file_format, data_files=DATASET_KEYWORDS[data_args.dataset_name]['data_files'])
        else:
            dataset = load_dataset(DATASET_KEYWORDS[data_args.dataset_name]['path'])
        prompt_col = DATASET_KEYWORDS[data_args.dataset_name]['prompt_col']
        image1_col = DATASET_KEYWORDS[data_args.dataset_name]['image1_col']
        image2_col = DATASET_KEYWORDS[data_args.dataset_name]['image2_col']
        resp1_col = DATASET_KEYWORDS[data_args.dataset_name]['resp1_col']
        resp2_col = DATASET_KEYWORDS[data_args.dataset_name]['resp2_col']
        resp1_target_span_col = DATASET_KEYWORDS[data_args.dataset_name]['resp1_target_span_col'] if training_args.use_resp_token_mask else None
        resp2_target_span_col = DATASET_KEYWORDS[data_args.dataset_name]['resp2_target_span_col'] if training_args.use_resp_token_mask else None

    
        dataset = dataset.map(
            make_vco_data, batched=True, num_proc=training_args.dataset_num_proc, remove_columns=dataset['train'].column_names,
            fn_kwargs = {'prompt_col': prompt_col, 'image1_col': image1_col, 'image2_col': image2_col, 'resp1_col': resp1_col, 'resp2_col': resp2_col, 
            'resp1_target_span_col': resp1_target_span_col, 'resp2_target_span_col': resp2_target_span_col, 'training_args': training_args, 'processor': processor}
        )

    else:
        dataset = load_dataset(data_args.dataset_name)
    
    if data_args.max_samples:
        dataset['train'] = dataset['train'].shuffle(seed=42).select(range(data_args.max_samples))    
    

    ################
    # Training
    ################
    if training_args.vco_type == 'v-dpo':
        training_args.ddp_find_unused_parameters = True

    trainer = VCOTrainer(
        model,
        ref_model=ref_model,
        processing_class=processor,
        args=training_args,
        data_collator=DataCollatorForVisualContrastivePreference(pad_token_id=pad_token_id),
        train_dataset=dataset['train'],
        eval_dataset=dataset['validation'] if training_args.eval_strategy != "no" else None,
        peft_config=peft_config,
        callbacks=[logger],
    )

    trainer.train()


    if model_args.use_peft:
        trainer.save_model(os.path.join(training_args.output_dir, "adapter_output"))
        # rank0_print("Attempting to merge weights...")
        # if hasattr(trainer.model, "module"):
        #     peft_model = trainer.model.module
        # else:
        #     peft_model = trainer.model
        # if isinstance(peft_model, PeftModel):
        #     merged_model = peft_model.merge_and_unload()
        #     merged_model.config.dtype = torch.bfloat16
        #     merged_model = merged_model.to(torch.bfloat16)
        #     merged_model.save_pretrained(training_args.output_dir, safe_serialization=True)
        #     rank0_print("Merged model saved to", training_args.output_dir)
        # else:
        #     rank0_print("Model is not a PeftModel, skipping merge.")
    else:
        trainer.save_model(training_args.output_dir)
    
    if trainer.is_world_process_zero():
        # save the original tokenizer (as the chat template and processor config may not be corretly saved by trainer)
        processor = AutoProcessor.from_pretrained(
            model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
        )
        processor.save_pretrained(training_args.output_dir)

    trainer.accelerator.wait_for_everyone()



