import numpy as np
import heapq

class DistanceMatrixSampler:
    def __init__(self, distance_matrix):
        self.distance_matrix = distance_matrix.copy()
        self.visited = set()
        self.size = distance_matrix.shape[0]
    
    def find_most_similar_points(self):
        mask = np.ones(self.distance_matrix.shape, dtype=bool)
        np.fill_diagonal(mask, 0)
        mask[list(self.visited), :] = False
        mask[:, list(self.visited)] = False
        masked_matrix = self.distance_matrix * mask
        max_similarity = np.max(masked_matrix)
        point_a, point_b = np.argwhere(masked_matrix == max_similarity)[0]
        return point_a, point_b

    def sample(self, n_samples=50):
        if n_samples > self.size - len(self.visited):
            raise ValueError("Not enough unvisited points left to sample.")

        point_a, point_b = self.find_most_similar_points()
        start_point = np.random.choice([point_a, point_b])
        self.visited.add(start_point)

        distances = [
            (i, self.distance_matrix[start_point, i])
            for i in range(self.size)
            if i != start_point and i not in self.visited
        ]
        closest_points = heapq.nlargest(n_samples - 1, distances, key=lambda x: x[1])

        for point, _ in closest_points:
            self.visited.add(point)

        samples = [start_point] + [point for point, _ in closest_points]
        return samples

    def is_done_sampling(self):
        return len(self.visited) == self.size