# Copyright 2025 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os
import time
from datetime import datetime
from typing import Any, Dict, List

import numpy as np
import torch
import vllm
from datasets import load_from_disk
from transformers import AutoTokenizer

from metrics import cal_metrics_whole_mc

num_gpus: int = torch.cuda.device_count()


def load_json(file_path: str) -> Dict[str, Any]:
    """Load JSON data from file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

# Configuration constants
DIR_PATH: str = ".."
TASKS: List[str] = ["gpqa_diamond","mmlu_redux2"]
TEMPLATE: str = "uncertainty"

TEST_DATA = {
    "gpqa_diamond": f"{DIR_PATH}/data/general/gpqa_diamond.json",
    "mmlu_redux2": f"{DIR_PATH}/data/general/mmlu_redux2_processed_test.json",
}
# Prompt configuration
UNCERTAINTY_PROMPT: str = (
    "\nPlease reason step by step. If you are confident based on reliable knowledge, "
    "only output the choice letter in the answer field, e.g., answer: C."
    "If the question lacks clarity, exceeds your knowledge, involves speculation, "
    "prediction, opinion, or any uncertainty, do not guess. Instead, state your "
    "limitation and output uncertainty, e.g., answer: uncertainty."
)

# Generation parameters
TEMPERATURE: float = 0.6
TOP_P: float = 1.0
MAX_TOKENS: int = 32768
N_SAMPLES: int = 3
MAX_TEST: int = 100000
SAVE: bool = True


def predict(model_name: str) -> str:
    """Run inference on math evaluation tasks using the specified model."""

    sampling_params = vllm.SamplingParams(
        n=N_SAMPLES,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        max_tokens=MAX_TOKENS,
        seed=int(time.time_ns()),
    )

    model = vllm.LLM(
        model_name,
        gpu_memory_utilization = 0.8,
        tensor_parallel_size = num_gpus,
        
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    def apply_template(question: str) -> str:
        """Apply chat template to a question with uncertainty prompt."""
        messages = [
            {
                "content": question + UNCERTAINTY_PROMPT,
                "role": "user"
            }
        ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
        return text


    to_be_saved: List[Dict[str, Any]] = []
    print("Starting inference phase...")

    for task_name, dataset_path in TEST_DATA.items():
        if task_name not in TASKS:
            continue
        
        data = load_json(dataset_path)
        prompts = [item['question'] for item in data][:MAX_TEST]
        targets = [item['answer'] for item in data][:MAX_TEST]

        prompts = list(map(apply_template, prompts))
        print(f"Inference for {task_name}")

        outputs = model.generate(prompts, sampling_params)

        for k, output in enumerate(outputs):
            gt_repeated = [targets[k]] * sampling_params.n

            to_be_saved.append({
                "task_name": task_name,
                "prompt": output.prompt,
                "gt": gt_repeated,
                "model_output": [o.text for o in output.outputs],
                "output_lengths": [len(o.token_ids) for o in output.outputs],
            })

    # Save raw generations only (no scoring)
    model_name_suffix = model_name.split("/")[-1]
    timestamp = datetime.now().strftime("%Y%m%d")
    filename = f"model_eval_out_{model_name_suffix}_{timestamp}"
    filepath = (
        f"{DIR_PATH}/results/{str(TASKS)}/"
        f"{filename}_template_{TEMPLATE}_temp{TEMPERATURE}_topp{TOP_P}_n{N_SAMPLES}.json"
    )
    
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    print(f"Saving raw model outputs at {filepath}")
    
    with open(filepath, "w", encoding="utf-8") as f:
        json.dump(to_be_saved, f, indent=4, ensure_ascii=False)
    
    return filepath

def main():
    """Main entry point for the script."""
    parser = argparse.ArgumentParser(description="Run Qwen3 math evaluation")
    parser.add_argument("--model_name", type=str, required=True, help="Path to the model")
    args = parser.parse_args()

    output_file = predict(args.model_name)
    cal_metrics_whole_mc(output_file, is_base=True)


if __name__ == "__main__":
    main()
