# coding: utf-8

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class _Base(nn.Module):
	def __init__(self, channels, codebook_size, metric='dot', normalize=False, init_scale=None, random_leak=False):
		super().__init__()
		# assert lookup in ['softmax','argmax','soft-gumbel','hard-gumbel'], 'lookup method must be softmax, argmax, or gumbel.'
		# self.lookup = lookup
		assert metric in ['dot','euclid'], 'metric must be dot or euclid.'
		assert (metric!='euclid') or (not normalize), 'metric=euclid cannot be normalized.'
		self.metric = metric
		raw_codebook = torch.randn(codebook_size,channels,1,1)
		if init_scale is None:
			init_scale = 1/math.sqrt(channels)
		assert init_scale>0, 'init_scale must be positive.'
		self.init_scale = init_scale
		self.normalize = normalize
		if normalize:
			raw_codebook = F.normalize(raw_codebook, p=2.0, dim=1)
			# if lookup=='softmax':
				# self.register_parameter('log_codebook_scale', nn.Parameter(torch.tensor(init_scale).log()))
		else:
			raw_codebook = init_scale*raw_codebook
		self.register_parameter('raw_codebook', nn.Parameter(raw_codebook, requires_grad=True))
		self.random_leak = random_leak
		self.register_forward_pre_hook(self.format_codebook)

	@staticmethod
	def format_codebook(module, inputs):
		codebook = module.raw_codebook
		if module.normalize:
			codebook = F.normalize(codebook, p=2.0, dim=1)
		module.codebook = codebook

	def quantize(self, feature):
		code,reg_loss,stats = self.lookup(feature)
		quantized = F.conv2d(code, self.codebook.transpose(0,1))
		return quantized,code,reg_loss,stats
	
	def lookup(self, feature):
		pass
		# NOTE: Return a onehot-like vector, termed "code" here.
		# return code,reg_loss

	@staticmethod
	def hard_lookup(similarity):
		"""
		Lookup the argmax entry and return as the onehot vector.
		"""
		B,K,*_ = similarity.size()
		max_similarity,code = similarity.view(B,K,-1).max(dim=1)
		code = F.one_hot(code, num_classes=K).transpose(1,2).reshape_as(similarity).float()
		return code,max_similarity
	
	@staticmethod
	def code_coverage(hard_code):
		"""
		hard_code: batch_size x codebook_size x *
		"""
		B,K,*_ = hard_code.size()
		code_freq = hard_code.view(B,K,-1).mean(dim=-1)
		# per-sample stats
		code_entropy = -(code_freq*code_freq.masked_fill(code_freq==0.0, 1.0).log()).sum(dim=1).mean(dim=0)
		code_coverage = (code_freq>0.0).float().sum(dim=1).mean(dim=0) # per smaple
		code_freq = code_freq.mean(dim=0)
		return code_freq,code_entropy,code_coverage

	def forward(self, feature):
		"""
		feature: batch_size x channels x *
		"""
		if self.normalize:
			feature = F.normalize(feature, p=2.0, dim=1)

		quantized,code,reg_loss,stats = self.quantize(feature)
		if self.training and self.random_leak:
			leak = torch.rand_like(feature)
			quantized = leak*feature + (1.0-leak)*quantized
		return quantized,code,reg_loss,stats

class SoftmaxVAE(_Base):
	def __init__(self, *args, regularization=None, num_neighbors=8, **kwargs):
		super().__init__(*args, **kwargs)
		assert self.metric!='euclid', 'metric=euclid is not supported.'
		assert regularization in [None,'knn_l2','knn_ce','global_perplexity'], 'regularizer must be None, var, or perplexity'
		self.regularization = '' if regularization is None else regularization
		self.num_neighbors = 0 if regularization is None else num_neighbors
		if self.normalize:
			self.register_parameter('log_inv_temperature', nn.Parameter(torch.tensor(self.init_scale).log()))

	def soft_lookup(self, logits):
		return F.softmax(logits, dim=1)

	def lookup(self, feature):
		logits = F.conv2d(feature, self.codebook)
		if self.normalize:
			logits = logits*self.log_inv_temperature.exp()
		
		hard_code,_ = self.hard_lookup(logits)
		probs = self.soft_lookup(logits) if self.training else hard_code

		probs_flatten,log_probs_flatten,entropy = self._flatten_probs(logits, probs)
		# log_mean_probs = self._mean_log_probs(log_probs_flatten)
		code_freq,code_entropy,code_coverage = self.code_coverage(hard_code)
		stats = dict(entropy=entropy,code_freq=code_freq,code_entropy=code_entropy,code_coverage=code_coverage)

		code = probs if self.training else hard_code

		reg_loss = dict()
		if 'knn' in self.regularization:
			knn_loss = self._knn_loss(probs_flatten,log_probs_flatten)
			reg_loss[self.regularization] = knn_loss
		elif self.regularization=='global_perplexity':
			global_perplexity = self._global_entropy(probs_flatten,log_probs_flatten).exp()\
									/ probs_flatten.size(0) # NOTE: Normalize b/w 0.0 and 1.0
			reg_loss['neg_global_perplexity'] = -global_perplexity
		return code,reg_loss,stats
		
	@staticmethod
	def _flatten_probs(logits, probs):
		"""
		logits,probs: batch_size x codebook_size x height x width
		"""
		K = logits.size(1)
		probs_flatten = probs.transpose(0,1).reshape(K,-1) # Merge all but codebook_size dimensions.
		log_probs = F.log_softmax(logits.transpose(0,1), dim=0)
		log_probs_flatten = log_probs.view(K,-1)
		entropy = -(probs_flatten*log_probs_flatten).sum(dim=0)
		return probs_flatten,log_probs_flatten,entropy

	@staticmethod
	def _mean_log_probs(log_probs_flatten):
		"""
		(log_)probs_flatten: codebook_size x others
		"""
		log_N = math.log(log_probs_flatten.size(1))
		log_mean_probs = log_probs_flatten.logsumexp(dim=1) - log_N
		return log_mean_probs
	
	def _knn_loss(self, probs_flatten,log_probs_flatten):
		"""
		(log_)probs_flatten: codebook_size x others
		"""
		if 'l2' in self.regularization:
			loss = 1.0 + probs_flatten.pow(2).sum(dim=0,keepdim=True) -2*probs_flatten
		elif 'ce' in self.regularization:
			loss = -log_probs_flatten # NOTE: =cross entropy w/ onehots.
		loss = loss.topk(k=self.num_neighbors,dim=1,largest=False)[0].mean()
		return loss
	
	def _global_entropy(self, probs_flatten, log_probs_flatten):
		"""
		probs_flatten,log_probs_flatten: codebook_size x others
		"""
		mean_probs = probs_flatten.mean(dim=1)
		log_mean_probs = self._mean_log_probs(log_probs_flatten)
		global_entropy = -(mean_probs*log_mean_probs).sum()
		return global_entropy
	

class GumbelVAE(SoftmaxVAE):
	def __init__(self, *args, hard=False,
					gumbel_temperature_anneal_rate=None,
					min_gumbel_temperature=None,
					max_gumbel_temperature=None,
					**kwargs):
		super().__init__(*args, **kwargs)
		self.hard = hard
		self.aneal_gumbel_temperature = not (
											min_gumbel_temperature is None
											or max_gumbel_temperature is None
											or gumbel_temperature_anneal_rate is None)
		if self.aneal_gumbel_temperature:
			self.register_buffer('min_gumbel_temperature', torch.tensor(min_gumbel_temperature))
			self.register_buffer('max_gumbel_temperature', torch.tensor(max_gumbel_temperature))
			self.register_buffer('gumbel_temperature_anneal_rate', torch.tensor(gumbel_temperature_anneal_rate))
			self.register_buffer('gumbel_temperature_anneal_count', torch.tensor(0))
			self.register_full_backward_hook(self.update_gumbel_anneal_count)

	@staticmethod
	def update_gumbel_anneal_count(module, grad_input, grad_output):
		module.gumbel_temperature_anneal_count.add_(1)

	def soft_lookup(self, logits):
		gumbel_temperature = 1
		if self.aneal_gumbel_temperature:
			gumbel_temperature = self.max_gumbel_temperature*(
									self.gumbel_temperature_anneal_rate**self.gumbel_temperature_anneal_count)
			gumbel_temperature = gumbel_temperature.maximum(self.min_gumbel_temperature)
		return F.gumbel_softmax(logits, dim=1, hard=self.hard, tau=gumbel_temperature)
	

class VQVAE(_Base):
	def __init__(self, *args, grad_estimator='straight-through', **kwargs):
		super().__init__(*args, **kwargs)
		assert grad_estimator in ['straight-through', 'rotate'], 'grad_estimator must be straight-through or rotate.'
		self.grad_estimator = grad_estimator

	def quantize(self, feature):
		quantized,code,reg_loss,stats = super().quantize(feature)
		if self.metric=='euclid':
			reg_loss['code_location_loss_l2'] = F.mse_loss(quantized,feature.detach(),reduction='mean')
		else:
			reg_loss['code_location_loss_dot'] = -(quantized*feature.detach()).sum(dim=1).mean()
		if self.grad_estimator=='straight-through':
			quantized = quantized + feature - feature.detach()
		elif self.grad_estimator=='rotate':
			if self.normalize:
				q_normalized = quantized.detach()
				c_normalized = feature.detach()
				scale = 1.0
			else:
				q_norm = torch.linalg.vector_norm(quantized.detach(), ord=2, dim=1, keepdim=True)
				q_normalized = quantized.detach()/q_norm.clamp_min(1e-12)
				c_norm = torch.linalg.vector_norm(feature.detach(), ord=2, dim=1, keepdim=True)
				c_normalized = feature.detach()/c_norm.clamp_min(1e-12)
				scale = q_norm/c_norm
			r = F.normalize(c_normalized+q_normalized, p=2.0, dim=1)
			quantized = scale*(feature
								- 2*r*(r*feature).sum(dim=1,keepdim=True)
								+ 2*q_normalized*(c_normalized*feature).sum(dim=1,keepdim=True))
		return quantized,code,reg_loss,stats
	
	def lookup(self, feature):
		feat2quant_sim = F.conv2d(feature, self.codebook.detach())
		if self.metric=='euclid':
			cont_sq_norm = feature.pow(2).sum(dim=1, keepdim=True)
			code_sq_norm = self.codebook.pow(2).sum(dim=1, keepdim=False)
			feat2quant_sim = 2*feat2quant_sim - cont_sq_norm - code_sq_norm.detach()
		else:
			feat2quant_sim = 2*feat2quant_sim - 2.0
		code,max_feat2quant_sim = self.hard_lookup(feat2quant_sim)
		reg_loss = {
				'commitment_loss_l2' if self.metric=='euclid'
				else 'commitment_loss_dot':
				-max_feat2quant_sim.mean(),
				}
		code_freq,code_entropy,code_coverage = self.code_coverage(code)
		stats = dict(code_freq=code_freq,code_entropy=code_entropy,code_coverage=code_coverage)
		return code,reg_loss,stats

class Dummy(nn.Module):
	"""
	No quantization. For topline benchmarking of reconstruction.
	"""
	def forward(self, x):
		return x,None,None,None