"""
Metadata and UID utilities for structured ID management.

This module provides utilities for managing structured metadata instead of
compound string IDs. Each record has a metadata dict containing ID components
and a uid field generated from the metadata.
"""

from typing import Dict, Any, Optional
from copy import deepcopy
import json
import os
from pathlib import Path
from collections import defaultdict


def create_metadata(
    origin_problem_id: str,
    round_id: int = 0,
    parent_problem_id: Optional[str] = None,
    breakdown_id: Optional[int] = None,
    lemma_id: Optional[int] = None,
    formalization_id: Optional[int] = None,
    attempt_id: Optional[int] = None,
    iteration_id: Optional[int] = None,
    correction_round_id: Optional[int] = None,
    initial_attempt_id: Optional[int] = None,
) -> Dict[str, Any]:
    """
    Create a metadata dictionary with ID components.

    Args:
        origin_problem_id: Base problem ID (e.g., "putnam_1969_a2")
        round_id: Round number (default 0)
        parent_problem_id: Parent problem ID (defaults to origin_problem_id for round 0)
        breakdown_id: Breakdown number (optional)
        lemma_id: Lemma number (optional, -1 for theorem)
        formalization_id: Formalization sample number (optional)
        attempt_id: Proof attempt number (optional)
        iteration_id: Prover iteration number (optional)
        correction_round_id: Correction round number (optional, 0 for first round)
        initial_attempt_id: Initial attempt ID that started this correction chain (optional)

    Returns:
        Metadata dictionary
    """
    # For consistency, parent_problem_id always exists
    # For round 0, it equals origin_problem_id (parent is the original problem)
    # For round 1+, it should be the UID of the failed lemma from previous round
    if parent_problem_id is None:
        parent_problem_id = origin_problem_id

    metadata = {
        "origin_problem_id": origin_problem_id,
        "parent_problem_id": parent_problem_id,
        "round_id": round_id,
    }

    if breakdown_id is not None:
        metadata["breakdown_id"] = breakdown_id
    if lemma_id is not None:
        metadata["lemma_id"] = lemma_id
    if formalization_id is not None:
        metadata["formalization_id"] = formalization_id
    if attempt_id is not None:
        metadata["attempt_id"] = attempt_id
    if iteration_id is not None:
        metadata["iteration_id"] = iteration_id
    if correction_round_id is not None:
        metadata["correction_round_id"] = correction_round_id
    if initial_attempt_id is not None:
        metadata["initial_attempt_id"] = initial_attempt_id

    return metadata


def generate_uid(metadata: Dict[str, Any]) -> str:
    """
    Generate a unique ID string from metadata.

    The UID is generated by joining metadata components with underscores.
    Components are added in hierarchical order. Special case: lemma_id=-1
    is rendered as "theorem" instead of "l-1".

    Args:
        metadata: Metadata dictionary

    Returns:
        Generated UID string (e.g., "putnam_1969_a2_r0_b0_l1_f0")
    """
    components = []

    # Always include origin_problem_id or parent_problem_id and round_id
    if "parent_problem_id" in metadata:
        components.append(metadata["parent_problem_id"])
    else:
        components.append(metadata["origin_problem_id"])
    components.append(f"r{metadata['round_id']}")

    # Add optional components in hierarchical order
    if "breakdown_id" in metadata:
        components.append(f"b{metadata['breakdown_id']}")

    if "lemma_id" in metadata:
        lemma_id = metadata["lemma_id"]
        if lemma_id == -1:
            components.append("theorem")
        else:
            components.append(f"l{lemma_id}")

    if "formalization_id" in metadata:
        components.append(f"f{metadata['formalization_id']}")

    if "attempt_id" in metadata:
        components.append(f"a{metadata['attempt_id']}")

    if "iteration_id" in metadata:
        components.append(f"i{metadata['iteration_id']}")

    if "correction_round_id" in metadata:
        components.append(f"c{metadata['correction_round_id']}")

    return "_".join(components)


def generate_problem_key(metadata: Dict[str, Any]) -> tuple:
    """
    Generate a problem key from metadata, excluding attempt-level fields.

    This key uniquely identifies a problem instance but excludes attempt_id,
    iteration_id, and correction_round_id. Used for grouping different attempts
    at the same problem together (e.g., for correction rounds).

    Returns a tuple that can be used as a dictionary key. This approach is more
    robust to future metadata changes than string generation.

    Args:
        metadata: Metadata dictionary

    Returns:
        Tuple of (field_name, value) pairs for non-attempt fields, sorted by field name.
        Example: (('breakdown_id', 0), ('formalization_id', 0), ('lemma_id', 2),
                  ('origin_problem_id', 'numbertheory_4x3m7y3neq2003'), ('round_id', 0))
    """
    # Fields that are attempt-specific and should be excluded from grouping
    ATTEMPT_LEVEL_FIELDS = {'attempt_id', 'iteration_id', 'correction_round_id', 'initial_attempt_id'}

    # Create tuple of (key, value) pairs for all non-attempt fields
    # Sort by key for consistent ordering
    problem_fields = {
        k: v for k, v in metadata.items()
        if k not in ATTEMPT_LEVEL_FIELDS
    }

    return tuple(sorted(problem_fields.items()))


def copy_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]:
    """
    Create a deep copy of metadata dictionary.

    Args:
        metadata: Source metadata

    Returns:
        Deep copy of metadata
    """
    return deepcopy(metadata)


def add_breakdown(metadata: Dict[str, Any], breakdown_id: int) -> Dict[str, Any]:
    """
    Add breakdown_id to metadata and return new metadata dict.

    Args:
        metadata: Source metadata
        breakdown_id: Breakdown number to add

    Returns:
        New metadata dict with breakdown_id added
    """
    new_metadata = copy_metadata(metadata)
    new_metadata["breakdown_id"] = breakdown_id
    return new_metadata


def add_lemma(metadata: Dict[str, Any], lemma_id: int) -> Dict[str, Any]:
    """
    Add lemma_id to metadata and return new metadata dict.
    Use lemma_id=-1 for theorems.

    Args:
        metadata: Source metadata
        lemma_id: Lemma number to add (-1 for theorem)

    Returns:
        New metadata dict with lemma_id added
    """
    new_metadata = copy_metadata(metadata)
    new_metadata["lemma_id"] = lemma_id
    return new_metadata


def add_formalization(metadata: Dict[str, Any], formalization_id: int) -> Dict[str, Any]:
    """
    Add formalization_id to metadata and return new metadata dict.

    Args:
        metadata: Source metadata
        formalization_id: Formalization sample number to add

    Returns:
        New metadata dict with formalization_id added
    """
    new_metadata = copy_metadata(metadata)
    new_metadata["formalization_id"] = formalization_id
    return new_metadata


def add_attempt(metadata: Dict[str, Any], attempt_id: int) -> Dict[str, Any]:
    """
    Add attempt_id to metadata and return new metadata dict.

    Args:
        metadata: Source metadata
        attempt_id: Proof attempt number to add

    Returns:
        New metadata dict with attempt_id added
    """
    new_metadata = copy_metadata(metadata)
    new_metadata["attempt_id"] = attempt_id
    return new_metadata


def add_iteration(metadata: Dict[str, Any], iteration_id: int) -> Dict[str, Any]:
    """
    Add iteration_id to metadata and return new metadata dict.

    Args:
        metadata: Source metadata
        iteration_id: Iteration number to add

    Returns:
        New metadata dict with iteration_id added
    """
    new_metadata = copy_metadata(metadata)
    new_metadata["iteration_id"] = iteration_id
    return new_metadata


def add_correction(metadata: Dict[str, Any], correction_round_id: int) -> Dict[str, Any]:
    """
    Add correction_round_id to metadata and return new metadata dict.

    Args:
        metadata: Source metadata
        correction_round_id: Correction round number to add (0 for first round)

    Returns:
        New metadata dict with correction_round_id added
    """
    new_metadata = copy_metadata(metadata)
    new_metadata["correction_round_id"] = correction_round_id
    return new_metadata


def remove_field(metadata: Dict[str, Any], field_name: str) -> Dict[str, Any]:
    """
    Remove a field from metadata and return new metadata dict.

    Args:
        metadata: Source metadata
        field_name: Name of field to remove

    Returns:
        New metadata dict without the specified field
    """
    new_metadata = copy_metadata(metadata)
    new_metadata.pop(field_name, None)
    return new_metadata


def get_origin_problem_id(metadata: Dict[str, Any]) -> str:
    """
    Get origin_problem_id from metadata.

    Args:
        metadata: Metadata dict

    Returns:
        Origin problem ID
    """
    return metadata["origin_problem_id"]


def get_breakdown_id(metadata: Dict[str, Any]) -> Optional[int]:
    """
    Get breakdown_id from metadata.

    Args:
        metadata: Metadata dict

    Returns:
        Breakdown ID or None if not present
    """
    return metadata.get("breakdown_id")


def get_lemma_id(metadata: Dict[str, Any]) -> Optional[int]:
    """
    Get lemma_id from metadata.

    Args:
        metadata: Metadata dict

    Returns:
        Lemma ID or None if not present (-1 indicates theorem)
    """
    return metadata.get("lemma_id")


def get_breakdown_key(metadata: Dict[str, Any]) -> str:
    """
    Get a unique breakdown key from metadata.

    This returns a unique string that identifies a breakdown across all problems.
    Format: "{origin_problem_id}_r{round_id}_b{breakdown_id}"
    Example: "aime_1983_p1_r0_b0"

    Args:
        metadata: Metadata dict

    Returns:
        Unique breakdown key string, or "unknown" if required fields missing
    """
    if not isinstance(metadata, dict):
        return "unknown"

    problem_id = metadata.get("parent_problem_id", metadata.get("origin_problem_id"))
    round_id = metadata.get("round_id")
    breakdown_id = metadata.get("breakdown_id")

    if problem_id is None or round_id is None or breakdown_id is None:
        return "unknown"

    return f"{problem_id}_r{round_id}_b{breakdown_id}"


def is_theorem(metadata: Dict[str, Any]) -> bool:
    """
    Check if metadata represents a theorem.

    Args:
        metadata: Metadata dict

    Returns:
        True if lemma_id is -1, False otherwise
    """
    return metadata.get("lemma_id") == -1


def is_lemma(metadata: Dict[str, Any]) -> bool:
    """
    Check if metadata represents a lemma (not a theorem).

    Args:
        metadata: Metadata dict

    Returns:
        True if lemma_id exists and is not -1, False otherwise
    """
    lemma_id = metadata.get("lemma_id")
    return lemma_id is not None and lemma_id != -1


def create_problem_key(metadata: Dict[str, Any]) -> str:
    """
    Create problem-level key string from metadata.

    Format: "{origin_problem_id}"
    Example: "putnam_1969_a2"

    Args:
        metadata: Metadata dict

    Returns:
        Problem key string
    """
    if not isinstance(metadata, dict):
        return "unknown"
    return str(metadata.get('parent_problem_id', metadata.get('origin_problem_id', 'unknown')))


def create_breakdown_problem_id(metadata: Dict[str, Any]) -> str:
    """
    Create the full breakdown problem_id from metadata (matches breakdown.json entries).

    This is the primary identifier used in breakdown.json and is needed to match
    prover records and other pipeline outputs with their corresponding breakdowns.

    Format: "{origin_problem_id}_r{round_id}_b{breakdown_id}"
    Example: "putnam_1969_a2_r0_b1"

    Args:
        metadata: Metadata dict with origin_problem_id, round_id, and breakdown_id

    Returns:
        Breakdown problem_id string
    """
    if not isinstance(metadata, dict):
        return "unknown"

    origin_id = metadata.get('origin_problem_id', 'unknown')
    round_id = metadata.get('round_id', 0)
    breakdown_id = metadata.get('breakdown_id', '')

    if breakdown_id == '':
        return "unknown"

    return f"{origin_id}_r{round_id}_b{breakdown_id}"


def create_breakdown_key(metadata: Dict[str, Any]) -> str:
    """
    Create breakdown-level key string from metadata.

    Format: "{origin_problem_id}_b{breakdown_id}"
    Example: "putnam_1969_a2_b1"

    Args:
        metadata: Metadata dict

    Returns:
        Breakdown key string
    """
    if not isinstance(metadata, dict):
        return "unknown"

    parts = [str(metadata.get('parent_problem_id', metadata.get('origin_problem_id', 'unknown')))]

    if 'breakdown_id' in metadata:
        parts.append(f"b{metadata['breakdown_id']}")

    return "_".join(parts)


def create_lemma_key(metadata: Dict[str, Any]) -> str:
    """
    Create lemma-level key string from metadata.

    Format: "{origin_problem_id}_b{breakdown_id}_l{lemma_id}"
    or:     "{origin_problem_id}_b{breakdown_id}_theorem" (if lemma_id == -1)
    Example: "putnam_1969_a2_b1_l2" or "putnam_1969_a2_b1_theorem"

    Args:
        metadata: Metadata dict

    Returns:
        Lemma key string
    """
    if not isinstance(metadata, dict):
        return "unknown"

    parts = [str(metadata.get('parent_problem_id', metadata.get('origin_problem_id', 'unknown')))]

    if 'round_id' in metadata:
        parts.append(f"r{metadata['round_id']}")

    if 'breakdown_id' in metadata:
        parts.append(f"b{metadata['breakdown_id']}")

    if 'lemma_id' in metadata:
        lemma_id = metadata['lemma_id']
        if lemma_id == -1:
            parts.append("theorem")
        else:
            parts.append(f"l{lemma_id}")

    return "_".join(parts)

def create_formalization_key(metadata: Dict[str, Any]) -> str:
    """
    Create formalization-level key string from metadata.

    Format: "{origin_problem_id}_b{breakdown_id}_l{lemma_id}_f{formalization_id}"
    or:     "{origin_problem_id}_b{breakdown_id}_theorem" (if lemma_id == -1)
    Example: "putnam_1969_a2_b1_l2_f1" or "putnam_1969_a2_b1_theorem"

    Args:
        metadata: Metadata dict

    Returns:
        Lemma key string
    """
    if not isinstance(metadata, dict):
        return "unknown"

    parts = [str(metadata.get('parent_problem_id', metadata.get('origin_problem_id', 'unknown')))]

    if 'round_id' in metadata:
        parts.append(f"r{metadata['round_id']}")

    if 'breakdown_id' in metadata:
        parts.append(f"b{metadata['breakdown_id']}")

    if 'lemma_id' in metadata:
        lemma_id = metadata['lemma_id']
        if lemma_id == -1:
            parts.append("theorem")
        else:
            parts.append(f"l{lemma_id}")
            # theorems don't have multiple formalizations
            if 'formalization_id' in metadata:
                parts.append(f"f{metadata['formalization_id']}")

    return "_".join(parts)


def create_attempt_key(metadata: Dict[str, Any]) -> str:
    """
    Create attempt-level key string from metadata (includes attempt_id).

    Format: "{origin_problem_id}_b{breakdown_id}_l{lemma_id}_a{attempt_id}"
    Example: "putnam_1969_a2_b1_theorem_a0"

    This key includes attempt_id so each proof attempt is counted separately.
    Useful for tracking multiple attempts at proving the same lemma/theorem.

    Args:
        metadata: Metadata dict

    Returns:
        Attempt key string
    """
    if not isinstance(metadata, dict):
        return "unknown"

    # Start with lemma-level key
    parts = [str(metadata.get('parent_problem_id', metadata.get('origin_problem_id', 'unknown')))]

    if 'breakdown_id' in metadata:
        parts.append(f"b{metadata['breakdown_id']}")

    if 'lemma_id' in metadata:
        lemma_id = metadata['lemma_id']
        if lemma_id == -1:
            parts.append("theorem")
        else:
            parts.append(f"l{lemma_id}")

    # Add attempt_id to distinguish multiple attempts
    if 'attempt_id' in metadata:
        parts.append(f"a{metadata['attempt_id']}")

    return "_".join(parts)


# ============================================================================
# COST AND TOKEN CALCULATION FUNCTIONS
# ============================================================================

def extract_detailed_cost(record: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract cost and token information from a record.

    Args:
        record: Record dict that may contain 'detailed_cost' or 'parser_detailed_cost'

    Returns:
        Dict with 'cost', 'input_tokens', 'output_tokens' (all default to 0/0.0 if not found)
    """
    result = {
        "cost": 0.0,
        "input_tokens": 0,
        "output_tokens": 0,
    }

    if "detailed_cost" in record and isinstance(record["detailed_cost"], dict):
        result["cost"] += record["detailed_cost"].get("cost", 0.0)
        result["input_tokens"] += record["detailed_cost"].get("input_tokens", 0)
        result["output_tokens"] += record["detailed_cost"].get("output_tokens", 0)

    if "parser_detailed_cost" in record and isinstance(record["parser_detailed_cost"], dict):
        result["cost"] += record["parser_detailed_cost"].get("cost", 0.0)
        result["input_tokens"] += record["parser_detailed_cost"].get("input_tokens", 0)
        result["output_tokens"] += record["parser_detailed_cost"].get("output_tokens", 0)

    return result


def _load_json_file(file_path: str) -> Any:
    """
    Safely load a JSON file.

    Args:
        file_path: Path to JSON file

    Returns:
        Parsed JSON data or empty dict/list if file doesn't exist or is invalid
    """
    try:
        if not os.path.exists(file_path):
            return None
        with open(file_path, 'r') as f:
            return json.load(f)
    except (json.JSONDecodeError, IOError):
        return None


def get_run_costs(run_path: str) -> Dict[str, Any]:
    """
    Calculate total costs and tokens for an entire run.

    Aggregates costs from:
    - full_records/ (compiled proofs)
    - round0/breakdown_parser/parsed_breakdown.json (breakdown + parser)
    - round0/prover/consolidated_full_records.json (prover costs)

    Args:
        run_path: Path to run directory (e.g., /path/to/results/combined/2025/11/14/134243)

    Returns:
        Dict with structure:
        {
            "total_cost": float,
            "total_input_tokens": int,
            "total_output_tokens": int,
            "components": {
                "breakdown_parser": {"cost": float, "input_tokens": int, "output_tokens": int},
                "prover": {"cost": float, "input_tokens": int, "output_tokens": int},
            }
        }
    """
    result = {
        "total_cost": 0.0,
        "total_input_tokens": 0,
        "total_output_tokens": 0,
        "components": {
            "breakdown_parser": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0},
            "prover": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0},
        }
    }

    run_path = str(run_path)

    # Load breakdown_parser costs
    bp_file = os.path.join(run_path, "round0", "breakdown_parser", "parsed_breakdown.json")
    bp_data = _load_json_file(bp_file)
    if bp_data and isinstance(bp_data, list):
        for record in bp_data:
            cost_info = extract_detailed_cost(record)
            result["components"]["breakdown_parser"]["cost"] += cost_info["cost"]
            result["components"]["breakdown_parser"]["input_tokens"] += cost_info["input_tokens"]
            result["components"]["breakdown_parser"]["output_tokens"] += cost_info["output_tokens"]

    # Load prover costs from consolidated records
    prover_file = os.path.join(run_path, "round0", "prover", "consolidated_full_records.json")
    prover_data = _load_json_file(prover_file)
    if prover_data and isinstance(prover_data, list):
        for record in prover_data:
            cost_info = extract_detailed_cost(record)
            result["components"]["prover"]["cost"] += cost_info["cost"]
            result["components"]["prover"]["input_tokens"] += cost_info["input_tokens"]
            result["components"]["prover"]["output_tokens"] += cost_info["output_tokens"]

    # Fallback: load from per-problem full_records if consolidated doesn't exist
    if not prover_data:
        full_records_dir = os.path.join(run_path, "round0", "prover", "full_records")
        if os.path.isdir(full_records_dir):
            for file_name in os.listdir(full_records_dir):
                if file_name.endswith(".json"):
                    file_path = os.path.join(full_records_dir, file_name)
                    problem_data = _load_json_file(file_path)
                    if problem_data and isinstance(problem_data, list):
                        for record in problem_data:
                            cost_info = extract_detailed_cost(record)
                            result["components"]["prover"]["cost"] += cost_info["cost"]
                            result["components"]["prover"]["input_tokens"] += cost_info["input_tokens"]
                            result["components"]["prover"]["output_tokens"] += cost_info["output_tokens"]

    # Calculate totals
    result["total_cost"] = result["components"]["breakdown_parser"]["cost"] + result["components"]["prover"]["cost"]
    result["total_input_tokens"] = result["components"]["breakdown_parser"]["input_tokens"] + result["components"]["prover"]["input_tokens"]
    result["total_output_tokens"] = result["components"]["breakdown_parser"]["output_tokens"] + result["components"]["prover"]["output_tokens"]

    return result


def get_problem_costs(run_path: str, origin_problem_id: str) -> Dict[str, Any]:
    """
    Calculate costs and tokens for a specific problem across all breakdowns and attempts.

    Args:
        run_path: Path to run directory
        origin_problem_id: Problem ID to aggregate (e.g., "mathd_algebra_209")

    Returns:
        Dict with structure:
        {
            "problem_id": str,
            "total_cost": float,
            "total_input_tokens": int,
            "total_output_tokens": int,
            "breakdowns": {
                breakdown_id: {"cost": float, "input_tokens": int, "output_tokens": int, ...}
            }
        }
    """
    result = {
        "problem_id": origin_problem_id,
        "total_cost": 0.0,
        "total_input_tokens": 0,
        "total_output_tokens": 0,
        "breakdowns": {}
    }

    run_path = str(run_path)

    # Load breakdown_parser costs for this problem
    bp_file = os.path.join(run_path, "round0", "breakdown_parser", "parsed_breakdown.json")
    bp_data = _load_json_file(bp_file)
    if bp_data and isinstance(bp_data, list):
        for record in bp_data:
            metadata = record.get("metadata", {})
            if metadata.get("origin_problem_id") == origin_problem_id:
                breakdown_id = metadata.get("breakdown_id")
                if breakdown_id not in result["breakdowns"]:
                    result["breakdowns"][breakdown_id] = {
                        "cost": 0.0,
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "phases": {
                            "breakdown": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0},
                            "parser": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0},
                            "prover": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0}
                        }
                    }

                cost_info = extract_detailed_cost(record)
                if "detailed_cost" in record:
                    result["breakdowns"][breakdown_id]["phases"]["breakdown"]["cost"] += record["detailed_cost"].get("cost", 0.0)
                    result["breakdowns"][breakdown_id]["phases"]["breakdown"]["input_tokens"] += record["detailed_cost"].get("input_tokens", 0)
                    result["breakdowns"][breakdown_id]["phases"]["breakdown"]["output_tokens"] += record["detailed_cost"].get("output_tokens", 0)
                if "parser_detailed_cost" in record:
                    result["breakdowns"][breakdown_id]["phases"]["parser"]["cost"] += record["parser_detailed_cost"].get("cost", 0.0)
                    result["breakdowns"][breakdown_id]["phases"]["parser"]["input_tokens"] += record["parser_detailed_cost"].get("input_tokens", 0)
                    result["breakdowns"][breakdown_id]["phases"]["parser"]["output_tokens"] += record["parser_detailed_cost"].get("output_tokens", 0)

                result["breakdowns"][breakdown_id]["cost"] += cost_info["cost"]
                result["breakdowns"][breakdown_id]["input_tokens"] += cost_info["input_tokens"]
                result["breakdowns"][breakdown_id]["output_tokens"] += cost_info["output_tokens"]

    # Load prover costs for this problem
    prover_file = os.path.join(run_path, "round0", "prover", "full_records", f"{origin_problem_id}.json")
    prover_data = _load_json_file(prover_file)
    if prover_data and isinstance(prover_data, list):
        for record in prover_data:
            metadata = record.get("metadata", {})
            breakdown_id = metadata.get("breakdown_id")
            if breakdown_id not in result["breakdowns"]:
                result["breakdowns"][breakdown_id] = {
                    "cost": 0.0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "phases": {
                        "breakdown": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0},
                        "parser": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0},
                        "prover": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0}
                    }
                }

            cost_info = extract_detailed_cost(record)
            result["breakdowns"][breakdown_id]["phases"]["prover"]["cost"] += cost_info["cost"]
            result["breakdowns"][breakdown_id]["phases"]["prover"]["input_tokens"] += cost_info["input_tokens"]
            result["breakdowns"][breakdown_id]["phases"]["prover"]["output_tokens"] += cost_info["output_tokens"]

            result["breakdowns"][breakdown_id]["cost"] += cost_info["cost"]
            result["breakdowns"][breakdown_id]["input_tokens"] += cost_info["input_tokens"]
            result["breakdowns"][breakdown_id]["output_tokens"] += cost_info["output_tokens"]

    # Calculate totals
    for breakdown_data in result["breakdowns"].values():
        result["total_cost"] += breakdown_data["cost"]
        result["total_input_tokens"] += breakdown_data["input_tokens"]
        result["total_output_tokens"] += breakdown_data["output_tokens"]

    return result


def get_breakdown_costs(run_path: str, origin_problem_id: str, breakdown_id: int) -> Dict[str, Any]:
    """
    Calculate costs and tokens for a specific breakdown.

    Args:
        run_path: Path to run directory
        origin_problem_id: Problem ID
        breakdown_id: Breakdown ID

    Returns:
        Dict with structure:
        {
            "breakdown_key": str,
            "total_cost": float,
            "total_input_tokens": int,
            "total_output_tokens": int,
            "phases": {
                "breakdown": {...},
                "parser": {...},
                "prover": {...}
            }
        }
    """
    result = {
        "breakdown_key": f"{origin_problem_id}_r0_b{breakdown_id}",
        "total_cost": 0.0,
        "total_input_tokens": 0,
        "total_output_tokens": 0,
        "phases": {
            "breakdown": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0},
            "parser": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0},
            "prover": {"cost": 0.0, "input_tokens": 0, "output_tokens": 0}
        }
    }

    run_path = str(run_path)

    # Load breakdown_parser costs
    bp_file = os.path.join(run_path, "round0", "breakdown_parser", "parsed_breakdown.json")
    bp_data = _load_json_file(bp_file)
    if bp_data and isinstance(bp_data, list):
        for record in bp_data:
            metadata = record.get("metadata", {})
            if metadata.get("origin_problem_id") == origin_problem_id and metadata.get("breakdown_id") == breakdown_id:
                if "detailed_cost" in record:
                    dc = record["detailed_cost"]
                    result["phases"]["breakdown"]["cost"] += dc.get("cost", 0.0)
                    result["phases"]["breakdown"]["input_tokens"] += dc.get("input_tokens", 0)
                    result["phases"]["breakdown"]["output_tokens"] += dc.get("output_tokens", 0)
                if "parser_detailed_cost" in record:
                    pdc = record["parser_detailed_cost"]
                    result["phases"]["parser"]["cost"] += pdc.get("cost", 0.0)
                    result["phases"]["parser"]["input_tokens"] += pdc.get("input_tokens", 0)
                    result["phases"]["parser"]["output_tokens"] += pdc.get("output_tokens", 0)

    # Load prover costs
    prover_file = os.path.join(run_path, "round0", "prover", "full_records", f"{origin_problem_id}.json")
    prover_data = _load_json_file(prover_file)
    if prover_data and isinstance(prover_data, list):
        for record in prover_data:
            metadata = record.get("metadata", {})
            if metadata.get("origin_problem_id") == origin_problem_id and metadata.get("breakdown_id") == breakdown_id:
                cost_info = extract_detailed_cost(record)
                result["phases"]["prover"]["cost"] += cost_info["cost"]
                result["phases"]["prover"]["input_tokens"] += cost_info["input_tokens"]
                result["phases"]["prover"]["output_tokens"] += cost_info["output_tokens"]

    # Calculate totals
    for phase_data in result["phases"].values():
        result["total_cost"] += phase_data["cost"]
        result["total_input_tokens"] += phase_data["input_tokens"]
        result["total_output_tokens"] += phase_data["output_tokens"]

    return result
