#!/usr/bin/env python3
import torch.nn.functional as F
import torch
import numpy as np
from uimnet.ensembles.base import BaseEnsemble
from pathlib import Path
from uimnet import utils
from uimnet import numerics



class Bagging(BaseEnsemble):

  def setup_optimizers(self):
    pass
  def entropy_(self, p):
    zeros = torch.zeros(()).to(p.device)
    h = p * torch.where(p == 0, zeros, torch.log(p))
    return h.sum(1).mul(-1)

  # def entropy(self, l):
  #   # H[p(y|x)] = - \sum_{y} log(p(y|x)) p(y|x)
  #   eps = 1e-5
  #   l = l.clamp_min(eps)
  #   logpy_x = l - torch.logsumexp(l, dim=1, keepdim=True)
  #   py_x = F.softmax(l, dim=1)
  #   return -1 * (py_x * logpy_x).sum(dim=1)


  def update(self, x, y, epoch=None):
    return {'loss': np.nan, 'cost': np.nan}

  def _forward(self, x):

    x = x.to(self.device)
    all_logits = torch.stack(self.forward_members(x), dim=1)
    return numerics.log_marginalization_from_logits(all_logits).to(x.device)

  # def uncertainty(self, loader):
  #   with torch.no_grad():
  #     # Extract all logits in without calibration
  #     all_logits = []
  #     for algorithm in self.algorithms.values():
  #       algorithm.cuda()
  #       algorithm.set_temperature('initial')
  #       all_logits += [algorithm.collect_logits_and_targets(loader)['logits']]
  #       algorithm.cpu()

  #     # Computing members logits
  #     members_logits = []
  #     temperature_mode = self.temperature.mode # Get members temperature from
  #     # ensembles temperature
  #     for algorithm, logits in zip(self.algorithms.values(), all_logits):
  #       algorithm.set_temperature(temperature_mode)
  #       members_logits += [algorithm.temperature(logits)]

  #     entropy_members = torch.stack([self.entropy(l) for l in members_logits]).mean(dim=0)

  #     # Computing ensemble logits
  #     ensemble_logits = numerics.log_marginalization_from_logits(torch.stack(all_logits, dim=1))
  #     self.temperature.cpu()
  #     ensemble_logits = self.temperature(ensemble_logits)
  #     self.temperature.to(self.device)

  #     entropy_ensemble = self.entropy(ensemble_logits)

  #     return entropy_ensemble - entropy_members



  def uncertainty(self, x):

    predictions = self.forward_members(x)

    predictions_members = [p.softmax(1) for p in predictions]
    predictions_ensemble = torch.stack(predictions_members).mean(0)

    entropy_members = [self.entropy_(p) for p in predictions_members]
    entropy_ensemble = self.entropy_(predictions_ensemble)

    measure = entropy_ensemble - torch.stack(entropy_members).mean(0)
    return measure.to(x.device)

if __name__ == '__main__':
  pass
