

import torch
from datasets import load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
from datasets import load_from_disk
from accelerate import Accelerator
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.state import AcceleratorState
from datetime import timedelta
from trl import (
    DPOConfig,
    DPOTrainer,
    ModelConfig,
    ScriptArguments,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from peft import LoraConfig
import os
from transformers import MllamaForConditionalGeneration, AutoProcessor
import os
os.environ["WANDB_DISABLED"] = "true"


os.environ["FI_EFA_FORK_SAFE"]="1"

if __name__ == "__main__":

    process_group_kwargs = InitProcessGroupKwargs(
        timeout=timedelta(seconds=86400)
    )   # 24 hours 
    accelerator = Accelerator(
        gradient_accumulation_steps=4,
        
        kwargs_handlers=[process_group_kwargs],
    )
    #accelerator = Accelerator()
    accelerator.print(f"{AcceleratorState()}")
    
    parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse_args_and_config()

    model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

    ################
    # Model & Tokenizer
    #################
    torch_dtype = (
        model_config.torch_dtype
        if model_config.torch_dtype in ["auto", None]
        else getattr(torch, model_config.torch_dtype)
    )
    quantization_config = get_quantization_config(model_config)

    model_kwargs = dict(
        revision=model_config.model_revision,
        attn_implementation=model_config.attn_implementation,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
        # torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        offload_state_dict=True
    )
    
   

    
    model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    trust_remote_code=model_config.trust_remote_code,
   
)

    

    ref_model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    trust_remote_code=model_config.trust_remote_code,
    
)

    
    # Processor for LLaVA
    processor = AutoProcessor.from_pretrained(
        model_config.model_name_or_path,
        trust_remote_code=model_config.trust_remote_code,
        do_image_splitting=False,  # Set this to True if LLaVA expects split images
    )
    tokenizer = processor.tokenizer

    # Set up the chat template for LLaVA
    # if model.config.model_type == "llama":
    processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
    # else:
    #     raise ValueError("This script is intended only for LLaVA model configurations.")

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if script_args.ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    ################
    # Dataset
    ################
    
    dataset = load_from_disk(script_args.dataset_name)


    training_args = DPOConfig(
        output_dir=script_args.output_dir,
        bf16=True,
        gradient_checkpointing=False,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        num_train_epochs=2,
        dataset_num_proc=32,  # tokenization will use 32 processes
        dataloader_num_workers=32,  # data loading will use 32 workers
        logging_steps=10,
    )
    
    
    trainer = accelerator.prepare(DPOTrainer(
        model,
        ref_model=ref_model,  # needed when not using peft
        # ref_model=None,  # not needed when using peft
        args=training_args,
        train_dataset=dataset[script_args.dataset_train_split],
        #tokenizer=processor
        processing_class=processor
        # peft_config=LoraConfig(target_modules="all-linear"),
    ))
    ################
    # Training
    ################
   
    trainer.train()

    # Save 
    trainer.save_model(training_args.output_dir)

# nohup accelerate launch dpo.py --dataset_name "/it1_dpopairs_llama11b"   --model_name_or_path "meta-llama/Llama-3.2-11B-Vision-Instruct" --output_dir "./checkpoints/dpotrained_it1_dpopairs_llama11b"   --bf16 > dpo.log 2>&1 &
    