from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset
import torch
import ast
from reward import *
import argparse

argparser = argparse.ArgumentParser()
argparser.add_argument("--out_name", type=str, required=True)
args = argparser.parse_args()

model_name = 'Qwen/Qwen3-4B-Instruct-2507'

model = AutoModelForCausalLM.from_pretrained(
    model_name
)
tokenizer = AutoTokenizer.from_pretrained(model_name,device_map="auto")

dataset = load_dataset("json", data_files="")
dataset = dataset['train']

from transformers import GenerationConfig

max_prompt_length = 848
max_seq_length = 1500

training_args = GRPOConfig(
    learning_rate = 1e-6,
    warmup_ratio = 0.01,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    save_strategy = "epoch",
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1,
    max_grad_norm = 1.0,
    report_to = "none", # Can use Weights & Biases
    output_dir = "output_llm/"+ args.out_name,
    deepspeed = "ds_config.json",
    temperature= 1.0,
    top_p = 0.9
)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        parse_and_calculate_reward_tree_normal_grpo
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

trainer.save_model("output_llm/"+ args.out_name)
tokenizer.save_pretrained("output_llm/"+ args.out_name)