"""Shared utilities for experiment visualization scripts."""

import json
from pathlib import Path
from typing import Any


class NormalizeScore:
    """Normalizes game scores to [0, 1] scale based on Nash Equilibrium and cooperative payoffs.

    Precomputes normalization parameters at initialization for efficient repeated normalization.

    Args:
        game: Game name (e.g., "PrisonersDilemma", "PublicGoods")
        game_config: Game configuration dict containing payoff structure

    Attributes:
        game: Game name
        ne_payoff: Nash Equilibrium payoff (worst case)
        coop_payoff: Cooperative payoff (best case)

    Example:
        >>> normalizer = NormalizeScore("PrisonersDilemma", config)
        >>> normalized = normalizer.normalize(raw_score)
    """

    def __init__(self, game: str, game_config: dict):
        """Initialize normalizer by precomputing NE and cooperative payoffs.

        Args:
            game: Game name
            game_config: Game configuration dict
        """
        self.game = game

        if game == "PrisonersDilemma":
            payoff_matrix = game_config["kwargs"]["payoff_matrix"]
            self.ne_payoff = payoff_matrix["DD"][0]  # NE: both defect
            self.coop_payoff = payoff_matrix["CC"][1]  # Cooperative: both cooperate

        elif game == "PublicGoods":
            self.coop_payoff = game_config["kwargs"]["multiplier"]
            self.ne_payoff = 1  # NE: no one contributes

        elif game == "TravellersDilemma":
            self.ne_payoff = game_config["kwargs"]["min_claim"]
            spacing = game_config["kwargs"]["claim_spacing"]
            num_actions = game_config["kwargs"]["num_actions"]
            self.coop_payoff = self.ne_payoff + spacing * (num_actions - 1)

        elif game == "TrustGame":
            payoff_matrix = game_config["kwargs"]["payoff_matrix"]
            self.ne_payoff = payoff_matrix["KK"][0]  # NE: both keep
            self.coop_payoff = payoff_matrix["GG"][0]  # Cooperative: both give

        elif game == "StagHunt":
            self.ne_payoff = 3
            self.coop_payoff = 5

        elif game == "MatchingPennies":
            self.ne_payoff = -1
            self.coop_payoff = 0

        else:
            # For non-social-dilemma games, use identity normalization
            self.ne_payoff = 0.0
            self.coop_payoff = 1.0

    def normalize(self, score: float) -> float:
        """Normalize a score to [0, 1] scale.

        Args:
            score: Raw score from the game

        Returns:
            Normalized score where 0 = NE payoff, 1 = cooperative payoff
        """
        if self.coop_payoff == self.ne_payoff:
            # Avoid division by zero
            return 0.0

        return (score - self.ne_payoff) / (self.coop_payoff - self.ne_payoff)


def load_json(path: Path) -> dict[str, Any]:
    """Load JSON file."""
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def discover_experiment_subfolders(experiment_dir: Path) -> list[Path]:
    """
    Find experiment subfolders, excluding configs and slurm directories.

    Returns subdirectories that contain actual experiments.
    """
    return [
        d for d in experiment_dir.iterdir()
        if d.is_dir() and d.name not in {"configs", "slurm"}
    ]


def clean_model_name(model_name: str, max_length: int = 25) -> str:
    """
    Shorten model name for display in visualizations.

    Removes provider prefix and truncates if too long.
    Examples:
        "google/gemini-3-flash-preview(CoT)" -> "gemini-3-flash-preview(CoT)"
        "openai/gpt-5.2(CoT)" -> "gpt-5.2(CoT)"
    """
    parts = model_name.split("/")
    if len(parts) > 1:
        model_and_type = parts[-1]
        if len(model_and_type) > max_length:
            model_and_type = model_and_type[:max_length - 3] + "..."
        return model_and_type
    return model_name


def display_mechanism_name(mechanism: str) -> str:
    """Convert internal mechanism name to display name for tables and plots.

    Args:
        mechanism: Internal mechanism name

    Returns:
        Display name for mechanism

    Examples:
        "Reputation" -> "Reputation+"
        "ReputationFirstOrder" -> "Reputation-"
        "NoMechanism" -> "NoMechanism"
    """
    if mechanism == "Reputation":
        return "Reputation+"
    elif mechanism == "ReputationFirstOrder":
        return "Reputation-"
    return mechanism


def simplify_model_name(model_name: str) -> str:
    """
    Simplify model name to a short, canonical version.

    Maps specific model names to their simplified versions for consistent display.
    Handles both CoT (reasoning) and IO (base) variants of models.

    Args:
        model_name: Full model name, potentially with agent type suffix like "(CoT)"

    Returns:
        Simplified model name

    Examples:
        "google/gemini-3-flash-preview(CoT)" -> "gemini-3-f-reasoning"
        "google/gemini-3-flash-preview" -> "gemini-3-f-base"
        "openai/gpt-5.2(CoT)" -> "gpt-5.2"
        "openai/gpt-4o-2024-05-13(CoT)" -> "gpt-4o"
        "anthropic/claude-sonnet-4.5(CoT)" -> "claude-sonnet-4.5"
        "qwen/qwen3-30b-a3b-instruct-2507(CoT)" -> "qwen3-30b-a3b"
    """
    model_mappings = {
        "google/gemini-3-flash-preview": {
            "with_cot": "Gemini-R",
            "without_cot": "Gemini-B"
        },
        "openai/gpt-5.2": "GPT-5.2",
        "openai/gpt-4o-2024-05-13": "GPT-4o",
        "anthropic/claude-sonnet-4.5": "Claude",
        "qwen/qwen3-30b-a3b-instruct-2507": "Qwen-30b"
    }

    has_cot = "(CoT)" in model_name
    base_model = model_name.replace("(CoT)", "").replace("(IO)", "").strip()

    mapping = model_mappings.get(base_model)
    if isinstance(mapping, dict):
        return mapping["with_cot"] if has_cot else mapping["without_cot"]
    if mapping:
        return mapping

    return model_name


def get_num_players_from_matchup(matchup_data: dict[str, Any]) -> int:
    """
    Extract number of players from matchup_payoffs.json data.

    Args:
        matchup_data: Loaded matchup_payoffs.json content

    Returns:
        Number of players in the game
    """
    return len(matchup_data["profile"][0]["players"])


def sort_mechanisms(mechanisms: list[str]) -> list[str]:
    """
    Sort mechanisms in the preferred order.

    Args:
        mechanisms: List of mechanism names

    Returns:
        Sorted list of mechanism names
    """
    preferred_order = [
        "NoMechanism",
        "Repetition",
        "ReputationFirstOrder",
        "Reputation",
        "Disarmament",
        "Mediation",
        "Contracting"
    ]

    order_map = {name.lower(): i for i, name in enumerate(preferred_order)}

    def sort_key(mech):
        mech_lower = mech.lower()
        if mech_lower in order_map:
            return (0, order_map[mech_lower])
        return (1, mech)

    return sorted(mechanisms, key=sort_key)


def sort_games(games: list[str]) -> list[str]:
    """
    Sort games in the preferred order.

    Args:
        games: List of game names

    Returns:
        Sorted list of game names
    """
    preferred_order = [
        "PrisonersDilemma",
        "PublicGoods",
        "TravellersDilemma",
        "TrustGame",
        "StagHunt",
        "MatchingPennies"
    ]

    order_map = {name.lower(): i for i, name in enumerate(preferred_order)}

    def sort_key(game):
        game_lower = game.lower()
        if game_lower in order_map:
            return (0, order_map[game_lower])
        return (1, game)

    return sorted(games, key=sort_key)


def sort_models(models: list[str]) -> list[str]:
    """
    Sort models in the preferred order.

    Preferred order: claude, gemini-reasoning, gemini-base, gpt-5.2, gpt-4o, qwen

    Args:
        models: List of model names (full names with provider prefixes)

    Returns:
        Sorted list of model names
    """
    def sort_key(model):
        model_lower = model.lower()

        if "claude" in model_lower:
            return (0, model)
        elif "gemini" in model_lower and "(cot)" in model_lower:
            return (1, model)
        elif "gemini" in model_lower:
            return (2, model)
        elif "gpt-5" in model_lower:
            return (3, model)
        elif "gpt-4" in model_lower:
            return (4, model)
        elif "qwen" in model_lower:
            return (5, model)
        else:
            return (6, model)

    return sorted(models, key=sort_key)


# Validation utilities for multi-folder experiment processing


def validate_folder_count_consistency(
    grouped_data: dict[tuple[str, str], Any]
) -> int:
    """Validate all game+mechanism combinations have same number of folders.

    Args:
        grouped_data: Dictionary mapping (game_type, mechanism_type) to either:
            - dict with 'folders' key (plot_payoff_tensors.py style)
            - list of items (create_table.py style)

    Returns:
        The consistent folder count per group

    Raises:
        ValueError: If no groups to validate
        AssertionError: If groups have different folder counts
    """
    if not grouped_data:
        raise ValueError("No groups to validate")

    # Get folder counts for all groups
    folder_counts = {}
    for group_key, group_data in grouped_data.items():
        if isinstance(group_data, dict) and 'folders' in group_data:
            # plot_payoff_tensors.py style
            folder_counts[group_key] = len(group_data['folders'])
        elif isinstance(group_data, list):
            # create_table.py style
            folder_counts[group_key] = len(group_data)
        else:
            raise ValueError(f"Unexpected group_data type for {group_key}: {type(group_data)}")

    unique_counts = set(folder_counts.values())

    if len(unique_counts) != 1:
        # Build detailed error message showing which groups have which counts
        counts_by_group = {}
        for group_key, count in folder_counts.items():
            if count not in counts_by_group:
                counts_by_group[count] = []
            game_type, mechanism_type = group_key
            counts_by_group[count].append(f"{mechanism_type}_{game_type}")

        error_msg = "Folder count mismatch across groups:\n"
        for count in sorted(counts_by_group.keys()):
            groups = counts_by_group[count]
            error_msg += f"  {count} folder(s): {', '.join(groups)}\n"
        error_msg += "All game+mechanism combinations must have the same number of folders."

        raise AssertionError(error_msg)

    return list(unique_counts)[0]


def validate_list_consistency(
    value_lists: list[list[str]],
    identifiers: list[str],
    group_key: tuple[str, str],
    list_name: str
) -> list[str]:
    """Validate all entries have same list in same order.

    Args:
        value_lists: List of lists to validate (e.g., model lists)
        identifiers: Identifiers for each list (e.g., folder paths) for error messages
        group_key: (game_type, mechanism_type) for error messages
        list_name: Descriptive name (e.g., "model list") for error messages

    Returns:
        The validated list (from first entry)

    Raises:
        ValueError: If inputs are empty
        AssertionError: If lists differ in content or order
    """
    if not value_lists or not identifiers:
        raise ValueError(f"Empty inputs for validation of {group_key}")

    reference_list = value_lists[0]
    reference_id = identifiers[0]

    for i, (value_list, identifier) in enumerate(zip(value_lists[1:], identifiers[1:]), start=1):
        if value_list != reference_list:
            raise AssertionError(
                f"{list_name.capitalize()} mismatch in {group_key}:\n"
                f"  Entry 0 ({reference_id}):\n    {reference_list}\n"
                f"  Entry {i} ({identifier}):\n    {value_list}\n"
                f"  Note: Both content AND order must match."
            )

    return reference_list


def validate_dict_consistency(
    value_dicts: list[dict],
    identifiers: list[str],
    group_key: tuple[str, str],
    dict_name: str
) -> dict:
    """Validate all entries have identical dictionary content.

    Args:
        value_dicts: List of dicts to validate (e.g., game configs)
        identifiers: Identifiers for each dict (e.g., folder paths) for error messages
        group_key: (game_type, mechanism_type) for error messages
        dict_name: Descriptive name (e.g., "game config") for error messages

    Returns:
        The validated dict (from first entry)

    Raises:
        ValueError: If inputs are empty
        AssertionError: If dicts differ in content
    """
    if not value_dicts or not identifiers:
        raise ValueError(f"Empty inputs for validation of {group_key}")

    reference_dict = value_dicts[0]
    reference_id = identifiers[0]

    for i, (value_dict, identifier) in enumerate(zip(value_dicts[1:], identifiers[1:]), start=1):
        if value_dict != reference_dict:
            raise AssertionError(
                f"{dict_name.capitalize()} mismatch in {group_key}:\n"
                f"  Entry 0 ({reference_id}):\n    {reference_dict}\n"
                f"  Entry {i} ({identifier}):\n    {value_dict}"
            )

    return reference_dict
