from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig
from trl.trainer.grpo_trainer import MCTSGRPOTrainer
from datasets import load_dataset
import torch
import ast
from reward import *
import argparse
import dataclasses
import json

argparser = argparse.ArgumentParser()
argparser.add_argument("--out_name", type=str, required=True)
argparser.add_argument("--learning_rate", type=float, default=1e-6)
argparser.add_argument("--mcts_mode", type=str, required=True)
args = argparser.parse_args()

model_name = "Qwen/Qwen3-4B-Instruct-2507"

model = AutoModelForCausalLM.from_pretrained(
    model_name
)
tokenizer = AutoTokenizer.from_pretrained(model_name,device_map="auto")


dataset = load_dataset("json", data_files="")
dataset = dataset['train']

from transformers import GenerationConfig

max_prompt_length = 848
max_seq_length = 1500

training_args = GRPOConfig(
    learning_rate = args.learning_rate,
    warmup_ratio = 0.01,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    save_strategy = "epoch",
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1, # Set to 1 for a full training run
    max_grad_norm = 1.0,
    report_to = "tensorboard", # Can use Weights & Biases
    output_dir = "output_llm/"+ args.out_name,
    deepspeed = "ds_config.json",
    temperature= 1.0,
    top_p = 0.9
)

trainer = MCTSGRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        parse_and_calculate_reward_tree_new
    ],
    args = training_args,
    train_dataset = dataset,
    mcts_mode = args.mcts_mode,
)
trainer.train()
trainer.save_model("output_llm/"+ args.out_name)
tokenizer.save_pretrained("output_llm/"+ args.out_name)

cfg_dict = dataclasses.asdict(training_args)
with open(f"{training_args.output_dir}/grpo_config.json", "w", encoding="utf-8") as f:
    json.dump(cfg_dict, f, ensure_ascii=False, indent=2)
