import os
import json
import numpy as np
from tqdm import tqdm
from openai import OpenAI

# ---------------- Judge Config ----------------

"""
JUDGE_MODE:
  - "gpt4o"    → use OpenAI gpt-4o
  - "deepseek" → use DeepSeek API (OpenAI compatible)
"""
JUDGE_MODE = os.getenv("JUDGE_MODE", "deepseek")

# Judge prompt template
JUDGE_TEMPLATE = """
[Instruction]
Act as an impartial judge to evaluate the quality of a biography generated by an AI assistant.
Focus on three metrics:
1) Factual Accuracy
2) Relevance and Recall
3) Coherence and Structure

Provide a brief initial assessment and then conclude
with final scoring in the format:

Therefore, the final scores of the output is:
(1) Factual Accuracy: [[Rating]];
(2) Relevance and Recall: [[Rating]];
(3) Coherence and Structure: [[Rating]]

Here is the reference Wikipedia summary:
{context}

[Question]
{question}

[The Start of Assistant's Answer]
{answer}
[The End of Assistant's Answer]

[Your Evaluation]
"""


# ---------------- Helper Functions ----------------

def read_json(file_path):
    with open(file_path, "r", encoding="utf8") as f:
        return json.load(f)


def save_json(file_path, data):
    with open(file_path, "w", encoding="utf8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)


def clean_answer(text):
    if "\n####" in text:
        return text[:text.find("\n####")]
    return text


def build_prompts(data, truths):
    prompts = []
    for i, truth in enumerate(truths):
        assert data[i]["query"] == truth["question"], \
            "Query and truth question mismatch"
        prompt_text = JUDGE_TEMPLATE.format(
            context  = truth["gpt_response"],
            question = data[i]["query"],
            answer   = clean_answer(data[i]["final_response"])
        )
        # each prompt is one list of messages
        prompts.append([
            {"role": "system", "content": "You are a helpful judge."},
            {"role": "user", "content": prompt_text}
        ])
    return prompts


def make_openai_client():
    if JUDGE_MODE == "deepseek":
        assert DEEPSEEK_API_KEY and DEEPSEEK_BASE_URL, \
            "DeepSeek requires DEEPSEEK_API_KEY and DEEPSEEK_BASE_URL set"
        return OpenAI(api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_BASE_URL)

    elif JUDGE_MODE == "gemini":
        assert GEMINI_API_KEY and GEMINI_BASE_URL, \
            "Gemini requires GEMINI_API_KEY and GEMINI_BASE_URL set"
        return OpenAI(api_key=GEMINI_API_KEY, base_url=GEMINI_BASE_URL)

    else:
        assert OPENAI_API_KEY and OPENAI_BASE_URL, "OpenAI key/base_url not found"
        return OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)


def call_judge_api(client, messages):
    """
    Call OpenAI/DeepSeek/Gemini single chat completion.
    """
    if JUDGE_MODE == "deepseek":
        model = DEEPSEEK_MODEL
    elif JUDGE_MODE == "gemini":
        model = GEMINI_MODEL
    else:
        model = "gpt-4o"

    resp = client.chat.completions.create(
        model=model,
        messages=messages,
        stream=False
    )
    return resp.choices[0].message.content



def extract_scores(text):
    """
    Find numeric scores inside the judge output.
    Returns (fact, relev, coh).
    """
    def find_after(marker):
        idx = text.find(marker)
        if idx == -1:
            return 0
        start = idx + len(marker)
        snippet = text[start:start+3]
        for i in ["10", "9","8","7","6","5","4","3","2","1"]:
            if snippet.startswith(i):
                return int(i)
        return 0

    f = find_after("Factual Accuracy: [[")
    r = find_after("Relevance and Recall: [[")
    c = find_after("Coherence and Structure: [[")
    return f, r, c


def evaluate(file_path, truth_path):
    data   = read_json(file_path)
    truths = read_json(truth_path)

    prompts = build_prompts(data, truths)
    client  = make_openai_client()

    all_scores = []

    # Judge each prompt sequentially
    for messages in tqdm(prompts):
        out_text = call_judge_api(client, messages)
        f, r, c = extract_scores(out_text)
        all_scores.append({"factacc_one": f, "relevance_one": r, "coherence_one": c})

    # Write scores back to data
    for i, item in enumerate(data):
        item["llm_perc_score"]      = all_scores[i]["factacc_one"]
        item["llm_relevance_score"] = all_scores[i]["relevance_one"]
        item["llm_coherence_score"] = all_scores[i]["coherence_one"]

    save_json(file_path, data)

    # Also save raw judge outputs
    raw_outcomes = [resp for resp in all_scores]
    save_json(file_path.replace(".json", f"-{JUDGE_MODE}-scores.json"), raw_outcomes)

    # Return averaged metrics
    fact = np.mean([s["factacc_one"]     for s in all_scores])
    rel  = np.mean([s["relevance_one"]    for s in all_scores])
    coh  = np.mean([s["coherence_one"]    for s in all_scores])
    return fact, rel, coh


# ---------------- Main ----------------

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--file",  required=True, help="JSON file with model outputs")
    parser.add_argument("--truth", required=True, help="JSON file with reference GPT ground truths")
    args = parser.parse_args()

    print(f"> Evaluating using judge mode = {JUDGE_MODE}\n")
    f, r, c = evaluate(args.file, args.truth)
    print(f"\nFinal average scores:")
    print(f"Factual Accuracy = {f:.2f}")
    print(f"Relevance         = {r:.2f}")
    print(f"Coherence         = {c:.2f}")
