#!/usr/bin/env python3
import torch
import numpy as np
import uimnet
from uimnet.measures.measure import Measure

class Native(Measure):
  """
  Native uncertanity measure
  """

  def __init__(self, algorithm):
    super(Native, self).__init__(algorithm=algorithm)

  def forward(self, loader):
    # is_ensemble = hasattr(self.algorithm, 'paths')
    # if is_ensemble:
    #   return self.algorithm.uncertainty(loader)
    # else:
    measures = []
    for batch in loader:
      measures += [self.get_native_measure(batch['x'].to(self.algorithm.device))]
    return torch.cat(measures, dim= 0)

  def get_native_measure(self, x):
    if hasattr(self.algorithm, 'uncertainty'):
      return self.algorithm.uncertainty(x)
    else:
      return x.new_zeros((x.shape[0], )).fill_(np.nan)


if __name__ == '__main__':
  pass
