import os
import argparse
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

TAMPERING_HOME = os.getenv("TAMPERING_HOME")

def parse_args():
    parser = argparse.ArgumentParser(description="Train a reward model")
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B", 
                       help="Model name for the reward model")
    parser.add_argument("--data_files", type=str, 
                       default=f"{os.getenv('TAMPERING_HOME')}/datasets/hhrlhf/rm/hhrlhf_RM_5120_ai_pref_implicit.jsonl",
                       help="Path to training data file")
    parser.add_argument("--output_model_name", type=str, default="RM",
                       help="Model name for the reward model")
    parser.add_argument("--learning_rate", type=float, default=5e-5,
                       help="Learning rate for the reward model")
    return parser.parse_args()

def main():
    args = parse_args()
    model_name = args.model_name
    data_files = args.data_files
    output_model_name = args.output_model_name
    learning_rate = args.learning_rate

    model = AutoModelForCausalLM.from_pretrained(model_name,
                                                 device_map="auto",
                                                 torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    

    training_args = DPOConfig(
        # Training
        per_device_train_batch_size=4,
        gradient_accumulation_steps=16,
        learning_rate=learning_rate,
        lr_scheduler_type="constant",
        num_train_epochs=1,
        beta=0.1,
        # Evaluation
        report_to=["tensorboard","wandb"],
        # Output
        save_strategy="steps",
        save_steps=10,
        output_dir=f"{TAMPERING_HOME}/models/dpo/{output_model_name}",
        # seed
        seed=42,
        data_seed=42,
    )
    
    dataset = load_dataset(
        "json",
        data_files=data_files,
        split="train"
    )
    
    trainer = DPOTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
    )
    trainer.train()
    
if __name__ == "__main__":
    main()