"""
Token analysis utilities for counting and categorizing tokens in prompts and responses.
Updated to handle XML-based response format with <thinking></thinking> and <answer></answer> tags.
"""

import re
from typing import Dict, Any


def extract_reasoning_tokens(response_text: str) -> Dict[str, int]:
    """
    Extract reasoning and final answer sections from a response and estimate token counts.
    Updated to handle XML format with <thinking></thinking> and <answer></answer> tags.

    Parameters:
    - response_text: Full response text

    Returns:
    - Dict with reasoning_tokens, answer_tokens, total_response_tokens, and parsing status
    """
    try:
        # Extract thinking section
        thinking_text = extract_xml_section(response_text, "thinking")
        has_thinking_section = thinking_text is not None

        # Extract answer section
        answer_text = extract_xml_section(response_text, "answer")
        has_answer_section = answer_text is not None

        # Use empty string if section not found
        if thinking_text is None:
            thinking_text = ""
        if answer_text is None:
            answer_text = ""

        # If no XML sections found at all, treat entire response as answer
        if not has_thinking_section and not has_answer_section:
            answer_text = response_text.strip()
            has_answer_section = True  # Mark as having answer section for compatibility

        # Estimate token counts (rough approximation: ~4 chars per token)
        thinking_tokens = estimate_token_count(thinking_text)
        answer_tokens = estimate_token_count(answer_text)
        total_tokens = estimate_token_count(response_text)

        return {
            "reasoning_tokens": thinking_tokens,
            "answer_tokens": answer_tokens,
            "total_response_tokens": total_tokens,
            "has_thinking_section": has_thinking_section,
            "has_answer_section": has_answer_section,
        }

    except (ValueError, re.error) as e:
        # Fallback to total count if parsing fails completely
        total_tokens = estimate_token_count(response_text)
        return {
            "reasoning_tokens": 0,
            "answer_tokens": total_tokens,
            "total_response_tokens": total_tokens,
            "has_thinking_section": False,
            "has_answer_section": False,
            "parsing_error": str(e),
        }


def extract_xml_section(text: str, tag_name: str) -> str:
    """
    Extract content between XML tags.

    Parameters:
    - text: Text to search in
    - tag_name: Name of the XML tag (without brackets)

    Returns:
    - str: Content between the tags, or None if not found
    """
    if not text or not tag_name:
        return None

    # Use regex to find content between XML tags
    # This pattern handles:
    # - Optional whitespace around tag names
    # - Case insensitive matching
    # - Multiline content
    # - Tags on separate lines
    pattern = rf"<\s*{re.escape(tag_name)}\s*>(.*?)<\s*/\s*{re.escape(tag_name)}\s*>"
    match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)

    if match:
        content = match.group(1).strip()
        return content if content else None

    return None


def estimate_token_count(text: str) -> int:
    """
    Rough token count estimation (4 characters per token average).
    This is a fallback when API doesn't provide exact counts.

    Parameters:
    - text: Text to count tokens for

    Returns:
    - Estimated token count
    """
    if not text or not text.strip():
        return 0

    # Remove extra whitespace and normalize
    normalized_text = " ".join(text.split())

    # Rough estimation: ~4 characters per token for English text
    # This is approximate but reasonable for comparison purposes
    return max(1, len(normalized_text) // 4)


def create_token_usage_dict(
    input_tokens: int = None,
    output_tokens: int = None,
    total_tokens: int = None,
    reasoning_tokens: int = None,
    answer_tokens: int = None,
    api_reported: Dict[str, int] = None,
    estimation_method: str = "api",
) -> Dict[str, Any]:
    """
    Create a standardized token usage dictionary.

    Parameters:
    - input_tokens: Number of tokens in the input prompt
    - output_tokens: Number of tokens in the response
    - total_tokens: Total tokens (input + output)
    - reasoning_tokens: Tokens in <thinking> section
    - answer_tokens: Tokens in <answer> section
    - api_reported: Raw token data from API
    - estimation_method: How tokens were counted ("api", "estimated", "mixed")

    Returns:
    - Standardized token usage dictionary
    """
    token_usage = {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "total_tokens": total_tokens,
        "reasoning_tokens": reasoning_tokens,
        "answer_tokens": answer_tokens,
        "estimation_method": estimation_method,
        "timestamp": None,  # Will be set by caller
    }

    # Include raw API data if available
    if api_reported:
        token_usage["api_reported"] = api_reported

    # Calculate derived metrics
    if reasoning_tokens is not None and answer_tokens is not None:
        token_usage["reasoning_ratio"] = reasoning_tokens / max(
            1, reasoning_tokens + answer_tokens
        )

    if input_tokens is not None and output_tokens is not None:
        token_usage["io_ratio"] = output_tokens / max(1, input_tokens)

    return token_usage


def analyze_response_tokens(
    prompt_text: str,
    response_text: str,
    api_usage: Dict[str, int] = None,
) -> Dict[str, Any]:
    """
    Complete token analysis for a prompt-response pair.

    Parameters:
    - prompt_text: The input prompt
    - response_text: The model's response
    - api_usage: Token usage data from API (if available)

    Returns:
    - Complete token usage dictionary
    """
    # Extract reasoning breakdown from response
    response_breakdown = extract_reasoning_tokens(response_text)

    # Determine input tokens
    if api_usage and "prompt_tokens" in api_usage:
        input_tokens = api_usage["prompt_tokens"]
        estimation_method = "api"
    else:
        input_tokens = estimate_token_count(prompt_text)
        estimation_method = "estimated"

    # Determine output tokens
    if api_usage and "completion_tokens" in api_usage:
        output_tokens = api_usage["completion_tokens"]
        if estimation_method == "estimated":
            estimation_method = "mixed"
    else:
        output_tokens = response_breakdown["total_response_tokens"]
        if estimation_method == "api":
            estimation_method = "mixed"
        else:
            estimation_method = "estimated"

    # Determine total tokens
    if api_usage and "total_tokens" in api_usage:
        total_tokens = api_usage["total_tokens"]
    else:
        total_tokens = input_tokens + output_tokens

    return create_token_usage_dict(
        input_tokens=input_tokens,
        output_tokens=output_tokens,
        total_tokens=total_tokens,
        reasoning_tokens=response_breakdown["reasoning_tokens"],
        answer_tokens=response_breakdown["answer_tokens"],
        api_reported=api_usage,
        estimation_method=estimation_method,
    )


def validate_response_format(response_text: str) -> Dict[str, Any]:
    """
    Validate that a response follows the expected XML format.

    Parameters:
    - response_text: The model's response text

    Returns:
    - Dict with validation results and suggestions
    """
    has_thinking = extract_xml_section(response_text, "thinking") is not None
    has_answer = extract_xml_section(response_text, "answer") is not None

    # Check for common formatting issues
    issues = []
    suggestions = []

    if not has_thinking and not has_answer:
        issues.append("No XML tags found")
        suggestions.append(
            "Response should contain <thinking></thinking> and <answer></answer> tags"
        )
    elif not has_thinking:
        issues.append("Missing <thinking> section")
        suggestions.append("Add <thinking></thinking> tags around your reasoning")
    elif not has_answer:
        issues.append("Missing <answer> section")
        suggestions.append("Add <answer></answer> tags around your final answer")

    # Check for partial tags (common error)
    text_lower = response_text.lower()
    if "<thinking" in text_lower and "</thinking>" not in text_lower:
        issues.append("Unclosed <thinking> tag")
    if "<answer" in text_lower and "</answer>" not in text_lower:
        issues.append("Unclosed <answer> tag")
    if "</thinking>" in text_lower and "<thinking" not in text_lower:
        issues.append("Closing </thinking> tag without opening tag")
    if "</answer>" in text_lower and "<answer" not in text_lower:
        issues.append("Closing </answer> tag without opening tag")

    return {
        "is_valid": has_thinking and has_answer,
        "has_thinking": has_thinking,
        "has_answer": has_answer,
        "issues": issues,
        "suggestions": suggestions,
    }
