import os
from typing import Any, Dict, List, Union

import json

from swift.llm import InferRequest


class PRM:

    def __call__(self, **kwargs) -> List[Any]:
        raise NotImplementedError


SYSTEM = """
You are a process reward model, give the reward value of the answer, you must follow the instructions below:

1. Output a float reward value between -1.0 and 1.0, -1.0 means the worst answer, 1.0 means the best answer, please think step by step to give your reasons and thoughts, but the reward must appare at the end with this format: **Reward: your-reward-value**.

2. The answer may be incomplete, you must give the reward by the existing part of the answer, taking into account semantic coherence, logical correctness, and clarity.

3. A ground truth answer will be given to you, it may be not the best one, consider it as a reference example.

Begin!
"""  # noqa

QUERY = """
The original question or the previous conversation:

#query#

Here is the ground truth as the reference:

#ground_truth#

Given the upper information, give your reward(-1.0~1.0) of the following answer:

#response#
"""


class QwenMaxPRM(PRM):

    def __call__(
        self,
        infer_requests: List[Union[InferRequest, Dict]],
        ground_truths: List[str],
        **kwargs
    ) -> List[float]:
        # TODO: check request_config
        rewards = []

        from openai import OpenAI

        client = OpenAI(
            api_key=os.getenv("DASHSCOPE_API_KEY"),
            base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
        )

        for request, ground_truth in zip(infer_requests, ground_truths):
            previous = request.messages[:-1]
            if previous[0]["role"] == "system":
                previous = previous[1:]

            assert request.messages[-1]["role"] == "assistant"
            query = QUERY.replace("#query#", json.dumps(previous))
            query = query.replace("#ground_truth#", ground_truth)
            query = query.replace("#response#", request.messages[-1]["content"])
            messages = [
                {"role": "system", "content": SYSTEM},
                {"role": "user", "content": query},
            ]
            completion = client.chat.completions.create(
                model="qwen-max",
                messages=messages,
            )

            content = completion.choices[0].message.content
            if "Reward:" not in content:
                rewards.append(0.0)
            else:
                try:
                    reward = float(content.split("Reward:")[1].strip().replace("*", ""))
                    rewards.append(reward)
                except Exception:
                    rewards.append(0.0)

        return rewards


class ClientPRM(PRM):

    def __init__(self, api_key=None, base_url=None, model=None):
        from swift.llm import InferClient
        import os

        if api_key is None:
            api_key = os.getenv("DASHSCOPE_API_KEY")
        if base_url is None:
            base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
        if model is None:
            model = "qwen-plus"
        self.infer_engine = InferClient(base_url=base_url, api_key=api_key)
        self.infer_engine.strict = False
        self.infer_kwargs = {
            "model": model,
        }

    def __call__(
        self,
        infer_requests: List[Union[InferRequest, Dict]],
        ground_truths: List[str],
        **kwargs
    ) -> List[float]:
        prm_infer_requests = []
        request_config = kwargs.get("request_config")
        for request, ground_truth in zip(infer_requests, ground_truths):
            previous = request["messages"][:-1]
            if previous[0]["role"] == "system":
                previous = previous[1:]

            assert request["messages"][-1]["role"] == "assistant"
            query = QUERY.replace("#query#", json.dumps(previous))
            query = query.replace("#ground_truth#", ground_truth)
            query = query.replace("#response#", request["messages"][-1]["content"])
            messages = [
                {"role": "system", "content": SYSTEM},
                {"role": "user", "content": query},
            ]

            prm_infer_requests.append(InferRequest(messages=messages))

        responses = self.infer_engine.infer(
            prm_infer_requests, request_config=request_config, **self.infer_kwargs
        )
        rewards = []
        for response in responses:
            content = response.choices[0].message.content
            if "Reward:" not in content:
                rewards.append(0.0)
            else:
                try:
                    reward = float(content.split("Reward:")[1].strip().replace("*", ""))
                    rewards.append(reward)
                except Exception:
                    rewards.append(0.0)
        return rewards


prms = {
    "qwen_max": QwenMaxPRM,
    "client": ClientPRM,
}
