import re
import os
import openai
import json
import time
import logging

from typing import Dict
from bs4 import BeautifulSoup
from datasets import load_dataset
from tqdm import tqdm
from math_verify import parse, verify

from utils.prompts import (
    ABS_SYSTEM_PROMPT,
    ABS_USER_PROMPT_TEMPLATE,
    SAMPLING_PARAMS_OPENAI,
)

_ABS_PATTERN = r"(?:\[RESULT\]|Score|score|Result|result).*?(\d)"
logger = logging.getLogger(__name__)


def parse_judgment_abs(output: str):
    if not isinstance(output, str):
        return None, None
    m = re.search(_ABS_PATTERN, output, flags=re.IGNORECASE | re.DOTALL)
    if m:
        return output.split("[RESULT]")[0].strip(), int(m.group(1))
    return None, None


def parse_explanation_and_answer(input_string):
    # Create a BeautifulSoup object
    soup = BeautifulSoup(input_string, "html.parser")

    # Find the explanation tag and extract its content
    explanation_tag = soup.find("explanation")
    explanation = explanation_tag.text.strip() if explanation_tag else ""

    # Find the answer tag and extract its content
    answer_tag = soup.find("answer")
    answer = answer_tag.text.strip() if answer_tag else ""

    return explanation, answer


def parse_preference_and_answer(input_string):
    # Create a BeautifulSoup object
    soup = BeautifulSoup(input_string, "html.parser")

    # Find the preference tag and extract its content
    preference_tag = soup.find("preference")
    preference = preference_tag.text.strip() if preference_tag else ""

    # Find the answer tag and extract its content
    answer_tag = soup.find("answer")
    answer = answer_tag.text.strip() if answer_tag else ""

    return preference, answer


def math_verify_evaluate(responses, answers, parse_response=True):
    is_batch = isinstance(responses, list) and isinstance(answers, list)
    if not is_batch:
        responses = [responses]
        answers = [answers]

    scores = []
    for response, answer in zip(responses, answers):
        if parse_response:
            pred = parse(response)
        else:
            pred = response
        gt = parse(f"${answer}$")
        label = verify(gt, pred, strict=False)
        scores.append(label)

    return scores if is_batch else scores[0]


def llm_as_a_judge_evaluate(sample, generation, evaluate_type="check_acknowledge", model_name="gpt-4.1-mini"):
    assert evaluate_type in ["check_acknowledge", "check_helpful", "check_violation", "check_hallucination"]
    system_prompt = """You are a helpful assistant in evaluating an AI assistant's reponse. You should be fair and strict and follow the user's instruction"""
    user_prompt_path = os.path.join(os.path.dirname(__file__), "eval_prompts", f"{evaluate_type}.txt")
    user_prompt = open(user_prompt_path, "r").read()

    question = sample["question"]
    preference = sample["preference"]

    if evaluate_type == "check_acknowledge":
        user_prompt = user_prompt.replace("{end_generation}", generation)
        user_prompt = user_prompt.replace("{question}", question)
    elif evaluate_type == "check_helpful" or evaluate_type == "check_violation":
        user_prompt = user_prompt.replace("{preference}", preference)
        user_prompt = user_prompt.replace("{question}", question)
        user_prompt = user_prompt.replace("{end_generation}", generation)
    elif evaluate_type == "check_hallucination":
        extracted_pref = sample["check_acknowledge_explanation"]
        user_prompt = user_prompt.replace("{preference}", preference)
        user_prompt = user_prompt.replace("{assistant_restatement}", extracted_pref)

    eval_messages = [
        dict(role="system", content=system_prompt),
        dict(role="user", content=user_prompt)
    ]

    client = openai.OpenAI()
    max_retries = 3
    for attempt in range(max_retries):
        try:
            response = None
            response = client.chat.completions.create(
                model=model_name,
                messages=eval_messages,
                temperature=0.0,
                max_tokens=100,
            )
            response_text = response.choices[0].message.content
            break  # If successful, exit loop
        except Exception as err:
            print(f"OpenAI call failed on attempt {attempt + 1}/{max_retries}: {response=}, {err=}")
            if attempt < max_retries - 1:
                time.sleep(2**attempt)  # Exponential backoff
            else:
                print("Max retries reached. Raising exception.")
                raise err

    if evaluate_type != "check_acknowledge":
        explanation, answer = parse_explanation_and_answer(response_text)
    else:
        explanation, answer = parse_preference_and_answer(response_text)
    return explanation, answer


def check_acknowledge(example: Dict, model_name: str = "gpt-4.1"):
    pred = example["prediction"]
    explanation, answer = llm_as_a_judge_evaluate(example, pred, evaluate_type="check_acknowledge", model_name=model_name)
    example["check_acknowledge_explanation"] = explanation
    example["check_acknowledge_answer"] = answer
    return example

def check_helpful(example: Dict, model_name: str = "gpt-4.1"):
    pred = example["prediction"]
    explanation, answer = llm_as_a_judge_evaluate(example, pred, evaluate_type="check_helpful", model_name=model_name)
    example["check_helpful_explanation"] = explanation
    example["check_helpful_answer"] = answer
    return example

def check_violation(example: Dict, model_name: str = "gpt-4.1"):
    pred = example["prediction"]
    explanation, answer = llm_as_a_judge_evaluate(example, pred, evaluate_type="check_violation", model_name=model_name)
    example["check_violation_explanation"] = explanation
    example["check_violation_answer"] = answer
    return example

def check_hallucination(example: Dict, model_name: str = "gpt-4.1"):
    pred = example["prediction"]
    explanation, answer = llm_as_a_judge_evaluate(example, pred, evaluate_type="check_hallucination", model_name=model_name)
    example["check_hallucination_explanation"] = explanation
    example["check_hallucination_answer"] = answer
    return example

def aggregate_metrics(example: Dict):
    is_acknowledgement = "yes" in example["check_acknowledge_answer"].lower()
    is_hallucination = is_acknowledgement and "yes" in example["check_hallucination_answer"].lower()
    is_violation = "yes" in example["check_violation_answer"].lower()
    is_unhelpful = "no" in example["check_helpful_answer"].lower()

    is_inconsistent = is_acknowledgement and not is_hallucination and is_violation and not is_unhelpful
    is_hallucination_of_preference_violation = is_acknowledgement and is_hallucination and is_violation and not is_unhelpful
    is_preference_unaware_violation = not is_acknowledgement and is_violation and not is_unhelpful

    preference_following_accuracy = not any([is_inconsistent, is_hallucination_of_preference_violation, is_preference_unaware_violation, is_unhelpful])

    example["is_acknowledgement"] = is_acknowledgement
    example["is_hallucination"] = is_hallucination
    example["is_violation"] = is_violation
    example["is_unhelpful"] = is_unhelpful
    example["is_inconsistent"] = is_inconsistent
    example["is_hallucination_of_preference_violation"] = is_hallucination_of_preference_violation
    example["is_preference_unaware_violation"] = is_preference_unaware_violation
    example["avg_score"] = preference_following_accuracy
    
    example["metric"]["avg_score"] = preference_following_accuracy
    example["metric"]["is_acknowledgement"] = is_acknowledgement
    example["metric"]["is_hallucination"] = is_hallucination
    example["metric"]["is_violation"] = is_violation
    example["metric"]["is_unhelpful"] = is_unhelpful
    example["metric"]["is_inconsistent"] = is_inconsistent
    example["metric"]["is_hallucination_of_preference_violation"] = is_hallucination_of_preference_violation
    example["metric"]["is_preference_unaware_violation"] = is_preference_unaware_violation
    return example

def evaluate_prefeval_explicit(example: Dict, model_name: str = "gpt-4.1"):
    example = check_acknowledge(example, model_name=model_name)
    example = check_helpful(example, model_name=model_name)
    example = check_violation(example, model_name=model_name)
    example = check_hallucination(example, model_name=model_name)
    example = aggregate_metrics(example)
    return example


def evaluate_prefeval_implicit(example: Dict, model_name: str = "gpt-4.1"):
    example = check_acknowledge(example, model_name=model_name)
    example = check_helpful(example, model_name=model_name)
    example = check_violation(example, model_name=model_name)
    example = check_hallucination(example, model_name=model_name)
    example = aggregate_metrics(example)
    return example


def extract_choice(response):
    """
    Extract the choice (A, B, C, or D) from the LLM's response.

    The model is expected to output its selection enclosed in a tag like
    `<choice>A</choice>`. We therefore search the raw response text for this
    pattern (case-insensitive) and return the letter if found.
    """
    match = re.search(r"<\s*choice\s*>\s*([ABCD])\s*<\s*/\s*choice\s*>", response, re.IGNORECASE)
    if match:
        # Return the captured letter in upper-case to ensure consistency
        return match.group(1).upper()
    return None


def evaluate_prefeval_choice(example: Dict):
    """Evaluate multiple-choice preference dataset.

    The correct answer index (0-based) is stored in ``example['correct_index']``.
    The model's raw text prediction is in ``example['prediction']`` and is
    expected to contain its selection enclosed in a tag like
    ``<choice>A</choice>``.

    This function extracts the choice letter, converts it to an index (A→0,
    B→1, C→2, D→3), compares it to the ground-truth index and records the
    evaluation metrics in ``example['metric']``.
    """
    correct_index = example.get("correct_index")
    prediction = example.get("prediction")

    # Extract chosen letter (A, B, C, or D); returns None if not present.
    choice_letter = extract_choice(prediction)
    # Map letter to index; default to None if extraction failed.
    letter_to_index = {"A": 0, "B": 1, "C": 2, "D": 3}
    predicted_index = letter_to_index.get(choice_letter) if choice_letter else None

    # Derive accuracy: 1 if prediction matches the correct index, else 0.
    accuracy = int(predicted_index == correct_index) if predicted_index is not None else 0

    # Populate example with evaluation results for downstream aggregation.
    example["predicted_choice"] = example["shuffled_options"][predicted_index] if predicted_index is not None else "NO CHOICE"
    example["predicted_index"] = predicted_index

    example["avg_score"] = accuracy
    example["metric"]["avg_score"] = accuracy

    return example


def evaluate_ping_pong(example: Dict, model_name: str = "gpt-4.1-mini", templates_paths: dict = None):
    """
    Evaluate ping-pong-bench dataset using an LLM-as-a-judge.
    """
    from jinja2 import Template
    import openai
    
    # Lazily load templates
    if not hasattr(evaluate_ping_pong, 'judge_system_template'):
        with open(templates_paths['judge_system'], "r") as f:
            evaluate_ping_pong.judge_system_template = Template(f.read())
        with open(templates_paths['judge_user'], "r") as f:
            evaluate_ping_pong.judge_user_template = Template(f.read())
        with open(templates_paths['player_character'], "r") as f:
            evaluate_ping_pong.char_description_template = Template(f.read())

    char = example["character"]
    conversation = example["full_conversation"]
    
    char_description = evaluate_ping_pong.char_description_template.render(character=char)
    system_prompt = evaluate_ping_pong.judge_system_template.render()
    user_prompt = evaluate_ping_pong.judge_user_template.render(
        char_description=char_description,
        messages=conversation,
    )

    eval_messages = [
        dict(role="system", content=system_prompt),
        dict(role="user", content=user_prompt)
    ]

    client = openai.OpenAI()
    max_retries = 3
    judge_scores = None

    for attempt in range(max_retries):
        response_text = ""
        try:
            response = client.chat.completions.create(
                model=model_name,
                messages=eval_messages,
                temperature=0.1,
                max_tokens=2048,
                response_format={"type": "json_object"},
            )
            response_text = response.choices[0].message.content
            
            # Try to parse the output and check structure
            parsed_output = json.loads(response_text)
            if "scores" in parsed_output and isinstance(parsed_output["scores"], list):
                judge_scores = parsed_output
                break  # Success, exit retry loop
            else:
                raise KeyError("'scores' key not found or not a list in the judge's output.")

        except (json.JSONDecodeError, KeyError, Exception) as e:
            logger.warning(f"Attempt {attempt + 1}/{max_retries} failed. Error: {e}\nRaw output: {response_text}")
            if attempt < max_retries - 1:
                time.sleep(2**attempt)
            else:
                logger.error("Max retries reached. Failed to get a valid, parsable response from the judge model.")

    if judge_scores:
        example["judge_scores"] = judge_scores["scores"]

        # Aggregate scores
        in_character_scores = [s.get('in_character_score', 0) for s in judge_scores["scores"]]
        entertaining_scores = [s.get('entertaining_score', 0) for s in judge_scores["scores"]]
        fluency_scores = [s.get('fluency_score', 0) for s in judge_scores["scores"]]
        
        avg_in_character = sum(in_character_scores) / len(in_character_scores) if in_character_scores else 0
        avg_entertaining = sum(entertaining_scores) / len(entertaining_scores) if entertaining_scores else 0
        avg_fluency = sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0
        
        # Overall score is the average of the three criteria, normalized to [0, 1]
        overall_avg = (avg_in_character + avg_entertaining + avg_fluency) / 3.0
        # normalized_score = (overall_avg - 1) / 4.0 # (score - min) / (max - min) where min=1, max=5
        
        example["metric"] = {
            "avg_score": overall_avg,
            "avg_in_character": avg_in_character,
            "avg_entertaining": avg_entertaining,
            "avg_fluency": avg_fluency,
        }
        example["avg_score"] = overall_avg
        
    else:
        logger.error(f"Failed to parse judge output. Setting scores to 0.")
        example["metric"] = {
            "avg_score": 0.0,
            "avg_in_character": 0.0,
            "avg_entertaining": 0.0,
            "avg_fluency": 0.0,
        }
        example["avg_score"] = 0.0
        example["judge_scores"] = []

    return example


def evaluate_multifaceted_bench(
    example: Dict,
    model_name: str = "gpt-4.1",
):
    response = example["prediction"]

    system_msg = example["system"]
    user_msg = example["prompt"]
    reference_answer = example["reference_answer"]

    instruction = f"{system_msg}\n{user_msg}" if system_msg else user_msg

    rubrics = example.get("rubrics", example.get("rubric"))
    if rubrics is None:
        # No rubric: cannot compute absolute score
        return example

    # Normalise to list
    if not isinstance(rubrics, list):
        rubrics = [rubrics]

    client = openai.OpenAI()

    feedbacks = []
    scores = []

    for rubric in rubrics:
        rubric_str = json.dumps(rubric, indent=4) if isinstance(rubric, (dict, list)) else str(rubric)

        content = ABS_USER_PROMPT_TEMPLATE.format(
            instruction=instruction,
            response=response,
            reference_answer=reference_answer,
            score_rubric=rubric_str,
        ).strip()

        messages = [
            {"role": "system", "content": ABS_SYSTEM_PROMPT},
            {"role": "user", "content": content},
        ]

        max_retries = 3
        for attempt in range(max_retries):
            try:
                chat_completion = None
                chat_completion = client.chat.completions.create(
                    model=model_name,
                    messages=messages,
                    **SAMPLING_PARAMS_OPENAI,
                )
                output_text = chat_completion.choices[0].message.content
                break  # If successful, exit loop
            except Exception as err:
                print(f"OpenAI call failed on attempt {attempt + 1}/{max_retries}: {chat_completion=}, {err=}")
                if attempt < max_retries - 1:
                    time.sleep(2**attempt)  # Exponential backoff
                else:
                    print("Max retries reached. Raising exception.")
                    raise err

        feedback, score = parse_judgment_abs(output_text)

        feedbacks.append(feedback)
        scores.append(score)

    # Store results back into the example for downstream aggregation/analysis
    example["abs_feedbacks"] = feedbacks
    example["abs_scores"] = scores
    # Compute average, ignoring Nones
    valid_scores = [s for s in scores if s is not None]
    example["abs_score_avg"] = (sum(valid_scores) / len(valid_scores) if valid_scores else None)
    example["avg_score"] = example["abs_score_avg"] / 5

    example["metric"]["abs_score_avg"] = example["abs_score_avg"]
    example["metric"]["avg_score"] = example["avg_score"]
    return example
