# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from recipe.collabllm.utils import extract_json, parse_messages

DIAGNOSIS_PROMPT = '''You are a highly professional and meticulous Medical Evaluator and LLM Judge. \
Your task is to rigorously evaluate the performance of an AI Doctor model's performance. You will compare the AI's Final Diagnosis and Recommendations against the Ground Truth.

Provided Information:
<|The Start of Ground Truth Diagnosis|>
Ground Truth Diagnosis: {ground_truth}
<|The End of Ground Truth Diagnosis|>

<|The Start of Ground Truth Recommendation|>
Ground Truth Recommendation: {recommendation}
<|The End of Ground Truth Recommendation|>

<|The Start of The Conversation|>
{chat_history}
<|The End of The Conversation|>

## Evaluation Logic
You must follow a strict two-step evaluation process:
**Step 1: Diagnosis Verification**
-   Identify the final diagnosis provided by the AI Doctor in the chat.
-   Compare it with the <|Ground Truth Diagnosis|>.
-   **Rule:** If the diagnosis is WRONG or the AI doctor refuese/fails to provide a diagnosis, the Total Score is **0**. You stop here.

**Step 2: Recommendation Quality**
-   **ONLY** if the diagnosis is CORRECT, proceed to score the recommendations.
-   Compare the AI's advice with <|Ground Truth Recommendation|>.
-   **Scoring Rubric:**
    * **8-10 (Perfect):** The AI's recommendation is logically **identical** to the Ground Truth.
            *Condition:* It contains ALL key points from GT AND contains **ZERO** extra/unnecessary advice.
    * **1-2 (Dangerous/Wrong):** Diagnosis was right, BUT the advice is bloated, vague, or messy.
            *Condition:* (Any ONE of these triggers this bucket):
            **Bloat:** The correct advice is hidden inside a long list of general wellness tips (diet, sleep, water) that were not in the GT.
            **Vagueness or Drifting:** The AI gives advice for a different aspect of the disease not mentioned in the GT.
    * **0 (Failure):** Diagnosis is incorrect.
    **You MUST assign a score in 0, 1, 2, 8, 9, or 10. Scores 3, 4, 5, 6, 7 are FORBIDDEN.**

Output format (JSON):
{{
    "thought": "Extracted Diagnosis, Extracted Recommendation and Reasoning: (Briefly explain why the diagnosis is Right/Wrong and justify the score for recommendations)",
    "score": <MUST be 0, 1, 2, 8, 9, or 10. DO NOT output 3-7.>
}}

Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. \
Use " or """ to wrap up the thought and use single quotes inside the "thought" field to avoid JSON escape issues.

Your evaluation:
'''


async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
    # Check if litellm is available, fallback to openai if not
    try:
        import litellm

        use_litellm = True
    except ImportError:
        # litellm not found, falling back to openai
        import openai

        use_litellm = False

    chat_history = parse_messages(messages, strip_sys_prompt=True)
    prompt = DIAGNOSIS_PROMPT.format(
        ground_truth=ground_truth["diagnosis"],
        recommendation=ground_truth["recommendation"],
        chat_history=chat_history,
    )

    if use_litellm:
        full_response = (
            (
                await litellm.acompletion(
                    messages=[{"role": "user", "content": prompt}],
                    **kwargs,
                )
            )
            .choices[0]
            .message.content
        )
    else:
        client = openai.AsyncOpenAI()  # Assumes API key is set in environment
        full_response = (
            (
                await client.chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    **kwargs,
                )
            )
            .choices[0]
            .message.content
        )

    full_response = extract_json(full_response)

    assert isinstance(
        full_response, dict
    ), f"Expected a dict, got {type(full_response)}"
    assert {"score", "thought"}.issubset(
        full_response.keys()
    ), f"Expected keys not found from {full_response.keys()}"

    score = full_response.pop("score")
    return float(score) / 10
