ARGS_FOR_EXACT_MATCH = [
    'position', 'color', 'image', 'k', 'top1'
]
ARGS_FOR_TEXT_SIM_MATCH = [
    'attribute', 'text', 'query', 'command', 'annotation', 'keywords', 'instruction'
]
ARGS_FOR_CALCULATION = [
    'expression'
]
ARGS_FOR_BBOX_CHECK = [
    'bbox'
]

def exact_match(arg_value_pred, arg_value_gt):
    arg_value_pred = str(arg_value_pred)
    arg_value_gt = str(arg_value_gt)
    if arg_value_pred.lower().strip() == arg_value_gt.lower().strip():
        return 1.0
    return 0.0



"""
Simple utility functions for converting comma-separated numbers in expressions.
"""
import re

def clean_expression(expr):
    """
    Convert expression with comma-separated numbers to Python-evaluable format.
    
    Examples:
        "29,000 + 32,900" -> "29000 + 32900"
        "1,234,567 * 2" -> "1234567 * 2"
    
    Args:
        expr (str): Mathematical expression
        
    Returns:
        str: Cleaned expression without commas in numbers
    """
    # Remove commas from numbers (handles cases like 1,000 or 1,234,567)
    return re.sub(r'\b(\d{1,3}(?:,\d{3})+)\b', lambda m: m.group(1).replace(',', ''), expr)


import math
from difflib import SequenceMatcher

from bert_score import score as bert_score, BERTScorer
BERTSCORE_AVAILABLE = True

# Pre-initialize BERTScorer to avoid downloading model on every call
_bert_scorer = None

def get_bert_scorer():
    """Get or initialize the BERTScorer instance."""
    global _bert_scorer
    if _bert_scorer is None:
        try:
            _bert_scorer = BERTScorer(model_type="roberta-large", lang="en", rescale_with_baseline=True)
        except Exception as e:
            print(f"Failed to initialize BERTScorer: {e}")
            return None
    return _bert_scorer

def check_bbox_correctness(bbox_pred, bbox_gt):
    """
    Compute Intersection-over-Union (IoU) between predicted and ground-truth boxes.
    bbox_pred, bbox_gt: (x1, y1, x2, y2)
    Returns float in [0,1].
    """
    try:
        x1p, y1p, x2p, y2p = bbox_pred
        x1g, y1g, x2g, y2g = bbox_gt

        # intersection
        xi1 = max(x1p, x1g)
        yi1 = max(y1p, y1g)
        xi2 = min(x2p, x2g)
        yi2 = min(y2p, y2g)

        inter_w = max(0, xi2 - xi1)
        inter_h = max(0, yi2 - yi1)
        inter_area = inter_w * inter_h

        # areas
        area_pred = max(0, x2p - x1p) * max(0, y2p - y1p)
        area_gt   = max(0, x2g - x1g) * max(0, y2g - y1g)

        union = area_pred + area_gt - inter_area
        if union == 0:
            return 0.0
        return inter_area / union
    except Exception:
        return 0.0

def check_expression_correctness(expression_pred, expression_gt, tol=1e-6):
    """
    Evaluate both expressions and compare numerically.
    Returns 1.0 if they match (within tol for floats), else 0.0.
    """
    env = {"__builtins__": None}
    try:
        expression_pred = clean_expression(expression_pred)
        expression_gt   = clean_expression(expression_gt)
        val_pred = eval(expression_pred, env, {})
        val_gt   = eval(expression_gt,   env, {})
    except Exception:
        return 0.0

    # both numeric?
    if isinstance(val_pred, (int, float)) and isinstance(val_gt, (int, float)):
        return 1.0 if abs(val_pred - val_gt) <= tol else 0.0
    # fallback to direct equality
    return 1.0 if val_pred == val_gt else 0.0

def jaccard_similarity(s1, s2):
    """
    Calculate Jaccard similarity between two strings based on word sets.
    Good for handling word order differences.
    """
    words1 = set(s1.lower().split())
    words2 = set(s2.lower().split())
    
    if not words1 and not words2:
        return 1.0
    
    intersection = words1.intersection(words2)
    union = words1.union(words2)
    
    return len(intersection) / len(union) if union else 0.0

def token_overlap_similarity(s1, s2):
    """
    Calculate token overlap similarity (similar to BLEU but simpler).
    """
    tokens1 = s1.lower().split()
    tokens2 = s2.lower().split()
    
    if not tokens1 and not tokens2:
        return 1.0
    
    if not tokens1 or not tokens2:
        return 0.0
    
    # Count overlapping tokens
    overlap = 0
    tokens2_copy = tokens2.copy()
    
    for token in tokens1:
        if token in tokens2_copy:
            overlap += 1
            tokens2_copy.remove(token)
    
    # Calculate F1-like score
    precision = overlap / len(tokens1) if tokens1 else 0
    recall = overlap / len(tokens2) if tokens2 else 0
    
    if precision + recall == 0:
        return 0.0
    
    return 2 * (precision * recall) / (precision + recall)

def bert_score_similarity(s1, s2):
    """
    Calculate BERTScore similarity between two strings.
    Falls back to other methods if BERTScore is not available.
    """
    if not BERTSCORE_AVAILABLE:
        # Fallback to combined similarity
        return combined_text_similarity(s1, s2)
    
    try:
        # Use pre-initialized scorer to avoid repeated model downloads
        scorer = get_bert_scorer()
        if scorer is None:
            return combined_text_similarity(s1, s2)
        
        # BERTScorer.score returns P, R, F1 tensors
        P, R, F1 = scorer.score([s1], [s2])
        return F1.item()
    except Exception as e:
        print(f"BERTScore failed: {e}, falling back to combined similarity")
        # Fallback if BERTScore fails
        return combined_text_similarity(s1, s2)

def combined_text_similarity(s1, s2):
    """
    Combine multiple similarity metrics for robust text comparison.
    """
    s1 = s1.lower().strip()
    s2 = s2.lower().strip()
    
    if s1 == s2:
        return 1.0
    
    # Calculate different similarity metrics
    sequence_sim = SequenceMatcher(None, s1, s2).ratio()
    jaccard_sim = jaccard_similarity(s1, s2)
    token_sim = token_overlap_similarity(s1, s2)
    
    # Weighted combination
    # Give more weight to semantic similarity (Jaccard and token overlap)
    combined_score = (0.3 * sequence_sim + 0.35 * jaccard_sim + 0.35 * token_sim)
    
    return combined_score

def text_sim_match(arg_value_pred, arg_value_gt):
    """
    Advanced text similarity using BERTScore when available,
    falling back to combined similarity metrics.
    Returns a float in [0,1], with 1.0 meaning exact match.
    """
    s1 = str(arg_value_pred).strip()
    s2 = str(arg_value_gt).strip()
    
    # Try BERTScore first for best semantic similarity
    if BERTSCORE_AVAILABLE:
        return bert_score_similarity(s1, s2)
    else:
        return combined_text_similarity(s1, s2)



### Eval functions ###
def eval_argument_values(arg_type, arg_value_pred, arg_value_gt):
    """
    Evaluate argument correctness based on its type.
    Returns a float in [0,1] indicating match quality.
    """
    if arg_type in ARGS_FOR_EXACT_MATCH:
        return exact_match(arg_value_pred, arg_value_gt)
    elif arg_type in ARGS_FOR_TEXT_SIM_MATCH:
        return text_sim_match(arg_value_pred, arg_value_gt)
    elif arg_type in ARGS_FOR_CALCULATION:
        return check_expression_correctness(arg_value_pred, arg_value_gt)
    elif arg_type in ARGS_FOR_BBOX_CHECK:
        return check_bbox_correctness(arg_value_pred, arg_value_gt)
    else:
        raise ValueError(f"Unknown argument type: {arg_type}")

def eval_argument_acc_type_and_value(pred_arguments, gt_arguments):
    """
    Evaluate argument types and values.
    
    Args:
        pred_arguments (dict): Predicted arguments
        gt_arguments (dict): Ground truth arguments
    
    Returns:
        dict: Dictionary containing evaluation metrics:
            - type_accuracy: Accuracy of argument types (presence/absence)
            - value_scores: Dictionary of value match scores for each argument
            - overall_score: Combined score considering both type and value accuracy
    """
    if not isinstance(pred_arguments, dict) or not isinstance(gt_arguments, dict):
        return {
            "type_accuracy": 0.0,
            "value_scores": {},
            "overall_score": 0.0
        }
    
    # Get all unique argument names from both predicted and ground truth
    all_arg_names = set(pred_arguments.keys()) | set(gt_arguments.keys())
    
    if not all_arg_names:
        return {
            "type_accuracy": 1.0,  # Both empty is perfect match
            "value_scores": {},
            "overall_score": 1.0
        }
    
    type_matches = 0
    value_scores = {}
    total_value_score = 0.0
    valid_value_comparisons = 0
    
    for arg_name in all_arg_names:
        pred_present = arg_name in pred_arguments
        gt_present = arg_name in gt_arguments
        
        # Type accuracy: both present or both absent
        if pred_present == gt_present:
            type_matches += 1
            
            # If both present, evaluate value accuracy
            if pred_present and gt_present:
                pred_value = pred_arguments[arg_name]
                gt_value = gt_arguments[arg_name]
                
                # Use the existing eval_argument_values function
                try:
                    value_score = eval_argument_values(arg_name, pred_value, gt_value)
                    value_scores[arg_name] = value_score
                    total_value_score += value_score
                    valid_value_comparisons += 1
                except ValueError:
                    # Unknown argument type, fall back to exact match
                    value_score = exact_match(str(pred_value), str(gt_value))
                    value_scores[arg_name] = value_score
                    total_value_score += value_score
                    valid_value_comparisons += 1
        else:
            # Type mismatch: one present, one absent
            if pred_present:
                value_scores[arg_name] = 0.0  # Predicted but not in GT
            # If only GT present, we don't add to value_scores (missing prediction)
    
    # Calculate metrics
    type_accuracy = type_matches / len(all_arg_names)
    avg_value_score = total_value_score / valid_value_comparisons if valid_value_comparisons > 0 else 0.0
    
    # Overall score: weighted combination of type and value accuracy
    # Give equal weight to type and value accuracy
    overall_score = (type_accuracy + avg_value_score) / 2.0
    
    return {
        "type_accuracy": type_accuracy,
        "value_scores": value_scores,
        "avg_value_score": avg_value_score,
        "overall_score": overall_score
    }

def eval_tool(pred_tool, gt_tool):
    """
    Overall evaluation of the tool object comparing predicted vs ground truth.
    
    Args:
        pred_tool (str or dict): Predicted tool in JSON string format or dict
        gt_tool (str or dict): Ground truth tool in JSON string format or dict
    
    Returns:
        dict: Dictionary containing evaluation metrics:
            - tool_name_match: 1.0 if tool names match exactly, 0.0 otherwise
            - arguments_eval: Result from eval_argument_acc_type_and_value
            - overall_score: Combined score considering tool name and arguments
    """
    import json
    
    # Parse tool strings to dictionaries if needed
    try:
        if isinstance(pred_tool, str):
            pred_tool_dict = json.loads(pred_tool.strip())
        else:
            pred_tool_dict = pred_tool
        if isinstance(gt_tool, str):
            gt_tool_dict = json.loads(gt_tool.strip())
        else:
            gt_tool_dict = gt_tool
    except (json.JSONDecodeError, TypeError):
        return {
            "tool_name_match": 0.0,
            "arguments_eval": {
                "type_accuracy": 0.0,
                "value_scores": {},
                "avg_value_score": 0.0,
                "overall_score": 0.0
            },
            "overall_score": 0.0
        }
    
    
    # Extract tool names and arguments
    pred_name = pred_tool_dict.get("name", "")
    gt_name = gt_tool_dict.get("name", "")
    pred_arguments = pred_tool_dict.get("arguments", {})
    gt_arguments = gt_tool_dict.get("arguments", {})
    
    # Evaluate tool name match (exact match)
    tool_name_match = 1.0 if pred_name.lower().strip() == gt_name.lower().strip() else 0.0
    
    if tool_name_match == 1.0:
        # Evaluate arguments
        arguments_eval = eval_argument_acc_type_and_value(pred_arguments, gt_arguments)
    else:
        # If tool names don't match, arguments evaluation is irrelevant
        arguments_eval = {
            "type_accuracy": 0.0,
            "value_scores": {},
            "avg_value_score": 0.0,
            "overall_score": 0.0
        }
    
    # Calculate overall score
    overall_score = (0.5 * tool_name_match + 0.5 * arguments_eval["overall_score"])
    
    return {
        "tool_name_match": tool_name_match,
        "arguments_eval": arguments_eval,
        "overall_score": overall_score
    }


from typing import Union

def extract_json_answer(text: str) -> str:
    """
    Extract the first JSON object from text, handling optional outer code fences.
    """
    # remove <think> block
    if "</think>" in text:
        text = text.split("</think>")[-1].strip()
    # Remove outer markdown code fence if present
    lines = text.splitlines()
    start_idx = None
    for i, line in enumerate(lines):
        if line.strip().startswith("```"):
            start_idx = i
            break
    if start_idx is not None:
        # Find closing fence
        for j in range(start_idx + 1, len(lines)):
            if lines[j].strip() == "```":
                text = "\n".join(lines[start_idx + 1:j])
                break

    # Locate balanced braces
    start = text.find("{")
    if start == -1:
        # raise ValueError("No JSON object found")
        return None
    depth = 0
    for idx, ch in enumerate(text[start:], start):
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return text[start:idx + 1]
    # raise ValueError("Unmatched '{' in text")
    return None



def run_eval(pred_str, gt_str):
    """
    Run evaluation between predicted and ground truth strings.
    
    Args:
        pred_str (str): Predicted response string containing JSON tool call
        gt_str (str): Ground truth string containing JSON tool call
    
    Returns:
        dict: Evaluation results from eval_tool function, or error dict if extraction fails
    """
    # Extract JSON strings from both predicted and ground truth responses
    pred_json = extract_json_answer(pred_str)
    gt_json = gt_str.strip()
    
    # Check if both extractions were successful
    if pred_json is None:
        return {
            "error": "Failed to extract JSON from predicted string",
            "tool_name_match": 0.0,
            "arguments_eval": {
                "type_accuracy": 0.0,
                "value_scores": {},
                "avg_value_score": 0.0,
                "overall_score": 0.0
            },
            "overall_score": 0.0
        }
    
    if gt_json is None:
        return {
            "error": "Failed to extract JSON from ground truth string",
            "tool_name_match": 0.0,
            "arguments_eval": {
                "type_accuracy": 0.0,
                "value_scores": {},
                "avg_value_score": 0.0,
                "overall_score": 0.0
            },
            "overall_score": 0.0
        }
    # print(f"Predicted JSON: {pred_json}"
    #       f"\nGround Truth JSON: {gt_json}")
    # Evaluate the extracted tools
    return eval_tool(pred_json, gt_json)


###
def unit_tests():
# Example usage
    print("Testing eval_argument_acc_type_and_value function:")
    
    # Test case 1: Perfect match
    pred_args = {"expression": "29,900 + 32,900", "position": "center", "text": "Hello World"}
    gt_args = {"expression": "29900 + 32900", "position": "center", "text": "Hello World"}
    
    result = eval_argument_acc_type_and_value(pred_args, gt_args)
    print(f"Test 1 - Perfect match: {result}")
    
    # Test case 2: Partial match with missing argument
    pred_args = {"expression": "100 + 200", "position": "left"}
    gt_args = {"expression": "100 + 200", "position": "center", "text": "Missing"}
    
    result = eval_argument_acc_type_and_value(pred_args, gt_args)
    print(f"Test 2 - Partial match: {result}")
    
    # Test case 3: Text similarity
    pred_args = {"text": "hello world", "attribute": "color red"}
    gt_args = {"text": "Hello World!", "attribute": "red color"}
    
    result = eval_argument_acc_type_and_value(pred_args, gt_args)
    print(f"Test 3 - Text similarity: {result}")
    
    # Test case 4: Additional text similarity tests
    print("\nAdditional text similarity tests:")
    
    test_cases = [
        ("red color", "color red"),
        ("big blue car", "blue big car"),
        ("machine learning model", "ML model"),
        ("Hello World", "hello world"),
        ("completely different", "totally unrelated"),
    ]
    
    for s1, s2 in test_cases:
        score = text_sim_match(s1, s2)
        print(f"'{s1}' vs '{s2}': {score:.4f}")
        
    # Test individual similarity metrics
    print(f"\nBERTScore available: {BERTSCORE_AVAILABLE}")
    if not BERTSCORE_AVAILABLE:
        print("Install bert-score for better text similarity: pip install bert-score")
    
    # Test eval_tool function
    print("\n" + "="*50)
    print("Testing eval_tool function:")
    
    # Test case 1: Perfect match
    pred_tool_1 = '{"name": "Calculator", "arguments": {"expression": "29,000 + 32,900"}}'
    gt_tool_1 = '{"name": "Calculator", "arguments": {"expression": "29000 + 32900"}}'
    
    result = eval_tool(pred_tool_1, gt_tool_1)
    print(f"Test 1 - Perfect match: {result}")
    
    # Test case 2: Wrong tool name
    pred_tool_2 = '{"name": "Math", "arguments": {"expression": "100 + 200"}}'
    gt_tool_2 = '{"name": "Calculator", "arguments": {"expression": "100 + 200"}}'
    
    result = eval_tool(pred_tool_2, gt_tool_2)
    print(f"Test 2 - Wrong tool name: {result}")
    
    # Test case 3: Missing argument
    pred_tool_3 = '{"name": "TextProcessor", "arguments": {"text": "hello"}}'
    gt_tool_3 = '{"name": "TextProcessor", "arguments": {"text": "hello", "command": "uppercase"}}'
    
    result = eval_tool(pred_tool_3, gt_tool_3)
    print(f"Test 3 - Missing argument: {result}")
    
    # Test case 4: Dictionary input instead of string
    pred_tool_4 = {"name": "ImageProcessor", "arguments": {"position": "center", "color": "red"}}
    gt_tool_4 = {"name": "ImageProcessor", "arguments": {"position": "center", "color": "red"}}
    
    result = eval_tool(pred_tool_4, gt_tool_4)
    print(f"Test 4 - Dict input perfect match: {result}")
    
    # Test case 5: Invalid JSON
    pred_tool_5 = '{"name": "Calculator", "arguments": {"expression": "1+1"'  # Invalid JSON
    gt_tool_5 = '{"name": "Calculator", "arguments": {"expression": "1+1"}}'
    
    result = eval_tool(pred_tool_5, gt_tool_5)
    print(f"Test 5 - Invalid JSON: {result}")