"""Confidence sampling method."""

from typing import List

from sampling_methods import base_sampler
from utils import tf_util
import numpy as np
import tensorflow as tf


class ConfidenceSampling(base_sampler.SamplingMethod):
  """Confidence sampling method."""

  def __init__(
      self,
      ensemble_models,
      n,
      target_test_ds,
      ensemble_method = 'soft',
      debug_info = False,
  ):
    super().__init__(n=n, debug_info=debug_info)
    self.ds = target_test_ds
    self.ensemble_models = ensemble_models
    self.ensemble_method = ensemble_method

  def get_scores(
      self,
      already_selected_indices,
      label_budget
  ):
    """Gets scores of the test data for sampling."""
    if already_selected_indices.shape[0] + label_budget >= self.n:
      # Scores are useless in this case,
      # since all remaining samples will be selected.
      return np.zeros(self.n, dtype=np.float32)
    outputs = []
    for batch_x, _ in self.ds:
      batch_output = tf_util.get_ensemble_model_output(
          self.ensemble_models, batch_x, self.ensemble_method
      )
      outputs.extend(batch_output.numpy())
    outputs = np.array(outputs)
    scores = np.max(outputs, axis=1)
    return scores
