# Load the all-MiniLM-L6-v2 model
from sentence_transformers import SentenceTransformer, util
embedding_model = SentenceTransformer('/all-MiniLM-L6-v2')

def semantic_similarity(ans1, ans2):
    """
    Compute semantic similarity between two answer strings using cosine similarity (range: 0 to 1).
    """
    embedding1 = embedding_model.encode(ans1, convert_to_tensor=True)
    embedding2 = embedding_model.encode(ans2, convert_to_tensor=True)
    sim = util.pytorch_cos_sim(embedding1, embedding2)
    return sim.item()

import re

def extract_answer(text: str, task_type: str = "default") -> str:
    """
    Extract the final answer string from model output based on task type.

    Args:
        text: Raw model output text.
        task_type: Type of task. Supported types include:
                   "multi_step_arithmetic", "single_step_arithmetic",
                   "year_parity", "yes_no", "open_qa", etc.

    Returns:
        Extracted answer string. If not found, returns "[invalid]".
    """
    normalized_text = text.lower()

    if task_type == "multi_step_arithmetic":
        numbers = re.findall(r"(\-?\d+(?:\.\d+)?)", text)
        return numbers[-1] if numbers else "[invalid]"

    elif task_type == "single_step_arithmetic":
        numbers = re.findall(r"(\-?\d+(?:\.\d+)?)", text)
        return numbers[-1] if numbers else "[invalid]"

    elif task_type == "year_parity":
        if "even" in normalized_text:
            return "even"
        elif "odd" in normalized_text:
            return "odd"
        else:
            numbers = re.findall(r"(\d+)", text)
            if numbers:
                year = int(numbers[-1])
                return "even" if (year % 2) == 0 else "odd"
            return "[invalid]"

    elif task_type == "yes_no":
        if any(x in normalized_text for x in ["yes", "true"]):
            return "yes"
        elif any(x in normalized_text for x in ["no", "false"]):
            return "no"
        return "[invalid]"

    elif task_type == "open_qa":
        lines = text.strip().split("\n")
        return lines[-1].strip() if lines else "[invalid]"

    else:
        numbers = re.findall(r"(\-?\d+(?:\.\d+)?)", text)
        return numbers[-1] if numbers else "[invalid]"


def validate_answer(task_type, answer_text):
    """
    Validate answer format based on task type.

    Args:
        task_type: Type of task ('option', 'math', 'yes_or_no', 'open')
        answer_text: Answer text to be validated.

    Returns:
        True if answer is valid, False otherwise.
    """
    ans_stripped = answer_text.strip()

    if task_type == "option":
        match = re.search(r'(?:^|\s)(?:\(?\s*([A-Z])\s*[).:]?)(?:\s|$)', ans_stripped)
        if match:
            return True
        lines = [line.strip() for line in ans_stripped.split('\n') if line.strip()]
        for line in reversed(lines):
            if re.search(r'(?:^|\s)([A-Z])(?:\s|$)', line):
                return True
        return False

    elif task_type == "math":
        return bool(re.search(r"\-?\d+(?:\.\d+)?", ans_stripped))

    elif task_type == "yes_or_no":
        return bool(re.search(r'\b(yes|no)\b', ans_stripped, re.I))

    elif task_type == "open":
        return answer_text

    else:
        raise ValueError(f"Unknown task type: {task_type}")

def filter_invalid_answers(
    task_type: str,
    answers: list,
    probs: list,
    max_decoding_steps: int,
):
    """
    Filter out answers with invalid formats:
      1. Empty answers (length = 0)
      2. Answers with length >= max_decoding_steps (possibly unfinished or repeated)
      3. Answers ending with a question mark (indicating it may be a restated question)
      4. [Optional: skipped mistral blocklist check]
      5. Answers without extractable numeric value (invalid if no number present)

    Args:
        task_type: The type of task (used for validation)
        answers: List[str], model-generated answer list
        probs: List[List[float]], token probability list
        max_decoding_steps: int, model’s max decoding steps to detect unfinished outputs

    Returns:
        valid_answers, valid_indices:
            - valid_answers: list of valid answers after filtering
            - valid_indices: original indices of the valid answers
    """

    valid_answers = []
    valid_indices = []

    for idx, ans in enumerate(answers):
        ans_stripped = ans.strip()

        # 1) Empty answer
        if len(ans_stripped) == 0:
            continue
        # 2) Answer too long (may be unfinished or repetitive)
        if len(probs[idx]) >= max_decoding_steps:
            continue
        # 3) Ends with question mark
        if ans_stripped.endswith("?"):
            continue

        # 5) Validate with task type-specific rules
        check = validate_answer(task_type, ans_stripped)

        if check:
            valid_answers.append(ans)
            valid_indices.append(idx)

    return valid_answers, valid_indices


def extract_answer_bbh(text: str, task_type: str = "default") -> str:
    """
    Extract answer from BBH dataset-style output based on task type.

    Supports: option, bool, yes_or_no, math, dyck, default (raw).
    """
    candidate = text

    if task_type == "option":
        standard_matches = re.finditer(r'(?:^|\s)(?:\(?\s*([A-Z])\s*[).:]?)(?:\s|$)', text)
        standard_options = [m.group(1) for m in standard_matches if m.group(1)]
        if standard_options:
            return standard_options[-1]
        lines = [line.strip() for line in text.split('\n') if line.strip()]
        for line in reversed(lines):
            match = re.search(r'(?:^|\s)([A-Z])(?:\s|$)', line)
            if match:
                return match.group(1)
        return "[invalid]"

    elif task_type == "bool":
        match = re.search(r'\b(true|false)\b', candidate, re.I)
        return "true" if match and match.group().lower() == "true" else "false"

    elif task_type == "yes_or_no":
        matches = re.findall(r'\b(yes|no)\b', text, re.I)
        return matches[-1].lower() if matches else "[invalid]"

    elif task_type == "math":
        numbers = re.findall(r'-?\d+\.?\d*', candidate)
        return numbers[-1] if numbers else "[invalid]"

    elif task_type == "dyck":
        dyck_sequence = re.sub(r'[^(){} ]', '', candidate).strip()
        return dyck_sequence if dyck_sequence else "[invalid]"

    else:
        return candidate

# Example test
def bbh_main():
    # Example 1: Mathematical answer
    response1 = """The final result is: 14"""
    print("Extracted 1:", extract_answer_bbh(response1, 'math'))  # Output: 14

    # Example 2: Option answer
    response2 = """The final result is: (B)"""
    print("Extracted 2:", extract_answer_bbh(response2, 'option'))  # Output: B

    # Example 3: Boolean answer
    response3 = """The final result is: True"""
    print("Extracted 3:", extract_answer_bbh(response3, 'bool'))  # Output: true

    # Example 4: Fallback raw text
    response4 = '''The final result is: 14\n'''
    print("Extracted 4:", extract_answer_bbh(response4, 'math'))  # Output: 14


if __name__ == "__main__":
    bbh_main()
