import numpy as np

class QuickMerge:

    def __init__(self, env=None):
        self._current_label = 0
        self._partitions = dict()
        self._env = env

    def most_likely(self, state):
        best = None
        best_dist = np.inf
        for label, partition in self._partitions.items():

            dist = self._distance(state, partition)
            if dist < best_dist:
                best_dist = dist
                best = label

        if best is not None:
            return best
        
        self._partitions[self._current_label] = state
        self._current_label += 1
        return self._current_label - 1


    def _distance(self, state, partition):
        if np.linalg.norm(state[0:2] - partition[0:2], ord=np.inf) > 0.03:
            return np.inf
        if np.linalg.norm(state[2:] - partition[2:], ord=np.inf) < 0.2:
            return np.inf
        return np.linalg.norm(state - partition, ord=np.inf)


    def _closest(self, state):
        best = None
        best_dist = np.inf
        for label, partition in self._partitions.items():
            dist = np.linalg.norm(state - partition, ord=np.inf)
            if dist < best_dist:
                best_dist = dist
                best = label
        return best


    def _is_close(self, state, other):
        if np.linalg.norm(state[0:2] - other[0:2], ord=np.inf) > 0.03:
            return False
        return np.linalg.norm(state[2:] - other[2:], ord=np.inf) < 0.2

    def reduce(self, label):
        return label

    def is_overlap(self, *args):
        return False

    def is_close(self, state, y):
        if y not in self._partitions:
            return False

        return self._is_close(state, self._partitions[y])