"""Nearest insertion heuristic for TSP."""

from __future__ import annotations

from typing import List

import numpy as np

from ..utils import (
    TSPBaselineResult,
    best_insertion_position,
    ensure_distance_matrix,
    tour_length,
)


def solve(distance_matrix: np.ndarray) -> TSPBaselineResult:
    """Construct a TSP tour using nearest insertion."""
    dist = ensure_distance_matrix(distance_matrix)
    n = dist.shape[0]
    if n <= 2:
        sequence = list(range(n))
        return TSPBaselineResult(sequence=sequence, cost=tour_length(sequence, dist))

    i, j = _initial_edge(dist, minimize=True)
    tour: List[int] = [i, j]
    remaining = [city for city in range(n) if city not in tour]

    while remaining:
        next_city = min(remaining, key=lambda c: float(dist[c, tour].min()))
        pos, _ = best_insertion_position(tour, next_city, dist)
        tour.insert(pos, next_city)
        remaining.remove(next_city)

    cost = tour_length(tour, dist)
    return TSPBaselineResult(sequence=tour, cost=cost, extras={"method": "NI"})


def _initial_edge(dist: np.ndarray, minimize: bool) -> tuple[int, int]:
    n = dist.shape[0]
    mask = dist + np.eye(n) * (np.inf if minimize else -np.inf)
    index = np.argmin(mask) if minimize else np.argmax(mask)
    i, j = divmod(int(index), n)
    return i, j


