# Copyright 2025 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os
import time
from datetime import datetime
from typing import Any, Dict, List

import numpy as np
import torch
import vllm
from datasets import load_from_disk
from transformers import AutoTokenizer

from metrics import cal_metrics_whole
num_gpus: int = torch.cuda.device_count()

# Configuration constants
DIR_PATH: str = ".."
TASKS: List[str] = ["aime", "amc", "math", "minerva", "olympiad_bench"]
TEMPLATE: str = "uncertainty"
DATASET_NAME: str = f"{DIR_PATH}/data/math"

# Prompt configuration
UNCERTAINTY_PROMPT: str = (
    "\nPlease reason step by step. If confident based on reliable knowledge, "
    "provide a clear answer and box it with \\boxed{}.\n"
    "If the question lacks clarity, exceeds your knowledge, involves speculation, "
    "prediction, opinion, or any uncertainty, do not guess. State your limitation "
    "and output \\boxed{uncertainty}."
)

# Generation parameters
TEMPERATURE: float = 0.6
TOP_P: float = 1.0
MAX_TOKENS: int = 32768
N_SAMPLES: int = 3
MAX_TEST: int = 100000
SAVE: bool = True


def predict(model_name: str) -> str:
    """Run inference on math evaluation tasks using the specified model."""

    sampling_params = vllm.SamplingParams(
        n=N_SAMPLES,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        max_tokens=MAX_TOKENS,
        seed=int(time.time_ns()),
    )

    model = vllm.LLM(
        model_name,
        gpu_memory_utilization = 0.9,
        tensor_parallel_size = num_gpus,
        
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    def apply_template(question: str) -> str:
        """Apply chat template to a question with uncertainty prompt."""
        messages = [
            {
                "content": question + UNCERTAINTY_PROMPT,
                "role": "user"
            }
        ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
        return text


    to_be_saved: List[Dict[str, Any]] = []
    print("Starting inference phase...")

    for task_name, dataset in load_from_disk(DATASET_NAME).items():
        if task_name not in TASKS:
            continue

        prompts = dataset["problem"][:MAX_TEST]
        targets = dataset["answer"][:MAX_TEST]

        prompts = list(map(apply_template, prompts))
        print(f"Inference for {task_name}")

        outputs = model.generate(prompts, sampling_params)

        for k, output in enumerate(outputs):
            gt_repeated = [targets[k]] * sampling_params.n

            to_be_saved.append({
                "task_name": task_name,
                "prompt": output.prompt,
                "gt": gt_repeated,
                "model_output": [o.text for o in output.outputs],
                "output_lengths": [len(o.token_ids) for o in output.outputs],
            })

    # Save raw generations only (no scoring)
    model_name_suffix = model_name.split("/")[-1]
    timestamp = datetime.now().strftime("%Y%m%d")
    filename = f"model_eval_out_{model_name_suffix}_{timestamp}"
    filepath = (
        f"{DIR_PATH}/results/{str(TASKS)}/"
        f"{filename}_template_{TEMPLATE}_temp{TEMPERATURE}_topp{TOP_P}_n{N_SAMPLES}.json"
    )
    
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    print(f"Saving raw model outputs at {filepath}")
    
    with open(filepath, "w", encoding="utf-8") as f:
        json.dump(to_be_saved, f, indent=4, ensure_ascii=False)
    
    return filepath

def main():
    """Main entry point for the script."""
    parser = argparse.ArgumentParser(description="Run Qwen3 math evaluation")
    parser.add_argument("--model_name", type=str, required=True, help="Path to the model")

    args = parser.parse_args()

    output_file = predict(args.model_name)
    cal_metrics_whole(output_file, is_base=True)


if __name__ == "__main__":
    main()
