"""Uniform sampling method."""

from sampling_methods import base_sampler
import numpy as np


class UniformSampling(base_sampler.SamplingMethod):
  """Uniform sampling method."""

  def __init__(self, n, debug_info = False):
    super().__init__(n=n, debug_info=debug_info)

  def get_scores(
      self, already_selected_indices, label_budget
  ):
    """Gets scores of the test data for sampling."""
    if already_selected_indices.shape[0] + label_budget >= self.n:
      # Scores are useless in this case,
      # since all remaining samples will be selected.
      return np.zeros(self.n, dtype=np.float32)
    scores = np.random.uniform(low=0.0, high=1.0, size=(self.n,))
    return scores
