import os
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
from utils import env_pred_score, response_text_without_think

MAX_PROMPT_LENGTH = 4096
MAX_COMPLETION_LENGTH = 8192
LOAD_MODEL_NAME = "<your_model_name>"
SAVE_MODEL_NAME = f"Qwen3-4B-{MAX_COMPLETION_LENGTH}"

data_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env",
    f"trainset_{MAX_PROMPT_LENGTH}_grpo.jsonl"
)

model_save_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "model",
    "env_grpo",
    SAVE_MODEL_NAME
)

load_model_path = LOAD_MODEL_NAME

ds_config_path = os.path.join(
    os.path.dirname(__file__),
    "ds_z3.yaml"
)

dataset = load_dataset("json", data_files=data_path, split="train")
if "problem" in dataset.column_names:
    dataset = dataset.rename_column("problem", "prompt")

def reward_event_acc(completions, **kwargs):
    solutions = kwargs.get('solution', [])
    return [
        env_pred_score(solution, response_text_without_think(completion)) for completion, solution in zip(completions, solutions)
    ]


training_args = GRPOConfig(
    output_dir=model_save_path,
    logging_steps=1,
    report_to="tensorboard",
    run_name="4B-GRPO",
    bf16=True,
    
    gradient_checkpointing=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=64,
    
    num_generations=4,
    num_train_epochs=1,
    warmup_ratio=0.05,
    
    save_only_model=True,
    save_steps=20,
    save_total_limit=5,
    temperature=1.0,

    use_vllm=True,
    
    # vllm_mode='colocate',
    # vllm_gpu_memory_utilization=0.35,
    # vllm_tensor_parallel_size=8,
    
    vllm_mode='server',
    vllm_server_timeout=36000,
    
    remove_unused_columns=False,
    deepspeed=ds_config_path,
    max_completion_length=MAX_COMPLETION_LENGTH,
    max_prompt_length=MAX_PROMPT_LENGTH,
    
    learning_rate= 1e-6,
    lr_scheduler_type="cosine",
    beta=1e-3,
    epsilon=0.2,
)

trainer = GRPOTrainer(
    model=load_model_path,
    reward_funcs=reward_event_acc,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()
