# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from time import sleep

import aiohttp

from verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed

BASE_URL = "http://localhost:30000"
API_KEY = "EMPTY"
MAX_RETRIES = 3
BASE_DELAY = 2
MAX_WORKERS = 32
MODEL_NAME = "genrm-demo"
GENRM_PROMPT_TEMPLATE = """
The following is a math problem and an AI solution:

[Math Problem]

{problem}

[AI Solution]

{solution}

Your task is to review and critique the solution step by step, and output whether the AI solution is correct.

Please put your final answer (i.e., 'True' or 'False') in \\boxed{{}}.
""".strip()


async def post_request(payload, endpoint):
    url = f"{BASE_URL}/{endpoint}"
    try:
        timeout = aiohttp.ClientTimeout(total=None)
        session = aiohttp.ClientSession(timeout=timeout)
        async with session.post(url, json=payload) as resp:
            output = await resp.text()
            output = json.loads(output)
            return output
    except Exception as e:
        raise e
    finally:
        await session.close()


async def get_response(problem, solution_str, ground_truth):
    prompt = GENRM_PROMPT_TEMPLATE.format(problem=problem, solution=solution_str)
    messages = [{"role": "user", "content": prompt}]
    for attempt in range(MAX_RETRIES):
        try:
            data = {"model": MODEL_NAME, "messages": messages}
            output = await post_request(data, "v1/chat/completions")
            response = output["choices"][0]["message"]["content"]
            return response
        except Exception as e:
            if attempt < MAX_RETRIES - 1:
                print("Exception: ", repr(e))
                delay = BASE_DELAY * (2**attempt)
                print(f"Retrying in {delay} seconds...")
                sleep(delay)
            else:
                print(f"Failed after {MAX_RETRIES} attempts. Error: {e}")

    raise ConnectionRefusedError(f"Failed to run the model for {prompt}!")


def compute_reward(response):
    reward_score = 0.0
    try:
        boxed_result = last_boxed_only_string(response)
        if boxed_result is not None:
            result = remove_boxed(boxed_result)
            reward_score = float(result == "True")
    except Exception as e:
        print(e)
    return reward_score


async def compute_score(data_source, solution_str, ground_truth, extra_info):
    split = extra_info["split"]
    from verl.utils.reward_score import default_compute_score

    func_rm_score = default_compute_score(data_source, solution_str, ground_truth, extra_info)

    if split == "test":
        return func_rm_score
    else:
        problem = extra_info["question"]
        response = await get_response(problem, solution_str, ground_truth)
        if response is not None:
            reward_score = compute_reward(response)
        else:
            reward_score = 0.0

        return reward_score
