# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for VAEBM. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
'''Energy networks'''

from . import utils
import math
import torch

from torch import nn
from torch.nn import functional as F
from models.neural_operations import Conv2D
from models.norms import get_act, get_norm


def Lip_swish(x):
	return (x * torch.sigmoid(x))/1.1


def get_timestep_embedding(
	timesteps: torch.Tensor,
	embedding_dim: int,
	flip_sin_to_cos: bool = False,
	downscale_freq_shift: float = 1,
	scale: float = 1,
	max_period: int = 10000,
):
	"""
	This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

	:param timesteps: a 1-D Tensor of N indices, one per batch element.
					  These may be fractional.
	:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
	embeddings. :return: an [N x dim] Tensor of positional embeddings.
	"""
	assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

	half_dim = embedding_dim // 2
	exponent = -math.log(max_period) * torch.arange(
		start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
	)
	exponent = exponent / (half_dim - downscale_freq_shift)

	emb = torch.exp(exponent)
	emb = timesteps[:, None].float() * emb[None, :]

	# scale embeddings
	emb = scale * emb

	# concat sine and cosine embeddings
	emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

	# flip sine and cosine embeddings
	if flip_sin_to_cos:
		emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

	# zero pad
	if embedding_dim % 2 == 1:
		emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
	return emb


class ResBlock(nn.Module):
	def __init__(self, in_channel, out_channel, downsample=False, data_init=True, nonl=None):
		super().__init__()

		self.nonl = get_act(nonl)

		# self.temb_dense = Conv2D(128, out_channel, 1, padding=0, bias=True)

		self.conv1 = Conv2D(
				in_channel,
				out_channel,
				3,
				padding=1,
				bias=True,
				data_init=data_init
			)

		self.conv2 = Conv2D(
				out_channel,
				out_channel,
				3,
				padding=1,
				bias= True,
				data_init=data_init
			)
		
		self.skip = None
		if in_channel != out_channel:
			self.skip = nn.Sequential(
				Conv2D(in_channel, out_channel, 1, bias=False, data_init=data_init))

		self.downsample = downsample


	def forward(self, input, temb=None):
		out = input

		out = self.nonl(out)
		out = self.conv1(out)

		# add timestep embedding
		# if temb is not None:
		# 	out += self.temb_dense(self.nonl(temb))

		out = self.nonl(out)
		out = self.conv2(out)

		if self.skip is not None:
			skip = self.skip(input)
		else:
			skip = input

		out = out + skip

		if self.downsample:
			out = F.avg_pool2d(out, 2)

		return out


class EBM_WideResNet(nn.Module):
	def __init__(self, nc=3, widen_factor=1, data_init=True, dataset=None, nonl=None):
		super().__init__()

		self.nonl = get_act(nonl)

		# self.temb_dense_0 = Conv2D(256, 512, 1, padding=0, bias=True)
		# self.temb_dense_1 = Conv2D(512, 512, 1, padding=0, bias=True)
		# self.temb_dense_2 = Conv2D(512, 512, 1, padding=0, bias=True)

		if dataset == 'CELEBA':

			self.conv1 = Conv2D(nc, 128, 3, padding=1, bias=True, data_init=data_init)

			self.blocks = nn.ModuleList(
				[	
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init, downsample=True),

					ResBlock(widen_factor*128, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init, downsample=True),

					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init, downsample=True),

					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init, downsample=True),

					ResBlock(widen_factor*256, widen_factor*512, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*512, widen_factor*512, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*512, widen_factor*512, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*512, widen_factor*512, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*512, widen_factor*512, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*512, widen_factor*512, nonl=nonl, data_init=data_init, downsample=False),
				]
			)
			self.linear = nn.Linear(512, 1)

		elif dataset == 'CIFAR10':

			self.conv1 = Conv2D(nc, widen_factor*128, 3, padding=1, bias=True, data_init=data_init)

			self.blocks = nn.ModuleList(
				[	
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*128, widen_factor*128, nonl=nonl, data_init=data_init, downsample=True),

					ResBlock(widen_factor*128, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init, downsample=True),

					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init, downsample=True),

					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init),
					ResBlock(widen_factor*256, widen_factor*256, nonl=nonl, data_init=data_init, downsample=False),
				]
			)
			self.linear = nn.Linear(widen_factor*256, 1)

		else: 
			raise NotImplementedError


	def forward(self, input, temb=None):
		
		# if temb is not None:
		# 	temb = get_timestep_embedding(100*temb, 256)[:,:,None,None]
		# 	temb = self.temb_dense_0(temb)
		# 	temb = self.temb_dense_1(self.nonl(temb))
		# 	# print(temb.shape)

		out = self.conv1(input)
		# print(out.shape)

		for block in self.blocks:
			out = block(out, temb)
			# print(out.shape)
			
		out = self.nonl(out)
		out = out.view(out.shape[0], out.shape[1], -1).sum(2)
		# print(out.shape)

		# if temb is not None:
		# 	temb_final = self.temb_dense_2(self.nonl(temb))
		# 	out = (out * temb_final.squeeze()).sum(1)
		# 	# out = self.linear(out * temb_final.squeeze())
		# else:
		# 	out = self.linear(out)
		out = self.linear(out)
		# print(out.shape)

		return out
	

@utils.register_model(name='vaebmwrn')
class MyVAEBM(nn.Module):
	def __init__(self, config):
		super(MyVAEBM, self).__init__()

		self.config = config
		self.data_dim = config.data.channels * config.data.image_size * config.data.image_size

		self.ebm = EBM_WideResNet(dataset=config.data.dataset, nonl=config.model.nonl, widen_factor=config.model.widen_factor)

		for _, m in self.named_modules():
			if isinstance(m, nn.Linear):
				nn.init.normal_(m.weight, 0., .01)
				if m.bias is not None:
					nn.init.zeros_(m.bias)
			elif isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear')
				if m.bias is not None:
					nn.init.zeros_(m.bias)

		self.all_conv_layers = []
		for _, layer in self.named_modules():
			if isinstance(layer, Conv2D):
				self.all_conv_layers.append(layer)
				# nn.utils.parametrizations.spectral_norm(layer)
				# print(layer)

		self.sr_u = {}
		self.sr_v = {}
		self.num_power_iter = 4


	def forward(self, input, temb=None):
		
		if self.config.training.augment_t:
			input, temb = torch.split(input, [self.data_dim, 1], dim=-1)
			# temb = temb.squeeze(-1)

		input = input.reshape(-1, self.config.data.channels, self.config.data.image_size, self.config.data.image_size)

		output = self.ebm(input)

		return output
	
	
	def spectral_norm_parallel(self):
		""" This method computes spectral normalization for all conv layers in parallel. This method should be called
		 after calling the forward method of all the conv layers in each iteration. """

		weights = {}   # a dictionary indexed by the shape of weights
		for l in self.all_conv_layers:
			weight = l.weight_normalized
			weight_mat = weight.view(weight.size(0), -1)
			if weight_mat.shape not in weights:
				weights[weight_mat.shape] = []

			weights[weight_mat.shape].append(weight_mat)

		loss = 0
		for i in weights:
			weights[i] = torch.stack(weights[i], dim=0)
			with torch.no_grad():
				num_iter = self.num_power_iter
				if i not in self.sr_u:
					num_w, row, col = weights[i].shape
					self.sr_u[i] = F.normalize(torch.ones(num_w, row).normal_(0, 1).cuda(), dim=1, eps=1e-3)
					self.sr_v[i] = F.normalize(torch.ones(num_w, col).normal_(0, 1).cuda(), dim=1, eps=1e-3)
					# increase the number of iterations for the first time
					num_iter = 10 * self.num_power_iter

				for j in range(num_iter):
					# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
					# are the first left and right singular vectors.
					# This power iteration produces approximations of `u` and `v`.
					self.sr_v[i] = F.normalize(torch.matmul(self.sr_u[i].unsqueeze(1), weights[i]).squeeze(1),
											   dim=1, eps=1e-3)  # bx1xr * bxrxc --> bx1xc --> bxc
					self.sr_u[i] = F.normalize(torch.matmul(weights[i], self.sr_v[i].unsqueeze(2)).squeeze(2),
											   dim=1, eps=1e-3)  # bxrxc * bxcx1 --> bxrx1  --> bxr

			sigma = torch.matmul(self.sr_u[i].unsqueeze(1), torch.matmul(weights[i], self.sr_v[i].unsqueeze(2)))
			loss += torch.sum(sigma)

		return loss