# coding: utf-8

import torch
import torch.nn.functional as F
from ignite.metrics import Metric
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced

class CodePerplexity(Metric):
	def __init__(self, *args, **kwargs):
		self._counts = 0.0
		super().__init__(*args, **kwargs)

	@reinit__is_reduced
	def reset(self):
		self._counts = 0.0
		super().reset()

	@reinit__is_reduced
	def update(self, onehot):
		"""
		onehot: batch_size x codebook_size x *
		"""
		K = onehot.size(1)
		self._counts += onehot.float().transpose(0,1).reshape(K,-1 # BxKx* -> KxBx* -> Kx*
							).sum(dim=1)

	@sync_all_reduce("_counts:SUM")
	def compute(self):
		K = self._counts.size(0)
		freqs = self._counts / self._counts.sum()
		entropy = -(freqs*freqs.masked_fill(freqs==0.0, 1.0).log()).sum()
		return entropy.exp().item() / K

class CodeCoverage(CodePerplexity):
	@sync_all_reduce("_counts:SUM")
	def compute(self):
		K = self._counts.size(0)
		num_used = (self._counts>0).float().sum().item()
		return num_used/K
