import torch.nn as nn
import torchinfo
from model.Conv import *


class MNISTVAEEncoder(nn.Module):
	def __init__(self, channel, rep_dim):
		super(MNISTVAEEncoder,self).__init__()
		self.rep_dim = rep_dim
		self.pool = nn.MaxPool2d(2,2)

		#Encoder
		self.enc = nn.Sequential(
					nn.Conv2d(channel, 8, 5, padding=2),
					nn.BatchNorm2d(8, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2),
					nn.Conv2d(8, 4, 5, padding=2),
					nn.BatchNorm2d(4, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2)
				)

		self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim)
		self.fc2 = nn.Linear(4 * 7 * 7, self.rep_dim)



	def forward(self, x):
		x = self.enc(x)
		x = x.view(x.size(0), -1)
		mu = self.fc1(x)
		logvar = self.fc2(x)
		return mu, logvar

class CIFARVAEEncoder(nn.Module):
	def __init__(self, channel, rep_dim=32):
		super(CIFARVAEEncoder,self).__init__()
		self.rep_dim = rep_dim

		#Encoder
		self.enc = nn.Sequential(
					nn.Conv2d(channel, 32, 5, padding=2),
					nn.BatchNorm2d(32, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2),
					nn.Conv2d(32, 64, 5, padding=2),
					nn.BatchNorm2d(64, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2),
					nn.Conv2d(64, 128, 5, padding=2),
					nn.BatchNorm2d(128, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2)
				)



		#Encoder
		self.fc1 = nn.Linear(128 * 4 * 4, self.rep_dim)
		self.fc2 = nn.Linear(128 * 4 * 4, self.rep_dim)


	def forward(self, x):
		x = self.enc(x)
		x = x.view(x.size(0), -1)
		mu = self.fc1(x)
		logvar = self.fc2(x)
		return mu, logvar



class Estimation(nn.Module):
	def __init__(self, rep_dim, n_gmm=2):
		super(Estimation, self).__init__()
		self.est = nn.Sequential(
								nn.Linear(rep_dim, 10),
								nn.LeakyReLU(0.2,inplace=True),
								nn.Dropout(p=0.5),
								nn.Linear(10, n_gmm),
								nn.Softmax(dim=1)
				)

	def forward(self, x):
		return self.est(x)

class RVAEBFA(nn.Module):
	def __init__(self, channel, rep_dim=32):
		super(RVAEBFA,self).__init__()
		
		self.rep_dim=rep_dim
		self.n_gmm = 4
		self.u = 0.5

		if channel ==1:
			self.enc = MNISTVAEEncoder(channel, rep_dim)
			self.dec = MNISTDSADDecoder(channel, rep_dim, bias=True)
		else:
			self.enc = CIFARVAEEncoder(channel, rep_dim)
			self.dec = CIFARDSADDecoder(channel, rep_dim, bias=True)

		self.est = Estimation(self.rep_dim+2, n_gmm=self.n_gmm)

	def reparameterize(self, mu, logvar):
		std = torch.exp(0.5*logvar)
		eps = torch.randn_like(std)
		return eps*std + mu
	
	def forward(self, x):
		mu, logvar = self.enc(x)
		z = self.reparameterize(mu, logvar)
		x_hat = self.dec(z)

		x = x.view(x.size(0), -1)
		x_hat = x_hat.view(x_hat.size(0), -1)

		cos = F.cosine_similarity(x, x_hat).unsqueeze(-1)
		norm = torch.nn.BatchNorm1d(cos.size(1))
		norm = norm.cuda()
		cos = norm(cos)

		rec = torch.sum((x-x_hat)**2, dim=1).unsqueeze(-1)
		norm = torch.nn.BatchNorm1d(rec.size(1))
		norm = norm.cuda()
		rec = norm(rec)

		gmm_input = torch.cat([(1-self.u)*z, self.u*cos, self.u*rec], dim=1)
		gamma = self.est(gmm_input)

		return {'mu': mu, 
				"logvar": logvar,
				"latent": z,
				"x_hat": x_hat,
				"gmm_input": gmm_input,
				"gamma": gamma
				}

