from __future__ import annotations
import argparse
import re
import json
import os
import signal
import glob

import networkx as nx
import subprocess
import sys
from tqdm import tqdm
from datetime import datetime
from functools import wraps
from typing import Any, Callable, Dict, List, TypeVar, Union, Optional

from scripts.utils.assorted import load_benchmark_names
from scripts.utils.decode_graph import decode_graph_from_text
from scripts.generate_prompts import SIZE_PATTERNS
from scripts.utils.metadata import (
    collect_response_metadata,
    extract_metadata_from_filename,
)


# Type variable for the timeout decorator
T = TypeVar("T")


def timeout(
    seconds: int,
) -> Callable[[Callable[..., T]], Callable[..., Union[T, None]]]:
    """
    Decorator that applies a timeout to a function.

    Parameters:
    - seconds: Maximum allowed time in seconds

    Returns:
    - Decorated function that will raise TimeoutError if it takes too long
    """

    def decorator(func: Callable[..., T]) -> Callable[..., Union[T, None]]:
        @wraps(func)
        def wrapper(*args, **kwargs) -> Union[T, None]:
            def handle_timeout(signum, frame):
                raise TimeoutError(
                    f"Function '{func.__name__}' timed out after {seconds} seconds"
                )

            # Set the timeout handler
            previous_handler = signal.signal(signal.SIGALRM, handle_timeout)
            signal.alarm(seconds)

            try:
                result = func(*args, **kwargs)
                return result
            except TimeoutError as e:
                print(f"⚠️ {str(e)}")
                return None
            finally:
                # Reset the alarm and restore previous handler
                signal.alarm(0)
                signal.signal(signal.SIGALRM, previous_handler)

        return wrapper

    return decorator


@timeout(30)  # Set 30-second timeout for graph comparison
def compare_graphs(g1: nx.Graph, g2: nx.Graph, mode="label_consistent") -> bool:
    """
    Compares two graphs using either isomorphism or label consistency.

    Parameters:
    - g1, g2: NetworkX graphs to compare
    - mode: Comparison mode ("isomorphic" or "label_consistent")

    Returns:
    - bool or None: True if graphs match according to mode, False if not, None if timed out
    """
    try:
        if mode == "isomorphic":
            return nx.is_isomorphic(
                g1, g2, node_match=lambda n1, n2: n1.get("color") == n2.get("color")
            )
        elif mode == "label_consistent":
            # Check nodes
            if set(g1.nodes) != set(g2.nodes):
                return False

            # Check edges
            if set(g1.edges) != set(g2.edges):
                return False

            # Check node attributes (colors)
            for node in g1.nodes:
                if g1.nodes[node].get("color") != g2.nodes[node].get("color"):
                    return False

            return True
        else:
            raise ValueError(f"Unknown comparison mode: {mode}")
    except nx.NetworkXException as e:
        print(f"Error in graph comparison: {e}")
        return False


def extract_answer_from_response(response_text: str) -> str:
    """
    Extract the answer from a response based on XML format.
    Updated to handle <answer></answer> XML tags instead of ANSWER section markers.

    Parameters:
    - response_text: Raw response text

    Returns:
    - Extracted answer string

    Raises:
    - ValueError: If answer section cannot be extracted
    """
    try:
        # Use regex to find content between <answer></answer> tags
        # This pattern handles:
        # - Optional whitespace around tag names
        # - Case insensitive matching
        # - Multiline content
        # - Tags on separate lines
        pattern = r"<\s*answer\s*>(.*?)<\s*/\s*answer\s*>"
        match = re.search(pattern, response_text, re.DOTALL | re.IGNORECASE)

        if match:
            answer_content = match.group(1).strip()
            if not answer_content:
                raise ValueError(
                    "Found <answer></answer> tags but they contain no content."
                )
            return answer_content

        # If no XML tags found, provide helpful error message
        # Check for partial tags to give more specific feedback
        text_lower = response_text.lower()
        if "<answer" in text_lower:
            if "</answer>" not in text_lower:
                raise ValueError(
                    "Found opening <answer> tag but missing closing </answer> tag."
                )
            else:
                raise ValueError(
                    "Found <answer> tags but could not extract content. Check for malformed XML."
                )
        elif "</answer>" in text_lower:
            raise ValueError(
                "Found closing </answer> tag but missing opening <answer> tag."
            )
        else:
            # Fallback: try to extract without XML markers (for emergency compatibility)
            # This helps if the model outputs the answer but without proper tags
            response_stripped = response_text.strip()
            if response_stripped:
                return response_stripped
            else:
                raise ValueError(
                    "No <answer></answer> XML tags found and no content to extract. "
                    "Expected format: <answer>content</answer>"
                )

    except Exception as e:
        raise ValueError(f"Failed to extract answer: {e}") from e


def compare_answers(expected: str, actual: str, question_type: str) -> bool:
    """
    Compare expected and actual answers based on question type.

    Parameters:
    - expected: Expected answer
    - actual: Actual model response
    - question_type: Type of question

    Returns:
    - True if answers match, False otherwise
    """
    # Clean up both answers
    expected_clean = str(expected).strip().lower()
    actual_clean = str(actual).strip().lower()

    # For number-based questions
    if question_type in [
        "node_count",
        "edge_count",
        "blue_node_count",
        "colored_node_count",
        "max_degree",
        "min_degree",
        "component_count",
    ]:
        try:
            # Extract all numbers from both answers
            expected_nums = re.findall(r"\d+", expected_clean)
            actual_nums = re.findall(r"\d+", actual_clean)
            if expected_nums and actual_nums:
                # Compare expected first number with model's final answer (last number)
                exp_val = int(expected_nums[0])
                act_str = actual_nums[-1]
                return exp_val == int(act_str)
        except (ValueError, IndexError):
            pass

    # For yes/no questions
    elif question_type in ["is_connected", "is_tree", "has_cycles"]:
        yes_variants = ["yes", "true", "1", "correct"]

        expected_bool = any(variant in expected_clean for variant in yes_variants)
        actual_bool = any(variant in actual_clean for variant in yes_variants)

        return expected_bool == actual_bool

    # For full_output, we'll handle this separately with graph comparison
    elif question_type == "full_output":
        return False  # This should be handled by graph comparison instead

    # Default: exact string match
    return expected_clean == actual_clean


def get_ground_truth_answer(
    ground_truth_path: str, question_type: str, target: str
) -> str:
    """
    Get the ground truth answer for a question.

    Parameters:
    - ground_truth_path: Path to ground truth file
    - question_type: Type of question
    - target: 'input' or 'output'

    Returns:
    - Ground truth answer as string
    """
    if question_type == "full_output":
        # For full output, we return the path (graph comparison will be done separately)
        return ground_truth_path

    # For question-based responses, we need to compute the answer from the graph
    try:
        if target == "input":
            # For input questions, analyze the input graph
            # We need to find the input graph path
            input_path = ground_truth_path.replace("/output/", "/input/")
            graph = decode_graph_from_text(input_path, response=False)
        else:
            # For output questions, analyze the output graph
            graph = decode_graph_from_text(ground_truth_path, response=False)

        # Compute the answer based on question type
        if question_type == "node_count":
            return str(len(graph.nodes))
        elif question_type == "edge_count":
            return str(len(graph.edges))
        elif question_type == "blue_node_count":
            blue_nodes = [
                n for n, data in graph.nodes(data=True) if data.get("color") == "blue"
            ]
            return str(len(blue_nodes))
        elif question_type == "colored_node_count":
            colored_nodes = [
                n
                for n, data in graph.nodes(data=True)
                if data.get("color", "grey") != "grey"
            ]
            return str(len(colored_nodes))
        elif question_type == "is_connected":
            return "yes" if nx.is_connected(graph) else "no"
        elif question_type == "is_tree":
            return "yes" if nx.is_tree(graph) else "no"
        elif question_type == "has_cycles":
            return "no" if nx.is_forest(graph) else "yes"
        elif question_type == "max_degree":
            if graph.nodes:
                return str(max(dict(graph.degree()).values()))
            return "0"
        elif question_type == "min_degree":
            if graph.nodes:
                return str(min(dict(graph.degree()).values()))
            return "0"
        elif question_type == "component_count":
            return str(nx.number_connected_components(graph))
        else:
            raise ValueError(f"Unknown question type: {question_type}")

    except Exception as e:
        raise ValueError(
            f"Failed to compute ground truth for {question_type}: {e}"
        ) from e


def find_consolidated_token_files() -> Dict[str, List[str]]:
    """
    Find all consolidated token files from cluster results.

    Returns:
    - Dict mapping model names to lists of their consolidated token file paths
    """
    token_files: Dict[str, List[str]] = {}

    pattern = "llm-inference/results/*_tokens.json"
    for file_path in glob.glob(pattern):
        filename = os.path.basename(file_path)
        if not filename.endswith("_tokens.json"):
            continue
        base_name = filename[: -len("_tokens.json")]
        if "_results" in base_name:
            model_name = base_name.split("_results")[0]
        else:
            model_name = base_name

        model_name = model_name.replace("_", "-").lower()
        token_files.setdefault(model_name, []).append(file_path)

    return token_files


def load_consolidated_token_data(file_path: str) -> Optional[Dict[str, Any]]:
    """Load consolidated token data from a file."""
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            return json.load(f)
    except (json.JSONDecodeError, IOError, OSError) as e:
        print(f"⚠️ Could not load consolidated token data from {file_path}: {e}")
        return None


def get_response_id_from_path(response_path: str) -> Optional[str]:
    """Extract response ID from response file path for matching with consolidated token data."""
    try:
        filename = os.path.basename(response_path)
        metadata = extract_metadata_from_filename(filename)

        if not metadata:
            return None

        # Get benchmark and graph_type from path
        path_parts = response_path.split(os.sep)
        if len(path_parts) >= 3 and "datasets" in path_parts:
            datasets_idx = path_parts.index("datasets")
            if len(path_parts) > datasets_idx + 2:
                benchmark = path_parts[datasets_idx + 1]
                graph_type = path_parts[datasets_idx + 2]
            else:
                return None
        else:
            return None

        # Construct the ID matching cluster results format
        response_id = f"{benchmark}_{graph_type}_{metadata['encoding']}-{metadata['size_pattern']}-{metadata['system_prompt']}-{metadata['num_pairs']}-{metadata['question_type']}-{metadata['target']}"

        return response_id

    except (ValueError, IndexError, KeyError, OSError) as e:
        print(f"⚠️ Could not extract response ID from {response_path}: {e}")
        return None


def build_consolidated_token_cache() -> Dict[str, Dict]:
    """Build a cache of all consolidated token data for efficient lookup.
    Supports multiple batch files per model, merging their token entries."""
    cache: Dict[str, Dict] = {}
    consolidated_files = find_consolidated_token_files()

    for model_name, file_paths in consolidated_files.items():
        merged: Dict[str, Any] = {"per_response_tokens": {}}
        loaded_paths: List[str] = []
        for file_path in file_paths:
            token_data = load_consolidated_token_data(file_path)
            if not token_data or "per_response_tokens" not in token_data:
                continue
            merged["per_response_tokens"].update(token_data["per_response_tokens"])
            loaded_paths.append(file_path)

        if merged["per_response_tokens"]:
            cache[model_name] = merged
            cache[f"{model_name}_file_path"] = loaded_paths
    return cache


def load_token_data(
    response_path: str, consolidated_token_cache: Optional[Dict[str, Dict]] = None
) -> Optional[Dict[str, Any]]:
    """
    Enhanced token data loader that tries consolidated files first, then individual files.

    This replaces the original load_token_data function in evaluate_responses.py

    Parameters:
    - response_path: Path to the response file
    - consolidated_token_cache: Optional cache of consolidated token data by model

    Returns:
    - Dict with token usage data or None if not available
    """

    # Method 1: Try consolidated token data (preferred)
    if consolidated_token_cache:
        filename = os.path.basename(response_path)
        metadata = extract_metadata_from_filename(filename)

        if metadata and metadata.get("model"):
            model_name = metadata["model"]

            if model_name in consolidated_token_cache:
                consolidated_data = consolidated_token_cache[model_name]
                response_id = get_response_id_from_path(response_path)

                if response_id and "per_response_tokens" in consolidated_data:
                    token_data = consolidated_data["per_response_tokens"].get(
                        response_id
                    )
                    if token_data:
                        token_data = token_data.copy()
                        token_data["source"] = "consolidated"
                        return token_data

    # Method 2: Try individual token file (fallback)
    individual_token_path = response_path.replace(".txt", "_tokens.json")

    if os.path.exists(individual_token_path):
        try:
            with open(individual_token_path, "r", encoding="utf-8") as f:
                token_data = json.load(f)
                token_data["source"] = "individual"
                return token_data
        except (json.JSONDecodeError, IOError, OSError) as e:
            print(
                f"⚠️ Could not load individual token data from {individual_token_path}: {e}"
            )

    return None


def evaluate_single_response(
    benchmark: str,
    graph_type: str,
    size_category: str,
    response_path: str,
    ground_truth_path: str,
    metadata: Dict[str, Any],
    failure_tracker: Dict[str, List],
    comparison_mode: str = "label_consistent",
    consolidated_token_cache: Optional[Dict[str, Dict]] = None,
) -> Dict[str, Any]:
    """
    Evaluates a single AI-generated response by comparing it to ground truth.
    Now includes token usage data when available.

    Parameters:
    - benchmark (str): Name of the benchmark task.
    - graph_type (str): Type of graph used.
    - size_category (str): Size category ('small', 'medium', 'large').
    - response_path (str): File path to the model's response.
    - ground_truth_path (str): File path to the expected correct output.
    - metadata (dict): Additional metadata from filename.
    - failure_tracker (dict): Dictionary to track different types of failures
    - comparison_mode (str): Mode for graph comparison ('isomorphic' or 'label_consistent')

    Returns:
    - dict: A dictionary containing evaluation results and metadata.
    """
    question_type = metadata.get("question_type", "full_output")
    target = metadata.get("target", "output")

    try:
        # Check if both files exist
        if not os.path.exists(response_path):
            failure_tracker["missing_responses"].append(response_path)
            raise FileNotFoundError(f"Response file not found: {response_path}")
        if not os.path.exists(ground_truth_path):
            failure_tracker["missing_ground_truth"].append(ground_truth_path)
            raise FileNotFoundError(f"Ground truth file not found: {ground_truth_path}")

        # Load response
        try:
            with open(response_path, "r", encoding="utf-8") as f:
                response_content = f.read()
        except Exception as e:
            raise ValueError(f"Failed to read response file: {e}") from e

        # Load token usage data with consolidated cache
        token_usage = load_token_data(response_path, consolidated_token_cache)

        parse_error = None
        is_correct = False
        structural_match = True  # Default for question-based responses

        response_graph_metadata = {}
        ground_truth_graph_metadata = {}

        if question_type == "full_output":
            # Handle full graph output
            try:
                response_graph = decode_graph_from_text(
                    response_path, encoding_type=metadata.get("encoding"), response=True
                )
                response_graph_metadata = generate_graph_metadata(response_graph)
            except ValueError as e:
                parse_error = str(e)
                failure_tracker["parse_errors"].append(
                    {"file": response_path, "error": parse_error}
                )
                raise

            # Load ground truth graph
            ground_truth_graph = decode_graph_from_text(
                ground_truth_path,
                encoding_type=metadata.get("encoding"),
                response=False,
            )
            ground_truth_graph_metadata = generate_graph_metadata(ground_truth_graph)

            # Compare graphs
            response_node_count = len(response_graph.nodes)
            response_edge_count = len(response_graph.edges)
            ground_truth_node_count = len(ground_truth_graph.nodes)
            ground_truth_edge_count = len(ground_truth_graph.edges)

            if (
                response_node_count != ground_truth_node_count
                or response_edge_count != ground_truth_edge_count
            ):
                is_correct = False
                structural_match = False
            else:
                structural_match = True
                is_correct_result = compare_graphs(
                    response_graph, ground_truth_graph, mode=comparison_mode
                )
                if is_correct_result is None:
                    failure_tracker["comparison_timeouts"].append(
                        {
                            "benchmark": benchmark,
                            "graph_type": graph_type,
                            "size": size_category,
                            "response_path": response_path,
                            "mode": comparison_mode,
                        }
                    )
                    is_correct = False
                    tqdm.write(
                        f"⚠️ Graph comparison ({comparison_mode} mode) timed out for {response_path}"
                    )
                else:
                    is_correct = is_correct_result
        else:
            # Handle question-based response
            try:
                # Extract answer from response using XML format
                actual_answer = extract_answer_from_response(response_content)

                # Get ground truth answer
                expected_answer = get_ground_truth_answer(
                    ground_truth_path, question_type, target
                )

                # Compare answers
                is_correct = compare_answers(
                    expected_answer, actual_answer, question_type
                )

                # For question-based responses, we can't generate full graph metadata
                response_graph_metadata = {"answer": actual_answer}
                ground_truth_graph_metadata = {"answer": expected_answer}

            except (ValueError, IndexError) as e:
                parse_error = str(e)
                failure_tracker["parse_errors"].append(
                    {"file": response_path, "error": parse_error}
                )
                is_correct = False

        # Collect and save response metadata with token usage
        benchmark_path = f"datasets/{benchmark}/{graph_type}"
        collect_response_metadata(
            benchmark_path=benchmark_path,
            response_path=response_path,
            encoding=metadata.get("encoding"),
            ground_truth_path=ground_truth_path,
            is_correct=is_correct,
            prompt_metadata={
                "size_pattern": metadata.get("size_pattern"),
                "system_prompt": metadata.get("system_prompt"),
                "n_pairs": metadata.get("n_pairs"),
                "question_type": metadata.get("question_type"),
                "target": metadata.get("target"),
            },
            comparison_mode=comparison_mode,
            parse_error=parse_error,
            token_usage=token_usage,  # Pass token usage data
        )

        # Include token usage in the result
        result = {
            "benchmark": benchmark,
            "graph_type": graph_type,
            "size": size_category,
            "encoding": metadata.get("encoding"),
            "size_pattern": metadata.get("size_pattern"),
            "system_prompt": metadata.get("system_prompt"),
            "n_pairs": metadata.get("n_pairs"),
            "model": metadata.get("model"),
            "question_type": metadata.get("question_type"),
            "target": metadata.get("target"),
            "comparison_mode": comparison_mode,
            "correct": is_correct,
            "details": {
                "structural_match": structural_match,
                "response_metadata": response_graph_metadata,
                "ground_truth_metadata": ground_truth_graph_metadata,
            },
            "token_usage": token_usage,  # Include token data in results
            "timestamp": datetime.now().isoformat(),
            "response_path": response_path,
            "ground_truth_path": ground_truth_path,
        }

        return result

    except (FileNotFoundError, ValueError, nx.NetworkXException) as e:
        error_message = str(e)
        tqdm.write(f"❌ Failed evaluating {response_path}: {error_message}")

        # Track non-specific errors
        if not any(
            error_type in error_message.lower()
            for error_type in ["file not found", "no <answer></answer> xml tags found"]
        ):
            failure_tracker["other_errors"].append(
                {"file": response_path, "error": error_message}
            )

        # Load token usage even for failed responses
        token_usage = load_token_data(response_path, consolidated_token_cache)

        # Collect metadata even for failed responses
        benchmark_path = f"datasets/{benchmark}/{graph_type}"
        collect_response_metadata(
            benchmark_path=benchmark_path,
            response_path=response_path,
            encoding=metadata.get("encoding"),
            ground_truth_path=ground_truth_path,
            is_correct=False,
            prompt_metadata={
                "size_pattern": metadata.get("size_pattern"),
                "system_prompt": metadata.get("system_prompt"),
                "n_pairs": metadata.get("n_pairs"),
                "question_type": metadata.get("question_type"),
                "target": metadata.get("target"),
            },
            comparison_mode=comparison_mode,
            parse_error=error_message,
            token_usage=token_usage,
        )

        return {
            "benchmark": benchmark,
            "graph_type": graph_type,
            "size": size_category,
            "encoding": metadata.get("encoding"),
            "size_pattern": metadata.get("size_pattern"),
            "system_prompt": metadata.get("system_prompt"),
            "n_pairs": metadata.get("n_pairs"),
            "model": metadata.get("model"),
            "question_type": metadata.get("question_type"),
            "target": metadata.get("target"),
            "comparison_mode": comparison_mode,
            "correct": False,
            "details": {"error": error_message},
            "token_usage": token_usage,  # Include token data even for failures
            "timestamp": datetime.now().isoformat(),
            "response_path": response_path,
            "ground_truth_path": ground_truth_path,
        }


def generate_graph_metadata(graph: nx.Graph) -> Dict[str, Any]:
    """Generate basic metadata for a graph (simplified version)."""
    try:
        color_distribution = {}
        for _, data in graph.nodes(data=True):
            color = data.get("color", "grey")
            color_distribution[color] = color_distribution.get(color, 0) + 1

        return {
            "node_count": len(graph.nodes),
            "edge_count": len(graph.edges),
            "color_distribution": color_distribution,
        }
    except (KeyError, TypeError):
        return {}


def categorize_parse_errors(parse_errors: List[Dict]) -> Dict[str, int]:
    """
    Categorizes parse errors by error type and counts occurrences.

    Parameters:
    - parse_errors: List of dictionaries containing error information

    Returns:
    - Dict mapping error types to their counts
    """
    error_types = {}

    for error in parse_errors:
        error_msg = error["error"]

        # Extract the core error message - simplify for better grouping
        if "invalid literal for int()" in error_msg:
            error_type = "invalid literal for int() with base 10"
        elif "No <answer></answer> XML tags found" in error_msg:
            error_type = "missing <answer> XML tags"
        elif "Found opening <answer> tag but missing closing" in error_msg:
            error_type = "unclosed <answer> tag"
        elif "Found closing </answer> tag but missing opening" in error_msg:
            error_type = "missing opening <answer> tag"
        elif "malformed XML" in error_msg.lower():
            error_type = "malformed XML tags"
        else:
            # Use the first 50 chars as the error type
            error_type = error_msg[:50] + ("..." if len(error_msg) > 50 else "")

        error_types[error_type] = error_types.get(error_type, 0) + 1

    return error_types


def analyze_failures_by_model(failure_list: List, key: str = "file") -> Dict[str, int]:
    """
    Analyzes which models had failures.

    Parameters:
    - failure_list: List of failure items (could be dicts or strings)
    - key: The key to use if items are dictionaries

    Returns:
    - Dict mapping model names to error counts
    """
    model_counts = {}

    for item in failure_list:
        # Handle both string paths and dictionaries
        if isinstance(item, dict):
            file_path = item[key]
        else:
            file_path = item

        # Extract model name from file path
        try:
            filename = os.path.basename(file_path)
            if not filename.endswith(".txt"):
                continue

            # Extract model using our metadata extraction function
            try:
                metadata = extract_metadata_from_filename(filename)
                model_name = metadata.get("model", "unknown")
                model_counts[model_name] = model_counts.get(model_name, 0) + 1
            except (ValueError, KeyError, IndexError):
                # If we can't parse the filename properly, use a simple approach
                parts = filename.split("_")
                if len(parts) >= 3:
                    model_name = parts[-1].replace(".txt", "")
                    model_counts[model_name] = model_counts.get(model_name, 0) + 1
        except (ValueError, KeyError, IndexError, OSError):
            continue

    return model_counts


def generate_failure_summary(
    failure_tracker: Dict[str, List], show_all_errors: bool = False
) -> str:
    """
    Generates a formatted summary of all failures.

    Parameters:
    - failure_tracker: Dictionary containing different types of failures
    - show_all_errors: If True, list every file with a parse error; otherwise show only a few examples

    Parameters:
    - failure_tracker: Dictionary containing different types of failures

    Returns:
    - Formatted string with the failure summary
    """
    summary = []
    separator = "=" * 80

    summary.append(separator)
    summary.append("EVALUATION FAILURE SUMMARY")
    summary.append(separator)

    # Parse errors
    parse_errors = failure_tracker["parse_errors"]
    if parse_errors:
        error_types = categorize_parse_errors(parse_errors)
        model_counts = analyze_failures_by_model(parse_errors)

        summary.append(
            f"\n🔍 PARSE ERRORS: {len(parse_errors)} responses could not be parsed"
        )
        summary.append("Error types:")
        for error_type, count in error_types.items():
            summary.append(f"  - {error_type}: {count} occurrences")

        summary.append("Models with parse errors:")
        for model, count in model_counts.items():
            summary.append(f"  - {model}: {count} occurrences")

        # Show either all errors or just a few examples
        if show_all_errors:
            summary.append("\nAll parse errors:")
            for err in parse_errors:
                summary.append(f"  - {err['file']}: {err['error']}")
        else:
            summary.append("\nExample parse errors:")
            for i, err in enumerate(parse_errors[:3], 1):
                summary.append(f"  {i}. {err['file']}")
                summary.append(f"     Error: {err['error']}")
    else:
        summary.append("\n✅ No parse errors detected.")

    # Comparison timeouts
    timeouts = failure_tracker["comparison_timeouts"]
    if timeouts:
        summary.append(f"\n⏱️ COMPARISON TIMEOUTS: {len(timeouts)} checks timed out")

        # Group by benchmark/graph_type/size
        timeouts_by_category = {}
        timeouts_by_mode = {}
        for timeout in timeouts:
            key = f"{timeout['benchmark']}/{timeout['graph_type']}/{timeout['size']}"
            timeouts_by_category[key] = timeouts_by_category.get(key, 0) + 1

            mode = timeout.get("mode", "isomorphic")
            timeouts_by_mode[mode] = timeouts_by_mode.get(mode, 0) + 1

        summary.append("Timeout distribution:")
        for category, count in timeouts_by_category.items():
            summary.append(f"  - {category}: {count} timeouts")

        summary.append("\nTimeouts by comparison mode:")
        for mode, count in timeouts_by_mode.items():
            summary.append(f"  - {mode}: {count} timeouts")

        # Show most affected models
        model_counts = analyze_failures_by_model(timeouts, key="response_path")
        if model_counts:
            summary.append("\nModels with comparison timeouts:")
            for model, count in model_counts.items():
                summary.append(f"  - {model}: {count} timeouts")
    else:
        summary.append("\n✅ No comparison timeouts detected.")

    # Missing response files
    missing_responses = failure_tracker["missing_responses"]
    if missing_responses:
        summary.append(
            f"\n📄 MISSING RESPONSES: {len(missing_responses)} response files not found"
        )

        # Show a few examples
        summary.append("Example missing responses:")
        for path in missing_responses[:3]:
            summary.append(f"  - {path}")
    else:
        summary.append("\n✅ No missing response files detected.")

    # Missing ground truth files
    missing_gt = failure_tracker["missing_ground_truth"]
    if missing_gt:
        summary.append(
            f"\n📄 MISSING GROUND TRUTH: {len(missing_gt)} ground truth files not found"
        )

        # Show a few examples
        summary.append("Example missing ground truth files:")
        for path in missing_gt[:3]:
            summary.append(f"  - {path}")
    else:
        summary.append("\n✅ No missing ground truth files detected.")

    # Other errors
    other_errors = failure_tracker["other_errors"]
    if other_errors:
        summary.append(f"\n❓ OTHER ERRORS: {len(other_errors)} other errors occurred")

        # Group errors by type
        error_types = {}
        for error in other_errors:
            error_msg = error["error"]
            error_type = error_msg[:50] + ("..." if len(error_msg) > 50 else "")
            error_types[error_type] = error_types.get(error_type, 0) + 1

        summary.append("Error types:")
        for error_type, count in error_types.items():
            summary.append(f"  - {error_type}: {count} occurrences")
    else:
        summary.append("\n✅ No other errors detected.")

    summary.append(separator)

    return "\n".join(summary)


def evaluate_responses(
    output_file: str = None,
    append: bool = False,
    overwrite: bool = False,
    repetition: int = 1,
    comparison_mode: str = "isomorphic",
    skip_encodings: List[str] = None,
    models: Optional[List[str]] = None,
    system_prompts: Optional[List[str]] = None,
    patterns: Optional[List[str]] = None,
    show_all_errors: bool = False,
    skip_parse_errors: bool = False,
) -> str:
    """
    Enhanced version with automatic deduplication and smart file handling.
    """

    # Smart file handling: find existing file if append is requested but no output file specified
    if append and output_file is None:
        # Find the most recent evaluation file
        existing_files = glob.glob("evaluation_data/evaluation_results_*.json")
        if existing_files:
            output_file = max(existing_files, key=os.path.getctime)
            print(f"📁 Auto-detected existing evaluation file: {output_file}")
        else:
            print("⚠️ No existing evaluation files found, creating new one")
            append = False

    if output_file is None:
        output_file = f"evaluation_data/evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"

    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    # Load consolidated token cache
    print("📊 Loading consolidated token data...")
    consolidated_token_cache = build_consolidated_token_cache()

    if consolidated_token_cache:
        models_with_tokens = [
            k for k in consolidated_token_cache.keys() if not k.endswith("_file_path")
        ]
        print(
            f"✅ Found consolidated token data for {len(models_with_tokens)} models: {', '.join(models_with_tokens)}"
        )

    # Initialize failure tracker
    failure_tracker = {
        "parse_errors": [],
        "comparison_timeouts": [],
        "missing_responses": [],
        "missing_ground_truth": [],
        "other_errors": [],
    }

    # Load existing evaluation results
    evaluation_results = []
    if os.path.exists(output_file):
        if not append and not overwrite:
            print(f"⚠️ File {output_file} already exists. Use --append or --overwrite.")
            return None
        if overwrite:
            evaluation_results = []
            print(f"🗑️ Overwriting existing file: {output_file}")
        else:
            try:
                with open(output_file, "r", encoding="utf-8") as f:
                    evaluation_results = json.load(f)
                print(
                    f"📂 Loaded {len(evaluation_results)} existing evaluation results"
                )
            except json.JSONDecodeError:
                print(f"⚠️ Could not parse existing file {output_file}. Starting fresh.")
                evaluation_results = []

    # Build set of already-evaluated response paths for deduplication
    already_evaluated = set()
    if evaluation_results:
        for result in evaluation_results:
            if "response_path" in result:
                # Store relative path as used in the results
                already_evaluated.add(result["response_path"])
        print(
            f"🔍 Found {len(already_evaluated)} already-evaluated responses that will be skipped"
        )

    benchmarks = load_benchmark_names()
    new_results_count = 0
    skipped_count = 0

    # Define valid encodings
    valid_encodings = ["adjacency", "incident"]
    if skip_encodings:
        for encoding in skip_encodings:
            if encoding in valid_encodings:
                valid_encodings.remove(encoding)

    # Build task list
    tasks = []
    for benchmark in benchmarks:
        benchmark_dir = f"datasets/{benchmark}"
        if not os.path.exists(benchmark_dir):
            continue
        for graph_type in os.listdir(benchmark_dir):
            graph_type_dir = os.path.join(benchmark_dir, graph_type)
            if not os.path.isdir(graph_type_dir):
                continue
            response_dir = os.path.join(graph_type_dir, "responses")
            if not os.path.exists(response_dir):
                continue
            for filename in os.listdir(response_dir):
                if not filename.endswith(".txt"):
                    continue
                metadata = extract_metadata_from_filename(filename)
                if metadata is None:
                    continue
                if models and metadata["model"] not in models:
                    continue
                if system_prompts and metadata["system_prompt"] not in system_prompts:
                    continue
                if patterns and metadata["size_pattern"] not in patterns:
                    continue
                if metadata["encoding"] not in valid_encodings:
                    continue
                metadata["repetition"] = repetition
                tasks.append((benchmark, graph_type, response_dir, filename, metadata))

    print(f"📊 Found {len(tasks)} total response files to process")

    # Track token statistics
    token_stats = {
        "responses_with_token_data": 0,
        "from_consolidated": 0,
        "from_individual": 0,
        "no_token_data": 0,
    }

    # Process tasks with deduplication and progress tracking
    pbar = tqdm(tasks, desc="Evaluating responses", unit="response")
    for benchmark, graph_type, response_dir, filename, metadata in pbar:
        try:
            # Determine test size and paths
            size_pattern = metadata["size_pattern"]
            if "-" in size_pattern:
                try:
                    sizes = [int(s) for s in size_pattern.split("-")]
                    test_size = str(sizes[-1])
                except ValueError:
                    if size_pattern in SIZE_PATTERNS:
                        sizes = SIZE_PATTERNS[size_pattern]
                        test_size = str(sizes[-1])
                    else:
                        tqdm.write(
                            f"⚠️ Cannot determine test size from pattern: {size_pattern} in {filename}"
                        )
                        continue
            elif size_pattern in SIZE_PATTERNS:
                sizes = SIZE_PATTERNS[size_pattern]
                test_size = str(sizes[-1])
            else:
                tqdm.write(
                    f"⚠️ Cannot determine test size from pattern: {size_pattern} in {filename}"
                )
                continue

            response_path = os.path.join(response_dir, filename)

            # Check if this response has already been evaluated
            if response_path in already_evaluated:
                skipped_count += 1
                pbar.set_description(
                    f"Skipping evaluated: {os.path.basename(filename)}"
                )
                continue

            graph_type_dir = os.path.dirname(response_dir)
            target = metadata.get("target", "output")
            if target == "input":
                ground_truth_path = f"{graph_type_dir}/textual/input/{test_size}/{metadata['encoding']}1.txt"
            else:
                ground_truth_path = f"{graph_type_dir}/textual/output/{test_size}/{metadata['encoding']}1.txt"

            pbar.set_description(f"Evaluating: {os.path.basename(filename)}")

            # NOTE: Using simplified evaluation that doesn't save individual metadata
            # to avoid the slow file I/O that was causing crashes
            result = evaluate_single_response_fast(
                benchmark,
                graph_type,
                test_size,
                response_path,
                ground_truth_path,
                metadata,
                failure_tracker,
                comparison_mode,
                consolidated_token_cache,
            )

            # Track token statistics
            if result.get("token_usage"):
                token_stats["responses_with_token_data"] += 1
                if result["token_usage"].get("source") == "consolidated":
                    token_stats["from_consolidated"] += 1
                elif result["token_usage"].get("source") == "individual":
                    token_stats["from_individual"] += 1
            else:
                token_stats["no_token_data"] += 1

            if skip_parse_errors and result.get("details", {}).get("error"):
                continue

            evaluation_results.append(result)
            new_results_count += 1

            # Save periodically to avoid losing work
            if new_results_count % 1000 == 0:
                with open(output_file, "w", encoding="utf-8") as f:
                    json.dump(evaluation_results, f, indent=2)
                pbar.set_description(f"Saved checkpoint ({new_results_count} new)")

        except (ValueError, IndexError, OSError) as e:
            tqdm.write(f"⚠️ Failed to parse or evaluate {filename}: {e}")
            continue

    # Final save
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(evaluation_results, f, indent=2)

    print(f"✅ Added {new_results_count} new evaluations to {output_file}")
    if skipped_count > 0:
        print(f"⏩ Skipped {skipped_count} already-evaluated responses")
    print(f"📊 Total evaluations in file: {len(evaluation_results)}")

    # Print token statistics
    print("\n📊 Token Data Statistics:")
    print(f"   Responses with token data: {token_stats['responses_with_token_data']}")
    print(f"   From consolidated files: {token_stats['from_consolidated']}")
    print(f"   From individual files: {token_stats['from_individual']}")
    print(f"   No token data: {token_stats['no_token_data']}")

    # Generate failure summary
    failure_summary = generate_failure_summary(
        failure_tracker, show_all_errors=show_all_errors
    )
    print(failure_summary)

    return output_file


def evaluate_single_response_fast(
    benchmark: str,
    graph_type: str,
    size_category: str,
    response_path: str,
    ground_truth_path: str,
    metadata: Dict[str, Any],
    failure_tracker: Dict[str, List],
    comparison_mode: str = "label_consistent",
    consolidated_token_cache: Optional[Dict[str, Dict]] = None,
) -> Dict[str, Any]:
    """
    Fast version of evaluate_single_response that doesn't save individual metadata files.
    This avoids the slow file I/O that was causing crashes.
    """
    question_type = metadata.get("question_type", "full_output")
    target = metadata.get("target", "output")

    try:
        # Check if both files exist
        if not os.path.exists(response_path):
            failure_tracker["missing_responses"].append(response_path)
            raise FileNotFoundError(f"Response file not found: {response_path}")
        if not os.path.exists(ground_truth_path):
            failure_tracker["missing_ground_truth"].append(ground_truth_path)
            raise FileNotFoundError(f"Ground truth file not found: {ground_truth_path}")

        # Load response
        try:
            with open(response_path, "r", encoding="utf-8") as f:
                response_content = f.read()
        except Exception as e:
            raise ValueError(f"Failed to read response file: {e}") from e

        # Load token usage data with consolidated cache
        token_usage = load_token_data(response_path, consolidated_token_cache)

        parse_error = None
        is_correct = False
        structural_match = True

        response_graph_metadata = {}
        ground_truth_graph_metadata = {}

        if question_type == "full_output":
            # Handle full graph output
            try:
                response_graph = decode_graph_from_text(
                    response_path, encoding_type=metadata.get("encoding"), response=True
                )
                response_graph_metadata = generate_graph_metadata(response_graph)
            except ValueError as e:
                parse_error = str(e)
                failure_tracker["parse_errors"].append(
                    {"file": response_path, "error": parse_error}
                )
                raise

            # Load ground truth graph
            ground_truth_graph = decode_graph_from_text(
                ground_truth_path,
                encoding_type=metadata.get("encoding"),
                response=False,
            )
            ground_truth_graph_metadata = generate_graph_metadata(ground_truth_graph)

            # Compare graphs
            response_node_count = len(response_graph.nodes)
            response_edge_count = len(response_graph.edges)
            ground_truth_node_count = len(ground_truth_graph.nodes)
            ground_truth_edge_count = len(ground_truth_graph.edges)

            if (
                response_node_count != ground_truth_node_count
                or response_edge_count != ground_truth_edge_count
            ):
                is_correct = False
                structural_match = False
            else:
                structural_match = True
                is_correct_result = compare_graphs(
                    response_graph, ground_truth_graph, mode=comparison_mode
                )
                if is_correct_result is None:
                    failure_tracker["comparison_timeouts"].append(
                        {
                            "benchmark": benchmark,
                            "graph_type": graph_type,
                            "size": size_category,
                            "response_path": response_path,
                            "mode": comparison_mode,
                        }
                    )
                    is_correct = False
                else:
                    is_correct = is_correct_result
        else:
            # Handle question-based response
            try:
                actual_answer = extract_answer_from_response(response_content)
                expected_answer = get_ground_truth_answer(
                    ground_truth_path, question_type, target
                )
                is_correct = compare_answers(
                    expected_answer, actual_answer, question_type
                )
                response_graph_metadata = {"answer": actual_answer}
                ground_truth_graph_metadata = {"answer": expected_answer}

            except (ValueError, IndexError) as e:
                parse_error = str(e)
                failure_tracker["parse_errors"].append(
                    {"file": response_path, "error": parse_error}
                )
                is_correct = False

        # Return result without saving individual metadata (for speed)
        result = {
            "benchmark": benchmark,
            "graph_type": graph_type,
            "size": size_category,
            "encoding": metadata.get("encoding"),
            "size_pattern": metadata.get("size_pattern"),
            "system_prompt": metadata.get("system_prompt"),
            "n_pairs": metadata.get("n_pairs"),
            "model": metadata.get("model"),
            "question_type": metadata.get("question_type"),
            "target": metadata.get("target"),
            "comparison_mode": comparison_mode,
            "correct": is_correct,
            "details": {
                "structural_match": structural_match,
                "response_metadata": response_graph_metadata,
                "ground_truth_metadata": ground_truth_graph_metadata,
            },
            "token_usage": token_usage,
            "timestamp": datetime.now().isoformat(),
            "response_path": response_path,
            "ground_truth_path": ground_truth_path,
        }

        return result

    except (FileNotFoundError, ValueError, nx.NetworkXException) as e:
        error_message = str(e)

        # Load token usage even for failed responses
        token_usage = load_token_data(response_path, consolidated_token_cache)

        return {
            "benchmark": benchmark,
            "graph_type": graph_type,
            "size": size_category,
            "encoding": metadata.get("encoding"),
            "size_pattern": metadata.get("size_pattern"),
            "system_prompt": metadata.get("system_prompt"),
            "n_pairs": metadata.get("n_pairs"),
            "model": metadata.get("model"),
            "question_type": metadata.get("question_type"),
            "target": metadata.get("target"),
            "comparison_mode": comparison_mode,
            "correct": False,
            "details": {"error": error_message},
            "token_usage": token_usage,
            "timestamp": datetime.now().isoformat(),
            "response_path": response_path,
            "ground_truth_path": ground_truth_path,
        }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluate model responses and generate visualizations"
    )

    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output file for evaluation results (default: auto-generated)",
    )
    parser.add_argument(
        "--append",
        action="store_true",
        help="Append to existing evaluation file if it exists",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite existing evaluation file if it exists",
    )
    parser.add_argument(
        "--repetition",
        type=int,
        default=1,
        help="Repetition number for this evaluation (for experiments with multiple runs)",
    )
    parser.add_argument(
        "--visualize",
        type=str,
        default=None,
        help="Generate visualizations from an existing evaluation file",
    )
    parser.add_argument(
        "--vis_dir",
        type=str,
        default="evaluation_data/visualizations",
        help="Directory to save visualizations",
    )
    parser.add_argument(
        "--show-all-errors",
        action="store_true",
        help="After evaluation, list every file with parse errors",
    )
    parser.add_argument(
        "--skip-parse-errors",
        action="store_true",
        help="Skip responses that fail to parse instead of counting them as incorrect",
    )
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        help="Filter results to include only specified models",
    )
    parser.add_argument(
        "--system_prompts",
        nargs="+",
        type=str,
        help="Filter results to include only specified system prompts",
    )
    parser.add_argument(
        "--patterns",
        nargs="+",
        type=str,
        help="Filter results to include only specified size patterns",
    )
    parser.add_argument(
        "--comparison-mode",
        type=str,
        choices=["isomorphic", "label_consistent"],
        default="label_consistent",
        help="Mode for graph comparison (default: label_consistent)",
    )
    parser.add_argument(
        "--skip-encodings",
        nargs="+",
        default=["expert"],
        help="List of encoding types to skip (default: expert)",
    )

    args = parser.parse_args()

    # Check if we're only generating visualizations
    if args.visualize:
        print(f"Generating visualizations from {args.visualize}...")
        # Use the new visualization system instead
        subprocess.run(
            [
                "python",
                "-m",
                "scripts.visualization.main",
                args.visualize,
                "--output-dir",
                args.vis_dir,
            ],
            check=False,  # Add this parameter
        )
    else:
        # Evaluate responses and save results
        evaluation_file = evaluate_responses(
            output_file=args.output,
            append=args.append,
            overwrite=args.overwrite,
            repetition=args.repetition,
            comparison_mode=args.comparison_mode,
            skip_encodings=args.skip_encodings,
            models=args.models,
            system_prompts=args.system_prompts,
            patterns=args.patterns,
            show_all_errors=args.show_all_errors,
            skip_parse_errors=args.skip_parse_errors,
        )

        if evaluation_file:
            # If only showing parse errors, stop after printing summary
            if args.show_all_errors:
                sys.exit(0)

            # Otherwise, generate visualizations
            print("\n🎨 Generating visualizations...")
            subprocess.run(
                [
                    sys.executable,
                    "-m",
                    "scripts.visualization.main",
                    evaluation_file,
                    "--verbose",
                ],
                check=False,
            )
