import asyncio
from typing import Literal

from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel

BASE_URL = "http://localhost:8005/v1"
EVAL_MODEL_NAME = "refinement_sft"

client = OpenAI(base_url=BASE_URL, api_key="EMPTY")
async_client = AsyncOpenAI(base_url=BASE_URL, api_key="EMPTY")


class EvaluationResponse(BaseModel):
    reasoning: str
    final_answer: Literal["1", "2"]


async def get_qwen3_response(prompt):
    system_prompt = "You are an expert creative writing judge. Be very critical with your evaluation and always end your evaluation with either 1 or 2 depending on whether the first or second story is better aligned with the criterion."

    try:
        chat_response = await async_client.chat.completions.create(
            model=EVAL_MODEL_NAME,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt},
            ],
            max_completion_tokens=256,
            temperature=0.3,
            top_p=0.95,
            presence_penalty=1.5,
            extra_body={
                "top_k": 20,
                "guided_json": EvaluationResponse.model_json_schema(),
                "chat_template_kwargs": {"enable_thinking": False},
            },
        )
        if response := chat_response.choices[0].message.content:
            return response
        return ""
    except Exception as e:
        print(f"Error calling vLLM model: {e}")
        return f'{{"reasoning": "Error: {str(e)}", "final_answer": "Invalid"}}'


def qwen3_grader(prompts) -> list[str]:

    async def process_prompts():
        return await asyncio.gather(*[get_qwen3_response(prompt) for prompt in prompts])

    result_text_list = asyncio.run(process_prompts())

    return result_text_list


def qwen25_grader(prompts) -> list[str]:
    system_prompt = "You are an expert creative writing judge. Be very critical with your evaluation and always end your evaluation with either 1 or 2 depending on whether the first or second story is better aligned with the criterion."

    full_prompts = [f"{system_prompt}\n{prompt}" for prompt in prompts]

    try:
        response = client.completions.create(
            model=EVAL_MODEL_NAME,
            prompt=full_prompts,
            max_tokens=256,
            temperature=0.3,
            extra_body={"guided_json": EvaluationResponse.model_json_schema()},
        )

        result_text_list = [choice.text.strip() for choice in response.choices]

        return result_text_list

    except Exception as e:
        print(f"Error calling vLLM model: {e}")
        return [f'{{"reasoning": "Error: {str(e)}", "final_answer": "Invalid"}}'] * len(
            prompts
        )
