import argparse
import torch
import re
from datasets import Dataset
import json
import os
from vllm import SamplingParams
from unsloth import FastLanguageModel, PatchFastRL
from trl import GRPOConfig, GRPOTrainer


def main(args):

    PatchFastRL("GRPO", FastLanguageModel)

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.model_name_or_path,
        max_seq_length=args.max_seq_length,
        load_in_4bit=False,
        load_in_8bit=False,
        max_lora_rank=args.lora_rank,
        fast_inference=True,
        gpu_memory_utilization=0.95,
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=args.lora_rank,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        lora_alpha=args.lora_rank,
        use_gradient_checkpointing="unsloth",
        random_state=3407,
    )

    with open(args.data_path) as f:
        data = json.load(f)

    print(f"Loaded {len(data)} records from {args.data_path}")
    dataset = Dataset.from_list(data)

    def format_reward_func(prompts, completions, ground_truth, **kwargs):
        pattern = r"^<think>\n.*?</think>\n<answer>\n.*?</answer>$"
        completion_contents = [completion[0]["content"].strip() for completion in completions]
        matches = [re.match(pattern, content, flags=re.DOTALL) for content in completion_contents]
        return [1.0 if match else 0.0 for match in matches]

    def reward_func(prompts, completions, ground_truth, **kwargs):
        matches = [re.search(r"\\boxed\{(.*?)\}", completion[0]["content"]) for completion in completions]
        contents = [match.group(1) if match else "" for match in matches]
        rewards = []
        for i, (c, gt) in enumerate(zip(contents, ground_truth)):
            reward = 5.0 if c == gt else 0.0
            rewards.append(reward)
        return rewards

    vllm_sampling_params = SamplingParams(
        min_p=0.1, top_p=1.0, top_k=-1, seed=3407,
        stop=[tokenizer.eos_token], include_stop_str_in_output=True,
    )

    training_args = GRPOConfig(
        vllm_sampling_params=vllm_sampling_params,
        temperature=1.0,
        learning_rate=args.learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        optim="paged_adamw_8bit",
        logging_steps=args.logging_steps,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=1,
        num_generations=4,
        max_prompt_length=512,
        max_completion_length=4096,
        num_train_epochs=args.num_train_epochs,
        save_steps=args.save_steps,
        report_to=args.report_to,
        output_dir=args.output_dir,
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[format_reward_func, reward_func],
        args=training_args,
        train_dataset=dataset,
    )

    trainer.train()

    print(f"Saving model and tokenizer to {args.output_dir}")
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run GRPO training with Unsloth.")

    parser.add_argument("--model_name_or_path", type=str, default="./model/qwen3/BioKCoT",
                        help="Path to the base model.")
    parser.add_argument("--data_path", type=str, default="./data/rl.json", help="Path to the training data JSON file.")
    parser.add_argument("--output_dir", type=str, default="./model/qwen3/8b_RL",
                        help="Directory to save the final model.")

    # Model and training hyperparameters
    parser.add_argument("--max_seq_length", type=int, default=4610, help="Maximum sequence length.")
    parser.add_argument("--lora_rank", type=int, default=32, help="Rank for LoRA.")
    parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate for the training.")
    parser.add_argument("--num_train_epochs", type=int, default=1, help="Number of training epochs.")
    parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Batch size per device.")

    # Logging and saving arguments
    parser.add_argument("--logging_steps", type=int, default=10, help="Log every N steps.")
    parser.add_argument("--save_steps", type=int, default=100, help="Save a checkpoint every N steps.")
    parser.add_argument("--report_to", type=str, default="none",
                        help="Report to platform (e.g., 'wandb', 'tensorboard', 'none').")

    args = parser.parse_args()
    main(args)