"""Utilities for analyzing robot structures produced by the builder."""

from __future__ import annotations

from collections import defaultdict, deque
from typing import Any, Dict, Iterable, List, Sequence, Tuple

import numpy as np


def _get_rigid_segment_ids(segment_id: np.ndarray) -> np.ndarray:
    """Return sorted unique rigid segment ids (exclude 0 for non-rigid)."""
    return np.sort(np.unique(segment_id[segment_id > 0]))


def _build_adjacency(
    connections: Sequence[Dict[str, object]], segment_ids: Iterable[int]
) -> Dict[int, set]:
    """Create an undirected adjacency map from connection components."""
    adjacency: Dict[int, set] = {sid: set() for sid in segment_ids}
    for connection in connections:
        if "components" not in connection:
            continue
        a, b = connection["components"]
        adjacency.setdefault(a, set()).add(b)
        adjacency.setdefault(b, set()).add(a)
    return adjacency


def rigid_segment_mass_stats(
    robot_structure: Dict[str, object],
) -> Tuple[float, float]:
    """Compute mean and std of rigid segment masses (voxel counts).

    Args:
        robot_structure: Dict containing at least ``segment_id`` as a 3D ndarray.

    Returns:
        Tuple of (mean_mass, std_mass). Returns (0.0, 0.0) when no rigid segments.
    """
    segment_id = np.asarray(robot_structure["segment_id"])
    segment_ids = _get_rigid_segment_ids(segment_id)
    if segment_ids.size == 0:
        return 0.0, 0.0

    masses = []
    for sid in segment_ids:
        masses.append(float(np.sum(segment_id == sid)))
    masses_arr = np.array(masses, dtype=np.float64)
    return float(masses_arr.mean()), float(masses_arr.std())


def rigid_segment_bounding_box_stats(
    robot_structure: Dict[str, object],
) -> Tuple[float, float]:
    """Compute mean/std of oriented bounding box scores for rigid segments.

    Each segment is fit with a PCA-oriented bounding box. For a single scalar that
    reflects both size and anisotropy, we use:
    score = (Lx * Ly * Lz) ** (1/3) * (Lmax / Lmin),
    where Lx, Ly, Lz are side lengths of the oriented box. This mixes the geometric
    mean size (cube-root of volume) with elongation (max/min ratio).

    Args:
        robot_structure: Dict containing ``segment_id`` as a 3D ndarray.

    Returns:
        Tuple of (mean_bbox_sum, std_bbox_sum). Returns (0.0, 0.0) when no segments.
    """
    segment_id = np.asarray(robot_structure["segment_id"])
    segment_ids = _get_rigid_segment_ids(segment_id)
    if segment_ids.size == 0:
        return 0.0, 0.0

    bbox_scores: List[float] = []
    for sid in segment_ids:
        coords = np.argwhere(segment_id == sid)
        if coords.shape[0] == 1:
            lengths = np.array([1.0, 1.0, 1.0], dtype=np.float64)
        else:
            centered = coords - coords.mean(axis=0, keepdims=True)
            # Principal axes via SVD for stability.
            _, _, vh = np.linalg.svd(centered, full_matrices=False)
            rotated = centered @ vh.T
            mins = rotated.min(axis=0)
            maxs = rotated.max(axis=0)
            lengths = np.maximum(maxs - mins, 1.0)
        geom_mean = float(np.prod(lengths) ** (1.0 / 3.0))
        elongation = float(lengths.max() / lengths.min())
        bbox_scores.append(geom_mean * elongation)
    bbox_arr = np.array(bbox_scores, dtype=np.float64)
    return float(bbox_arr.mean()), float(bbox_arr.std())


def rigid_segment_connectivity_stats(
    robot_structure: Dict[str, object],
) -> Tuple[float, float]:
    """Compute mean/std of rigid segment connectivity (degree).

    Connectivity is how many other rigid segments a segment is joined to.

    Args:
        robot_structure: Dict with ``segment_id`` ndarray and ``connections`` list.

    Returns:
        Tuple of (mean_connectivity, std_connectivity). (0.0, 0.0) if no segments.
    """
    segment_id = np.asarray(robot_structure["segment_id"])
    segment_ids = _get_rigid_segment_ids(segment_id)
    if segment_ids.size == 0:
        return 0.0, 0.0

    adjacency = _build_adjacency(robot_structure.get("connections", []), segment_ids)
    degrees = np.array(
        [len(adjacency.get(sid, ())) for sid in segment_ids], dtype=np.float64
    )
    return float(degrees.mean()), float(degrees.std())


def rigid_segment_longest_path_stats(
    robot_structure: Dict[str, object],
) -> Tuple[float, float]:
    """Compute mean/std of longest chained segments per segment.

    For each rigid segment, the metric is the longest path distance to any
    other segment in its connected component, plus one to count the starting segment.

    Args:
        robot_structure: Dict with ``segment_id`` ndarray and ``connections`` list.

    Returns:
        Tuple of (mean_longest_path, std_longest_path). (0.0, 0.0) if no segments.
    """
    segment_id = np.asarray(robot_structure["segment_id"])
    segment_ids = _get_rigid_segment_ids(segment_id)
    if segment_ids.size == 0:
        return 0.0, 0.0

    adjacency = _build_adjacency(robot_structure.get("connections", []), segment_ids)

    def eccentricity(node: int) -> int:
        """Longest path length from node within its component."""
        visited = {node}
        queue: deque[Tuple[int, int]] = deque([(node, 0)])
        max_dist = 0
        while queue:
            current, dist = queue.popleft()
            max_dist = max(max_dist, dist)
            for neighbor in adjacency.get(current, ()):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((neighbor, dist + 1))
        return max_dist

    longest_paths = []
    for sid in segment_ids:
        # Add one so the metric counts the starting segment itself.
        longest_paths.append(float(eccentricity(sid) + 1))
    longest_arr = np.array(longest_paths, dtype=np.float64)
    return float(longest_arr.mean()), float(longest_arr.std())


def rigid_segment_count(
    robot_structure: Dict[str, object],
) -> int:
    """Count rigid segments (ids > 0) for a single robot."""
    segment_id = np.asarray(robot_structure["segment_id"])
    return int(_get_rigid_segment_ids(segment_id).size)


def population_segment_count_stats(
    robot_structures: List[Any],
) -> Tuple[float, float]:
    """Mean/std of rigid segment counts across a population."""
    counts = [rigid_segment_count(structure) for structure in robot_structures]
    if not counts:
        return 0.0, 0.0
    counts_arr = np.array(counts, dtype=np.float64)
    return float(counts_arr.mean()), float(counts_arr.std())


def population_mass_stats(
    robot_structures: List[Any],
) -> Tuple[float, float]:
    """Mean/std of segment mass means across a population."""
    means = []
    stds = []
    for structure in robot_structures:
        mean, std = rigid_segment_mass_stats(structure)
        means.append(mean)
        stds.append(std)
    if not means:
        return 0.0, 0.0
    return float(np.mean(means)), float(np.std(means))


def population_bounding_box_stats(
    robot_structures: List[Any],
) -> Tuple[float, float]:
    """Mean/std of segment bounding-box sums across a population."""
    means = []
    stds = []
    for structure in robot_structures:
        mean, std = rigid_segment_bounding_box_stats(structure)
        means.append(mean)
        stds.append(std)
    if not means:
        return 0.0, 0.0
    return float(np.mean(means)), float(np.std(means))


def population_connectivity_stats(
    robot_structures: List[Any],
) -> Tuple[float, float]:
    """Mean/std of segment connectivity across a population."""
    means = []
    stds = []
    for structure in robot_structures:
        mean, std = rigid_segment_connectivity_stats(structure)
        means.append(mean)
        stds.append(std)
    if not means:
        return 0.0, 0.0
    return float(np.mean(means)), float(np.std(means))


def population_longest_path_stats(
    robot_structures: List[Any],
) -> Tuple[float, float]:
    """Mean/std of segment longest-path lengths across a population."""
    means = []
    stds = []
    for structure in robot_structures:
        mean, std = rigid_segment_longest_path_stats(structure)
        means.append(mean)
        stds.append(std)
    if not means:
        return 0.0, 0.0
    return float(np.mean(means)), float(np.std(means))


def diversity_cv(
    mean_mu_mass: float,
    std_mu_mass: float,
    mean_mu_bones: float,
    std_mu_bones: float,
    mean_mu_score: float,
    std_mu_score: float,
    mean_mu_conn: float,
    std_mu_conn: float,
    mean_mu_chain: float,
    std_mu_chain: float,
    eps: float = 1e-8,
) -> float:
    """Coefficient-of-variation–based diversity score over population metrics.

    Args:
        mean_mu_mass: Mean mass across the population.
        std_mu_mass: Std of mass across the population.
        mean_mu_score: Mean bounding-box score across the population.
        std_mu_score: Std of bounding-box score across the population.
        mean_mu_conn: Mean connectivity across the population.
        std_mu_conn: Std of connectivity across the population.
        mean_mu_chain: Mean longest path across the population.
        std_mu_chain: Std of longest path across the population.
        eps: Small constant to avoid division by zero.

    Returns:
        Diversity score as the L2 norm of per-metric coefficients of variation.
    """
    cv_mass = std_mu_mass / (mean_mu_mass + eps)
    cv_bones = std_mu_bones / (mean_mu_bones + eps)
    cv_score = std_mu_score / (mean_mu_score + eps)
    cv_conn = std_mu_conn / (mean_mu_conn + eps)
    cv_chain = std_mu_chain / (mean_mu_chain + eps)
    return float(
        np.sqrt(cv_mass**2 + cv_bones**2 + cv_score**2 + cv_conn**2 + cv_chain**2)
    )
