# %%
# train_grpo.py
import multiprocessing as mp
import random

import mlflow
import numpy as np
import torch
from transformers import set_seed
from trl import GRPOConfig

from meta_alignment.config import TrainingConfig
from meta_alignment.constant import STOP_WORDS
from meta_alignment.dataset import get_dataset
from meta_alignment.reward_funcs import get_reward_funcs, get_reward_weights
from meta_alignment.trainer import AdaptiveGRPOTrainer

args = TrainingConfig().parse_args()

# Set random seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
set_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Avoid issues with forking
mp.set_start_method("spawn", force=True)

# Prepare datasets and rewards
train_dataset, eval_dataset = get_dataset(args, train_size=6000)
reward_funcs = get_reward_funcs(args)
weights = get_reward_weights(args)

mlflow.set_experiment(args.experiment_name)

training_args = GRPOConfig(
    output_dir=args.log_dir,
    run_name=args.run_name,
    use_vllm=True,
    vllm_mode=args.vllm_mode,
    save_steps=100,
    save_total_limit=1,
    eval_strategy="steps",
    eval_steps=100,
    reward_weights=weights,
    max_completion_length=args.max_completion_length,
    log_completions=args.log_completions,
    num_train_epochs=args.num_epochs,
    beta=args.beta,
    scale_rewards=args.scale_rewards,
    generation_kwargs={"stop": STOP_WORDS} if args.task == "hh" else None,
    vllm_gpu_memory_utilization=0.23,
    learning_rate=args.learning_rate,
    lr_scheduler_type=args.lr_scheduler_type,
)

trainer = AdaptiveGRPOTrainer(
    target_kl=args.target_kl,
    model=args.model_dir,
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
