from typing import Callable, List
import numpy as np

def evaluate(solve_tsp: Callable[[np.ndarray], List[int]], dataset: dict) -> float:
    try:
        distances = dataset["distances"]
        path = solve_tsp(distances)

        if not _is_valid_path(path, len(distances)):
            return float('inf')
        total_distance = _calculate_path_length(distances, path)
        return total_distance 
        
    except Exception as e:
        return float('inf')  

def _is_valid_path(path: List[int], n: int) -> bool:
    if len(path) != n:
        return False
    if len(set(path)) != n:
        return False
    return sorted(path) == list(range(n))


def _calculate_path_length(distances: np.ndarray, path: List[int]) -> float:
    total = 0
    n = len(path)
    for i in range(n):
        total += distances[path[i], path[(i + 1) % n]]
    return total 