# utils/common.py
"""
Common utilities.

Includes multiple-choice parsing and accuracy calculation helpers.
"""

import numpy as np
from typing import List, Dict, Any, Tuple, Optional


def parse_multi_choice_response(response: str, all_choices: Optional[List[str]] = None) -> Optional[str]:
    """
    Parse a multiple-choice response and extract the most likely option letter.

    Args:
        response: Model output text.
        all_choices: Optional list of valid choice letters, e.g. ["A", "B", "C", "D"].

    Returns:
        The parsed choice letter, or None if it cannot be determined.
    """
    if all_choices is None:
        all_choices = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
    if not response:
        return all_choices[0] if all_choices else None
    
    # Strip some punctuation
    for char in [",", ".", "!", "?", ";", ":", "'"]:
        response = response.strip(char)
    padded = " " + response + " "

    candidates = []
    ans_with_brack = False
    
    # Bracket form: "(A)"
    for choice in all_choices:
        if f"({choice})" in padded:
            candidates.append(choice)
            ans_with_brack = True

    # Space-delimited: " A "
    if len(candidates) == 0:
        for choice in all_choices:
            if f" {choice} " in padded:
                candidates.append(choice)

    # Period form: "A."
    if len(candidates) == 0:
        for choice in all_choices:
            if f"{choice}." in padded:
                candidates.append(choice)

    # If nothing matches, return a deterministic fallback.
    if not candidates:
        return all_choices[0] if all_choices else None
    
    # If multiple candidates match, pick the one that appears last.
    elif len(candidates) > 1:
        start_indexes = []
        if ans_with_brack:
            for can in candidates:
                start_indexes.append(padded.rfind(f"({can})"))
        else:
            for can in candidates:
                start_indexes.append(padded.rfind(f" {can} "))
        return candidates[int(np.argmax(start_indexes))]
    else:
        return candidates[0]


def build_question_text(task_doc: Dict) -> str:
    """
    Build a standard question text (question + options).

    Args:
        task_doc: Task dict containing 'question' and 'candidates'.

    Returns:
        str: Formatted question text.
    """
    if 'candidates' not in task_doc:
        raise ValueError(f"Task {task_doc.get('id')} is missing the 'candidates' field")
    
    candidates = task_doc['candidates']
    question_text = task_doc["question"] + "\n" + "\n".join([
        f"{chr(ord('A') + i)}. {candidate}" 
        for i, candidate in enumerate(candidates)
    ])
    return question_text


def judge_correct(prediction: str, task_doc: Dict) -> Tuple[bool, str]:
    """
    Judge whether a prediction is correct.

    Args:
        prediction: Model output.
        task_doc: Task dict containing the correct answer info.

    Returns:
        (is_correct, predicted_letter)
    """
    if 'candidates' not in task_doc or 'correct_choice' not in task_doc:
        return False, ""
    
    all_choices = [chr(ord("A") + i) for i in range(len(task_doc['candidates']))]
    predicted_letter = parse_multi_choice_response(prediction, all_choices)
    if predicted_letter is None:
        return False, ""
    
    ground_truth_index = task_doc.get('correct_choice')
    if ground_truth_index is None:
        return False, predicted_letter
    
    ground_truth_letter = chr(ord("A") + ground_truth_index)
    is_correct = predicted_letter.upper() == ground_truth_letter.upper()
    
    return is_correct, predicted_letter


def calculate_accuracy(results: List[Dict]) -> Dict[str, Any]:
    """
    Calculate accuracy statistics.

    Args:
        results: List of result dicts, each containing 'is_correct'.

    Returns:
        Dict: Accuracy summary.
    """
    total_tasks = len(results)
    correct_count = sum(1 for r in results if r.get('is_correct', False))
    accuracy_percent = (correct_count / total_tasks * 100) if total_tasks > 0 else 0.0
    
    return {
        "total_tasks": total_tasks,
        "correct_count": correct_count,
        "accuracy_percent": accuracy_percent
    }


def extract_model_size_tag(model_path: str) -> str:
    """
    Extract a model size tag such as '3B' / '7B' / '32B' / '72B' from a model path.

    Args:
        model_path: Model path.

    Returns:
        str: Size tag, or 'UNK' if parsing fails.
    """
    import os, re
    base = os.path.basename(model_path.rstrip("/\\"))
    m = re.search(r"(\d+(?:\.\d+)?)\s*[bB](?:\D|$)", base)
    return (m.group(1) + "B") if m else "UNK"


def to_file_uri(path: str) -> str:
    """
    Convert a local path to a file:// URI.

    Args:
        path: File path.

    Returns:
        str: file:// URI.
    """
    import os
    if path.startswith("file://"):
        return path
    return "file://" + os.path.abspath(path)
