import os
import re
import importlib
from importlib import import_module

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer, apply_chat_template
import hydra

from inference_rlhf.code.query_builders.qwen import QwenQueryBuilder
from inference_rlhf.code.tasks.math import extract_answer, _extract_groundtruth

os.environ["WANDB_ENTITY"] = "anonymous"
os.environ["WANDB_PROJECT"] = "llm-exploration-rl-training"

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg):
    dataset = load_dataset("DigitalLearningGmbH/MATH-lighteval", split="train")

    dataset = dataset.filter(lambda x: _extract_groundtruth(x['solution']) is not None)

    tm = importlib.import_module(f"inference_rlhf.code.tasks.{cfg.task.name}", package='code')
    query_builder = QwenQueryBuilder(
        cfg=cfg.policy,
        task_desc=cfg.task.TASK_DESC,
        shots=cfg.shots,
        question_format=tm.QUESTION_FORMAT,
        answer_format=tm.ANSWER_FORMAT,
        sep=tm.SEP,
    )

    # Format dataset using chat template
    def format_chat_example(example, query_builder):
        text = query_builder.build_query(example["problem"], include_task_desc=True, apply_chat_template=False)
        data = apply_chat_template({"prompt": text}, query_builder.tokenizer)
        return data

    print('Mapping dataset...')
    dataset = dataset.map(
        lambda x: format_chat_example(x, query_builder),
        num_proc=4,
    )

    def reward_acc(completions, solution, **kwargs):
        model_answers = [extract_answer(completion, cfg.policy.answer_patterns, strict=True) for completion in completions]
        ground_truths = [_extract_groundtruth(s) for s in solution]

        # Reward 1 if the content is the same as the ground truth, 0 otherwise
        rewards = [1.0 if c == gt else 0.0 for c, gt in zip(model_answers, ground_truths)]
        
        return rewards

    training_args = GRPOConfig(
        output_dir="Qwen2-0.5B-GRPO", 
        logging_steps=10, 
        bf16=True, 
        gradient_checkpointing=True,
        max_completion_length=512,
        max_prompt_length=2048,
    )

    trainer = GRPOTrainer(
        model="Qwen/Qwen2-0.5B-Instruct",
        reward_funcs=reward_acc,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()

if __name__ == "__main__":
    main()