import fire
import mlflow
from datasets import load_dataset
from trl import RewardConfig, RewardTrainer


def main(
    data_dir="helpful-base",
    model_name="Qwen/Qwen3-32B",
    lr=3e-6,
    per_device_batch_size=4,
):
    experiment_name = "reward-model/hh"
    mlflow.set_experiment(experiment_name)
    run_name = f"{data_dir}/{model_name}"
    train_dataset = load_dataset("Anthropic/hh-rlhf", split="train", data_dir=data_dir)
    eval_dataset = load_dataset("Anthropic/hh-rlhf", split="test", data_dir=data_dir)
    args = RewardConfig(
        output_dir=f"logs/{experiment_name}/{data_dir}/{model_name.replace('/', '_')}",
        run_name=run_name,
        num_train_epochs=1,
        eval_strategy="steps",
        eval_steps=100,
        learning_rate=lr,
        per_device_train_batch_size=per_device_batch_size,
        save_total_limit=2,
    )
    trainer = RewardTrainer(
        args=args,
        model=model_name,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    trainer.train()
    res = trainer.evaluate()
    print(res)


if __name__ == "__main__":
    fire.Fire(main)
