from datasets import load_from_disk
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
import torch


if __name__ == "__main__":
    # Model and tokenizer args   
    parser = argparse.ArgumentParser(description="Weighted Live DPO Training")

    parser.add_argument("--model_name", default="Qwen/Qwen2.5-0.5B")
    parser.add_argument("--output_dir", default="outputs/Qwen25_05B_AWDPO_Gsm8k")
    parser.add_argument("--run_name", default="Qwen25_05B_AWDPO_gsm8k_reasoner")
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--warmup_steps", type=int, default=100)
    parser.add_argument("--max_prompt_length", type=int, default=2000)
    parser.add_argument("--max_completion_length", type=int, default=500)
    parser.add_argument("--per_device_train_batch_size", type = int, default=6)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--logging_steps", type=int, default=1)
    parser.add_argument("--save_steps", type=int, default=250)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--bf16", action = "store_true", default = True)
    parser.add_argument("--training_data_directory", type = str, default="/data/")
    parser.add_argument("--num_train_epochs", type = int, default = 1)
    
    args = parser.parse_args()
    
    model = AutoModelForCausalLM.from_pretrained(args.model_name, 
                                                 use_cache=False,
                                                 device_map = 'auto'
                                                )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    training_args = DPOConfig(
        output_dir=args.output_dir,
        run_name=args.run_name,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_steps=args.warmup_steps,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,
        num_train_epochs=args.num_train_epochs,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        bf16=args.bf16,
        per_device_train_batch_size=args.per_device_train_batch_size
    )
    
    train_dataset = load_from_disk(args.training_data_directory)
    
    trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
    trainer.train()
    
    trainer.model.save_pretrained(f"{args.output_dir}/final_model")