"""Shared utilities for TSP baseline implementations."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, List, Sequence

import numpy as np


@dataclass(frozen=True)
class TSPBaselineResult:
    """Standard output for a TSP baseline solver."""

    sequence: List[int]
    cost: float
    extras: dict | None = None


def ensure_distance_matrix(matrix: np.ndarray | Sequence[Sequence[float]]) -> np.ndarray:
    """Return a square numpy array copy of the provided distance matrix."""
    arr = np.asarray(matrix, dtype=float)
    if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
        raise ValueError("distance matrix must be square")
    return arr


def tour_length(sequence: Sequence[int], distance_matrix: np.ndarray) -> float:
    """Compute the total distance of a Hamiltonian tour."""
    n = len(sequence)
    if n == 0:
        return 0.0
    cost = 0.0
    for idx in range(n):
        a = sequence[idx]
        b = sequence[(idx + 1) % n]
        cost += distance_matrix[a, b]
    return float(cost)


def insertion_cost(
    sequence: Sequence[int],
    city: int,
    position: int,
    distance_matrix: np.ndarray,
) -> float:
    """Compute the marginal cost of inserting city between sequence[position-1] and sequence[position]."""
    n = len(sequence)
    if n == 0:
        return 0.0
    prev_city = sequence[position - 1] if position > 0 else sequence[-1]
    next_city = sequence[position] if position < n else sequence[0]
    added = distance_matrix[prev_city, city] + distance_matrix[city, next_city]
    removed = distance_matrix[prev_city, next_city]
    return added - removed


def best_insertion_position(
    sequence: Sequence[int],
    city: int,
    distance_matrix: np.ndarray,
) -> tuple[int, float]:
    """Return (position, delta_cost) giving best insertion location."""
    if not sequence:
        return 0, 0.0
    best_pos = 0
    best_delta = float("inf")
    for pos in range(len(sequence)):
        delta = insertion_cost(sequence, city, pos, distance_matrix)
        if delta < best_delta:
            best_delta = delta
            best_pos = pos
    return best_pos, best_delta


def coords_to_distance_matrix(coords: np.ndarray | Sequence[Sequence[float]]) -> np.ndarray:
    """Compute full pairwise Euclidean distance matrix from coordinates."""
    arr = np.asarray(coords, dtype=float)
    if arr.ndim != 2 or arr.shape[1] < 2:
        raise ValueError("coordinates must be of shape (n, 2+)")
    diff = arr[:, None, :] - arr[None, :, :]
    dist = np.sqrt(np.sum(diff * diff, axis=-1))
    return dist


def nearest_neighbor_route(distance_matrix: np.ndarray, start: int = 0) -> List[int]:
    """Deterministic nearest-neighbor tour starting from a seed city."""
    dist = ensure_distance_matrix(distance_matrix)
    n = dist.shape[0]
    if n <= 1:
        return list(range(n))
    unvisited = set(range(n))
    route = [start % n]
    unvisited.remove(route[0])
    while unvisited:
        cur = route[-1]
        nxt = min(unvisited, key=lambda c: float(dist[cur, c]))
        route.append(nxt)
        unvisited.remove(nxt)
    return route


def stochastic_two_opt(
    tour: List[int],
    distance_matrix: np.ndarray,
    max_steps: int = 400,
    rng: np.random.Generator | None = None,
) -> List[int]:
    """Lightweight stochastic 2-opt improvement with capped iterations."""
    if len(tour) <= 3:
        return tour
    dist = ensure_distance_matrix(distance_matrix)
    rng = rng or np.random.default_rng(0)
    best = list(tour)
    best_cost = tour_length(best, dist)
    n = len(tour)
    for _ in range(max_steps):
        i, j = sorted(rng.choice(n, 2, replace=False))
        if j - i <= 1 or (i == 0 and j == n - 1):
            continue
        a, b = best[i - 1], best[i]
        c, d = best[j], best[(j + 1) % n]
        delta = dist[a, c] + dist[b, d] - dist[a, b] - dist[c, d]
        if delta < -1e-9:
            best[i:j] = reversed(best[i:j])
            best_cost += delta
    return best


