import torch.nn as nn
import torch
import torch.nn.functional as F

def weight_init_xavier_uniform(submodule):
	if isinstance(submodule, torch.nn.Conv2d):
		nn.init.xavier_uniform_(submodule.weight, gain=nn.init.calculate_gain('leaky_relu'))
		if submodule.bias!= None:
			submodule.bias.data.fill_(0.01)
	elif isinstance(submodule, torch.nn.ConvTranspose2d):
		nn.init.xavier_uniform_(submodule.weight, gain=nn.init.calculate_gain('leaky_relu'))
		if submodule.bias!= None:
			submodule.bias.data.fill_(0.01)

class MNISTEncoder(nn.Module):
	def __init__(self, channel):
		super(MNISTEncoder,self).__init__()
		ofm=[16, 32, 64]
		self.enc = nn.Sequential(
								nn.Conv2d(in_channels=channel, out_channels=ofm[0], kernel_size=3, stride=2, padding=1),
								nn.BatchNorm2d(ofm[0],affine=True),
								nn.LeakyReLU(inplace=True),
								nn.Conv2d(in_channels=ofm[0], out_channels=ofm[1], kernel_size=3, stride=2, padding=1),
								nn.BatchNorm2d(ofm[1],affine=True),
								nn.LeakyReLU(inplace=True),
								nn.Conv2d(in_channels=ofm[1], out_channels=ofm[2], kernel_size=3, stride=3, padding=1),
								nn.BatchNorm2d(ofm[2],affine=True),
								nn.LeakyReLU(inplace=True)
						)
	def forward(self, x):
		return self.enc(x)




class MNISTDecoder(nn.Module):
	def __init__(self, channel):
		super(MNISTDecoder,self).__init__()
		ofm=[16, 32, 64]
		self.dec= nn.Sequential(
								nn.ConvTranspose2d(in_channels=ofm[2], out_channels=ofm[1], kernel_size=3, stride=3, padding=1),
								nn.BatchNorm2d(ofm[1],affine=True),
								nn.LeakyReLU(inplace=True),
								nn.ConvTranspose2d(in_channels=ofm[1], out_channels=ofm[0], kernel_size=3, stride=2, padding=1, output_padding=1),
								nn.BatchNorm2d(ofm[0], affine=True),
								nn.LeakyReLU(inplace=True),
								nn.ConvTranspose2d(in_channels=ofm[0], out_channels=channel, kernel_size=3, stride=2, padding=1, output_padding=1),
							)

	def forward(self, x):
		x = self.dec(x)
		return x


class CIFAREncoder(nn.Module):
	def __init__(self, channel):
		super(CIFAREncoder,self).__init__()
		ofm=[64, 128, 128, 256]
		self.enc = nn.Sequential(
								nn.Conv2d(in_channels=channel, out_channels=ofm[0], kernel_size=3, stride=2, padding=1),
								nn.BatchNorm2d(ofm[0]),
								nn.LeakyReLU(inplace=True),
								nn.Conv2d(in_channels=ofm[0], out_channels=ofm[1], kernel_size=3, stride=2, padding=1),
								nn.BatchNorm2d(ofm[1]),
								nn.LeakyReLU(inplace=True),
								nn.Conv2d(in_channels=ofm[1], out_channels=ofm[2], kernel_size=3, stride=2, padding=1),
								nn.BatchNorm2d(ofm[2]),
								nn.LeakyReLU(inplace=True),
								nn.Conv2d(in_channels=ofm[2], out_channels=ofm[3], kernel_size=3, stride=2, padding=1),
								nn.BatchNorm2d(ofm[3]),
								nn.LeakyReLU(inplace=True),
							)

	def forward(self, x):
		return self.enc(x)

class CIFARDecoder(nn.Module):
	def __init__(self, channel):
		super(CIFARDecoder,self).__init__()
		ofm=[64, 128, 128, 256]
		self.dec= nn.Sequential(
								nn.ConvTranspose2d(in_channels=ofm[3], out_channels=ofm[2], kernel_size=3, stride=2, padding=1, output_padding=1),
								nn.BatchNorm2d(ofm[2]),
								nn.LeakyReLU(inplace=True),
								nn.ConvTranspose2d(in_channels=ofm[2], out_channels=ofm[1], kernel_size=3, stride=2, padding=1, output_padding=1),
								nn.BatchNorm2d(ofm[1]),
								nn.LeakyReLU(inplace=True),
								nn.ConvTranspose2d(in_channels=ofm[1], out_channels=ofm[0], kernel_size=3, stride=2, padding=1, output_padding=1),
								nn.BatchNorm2d(ofm[0]),
								nn.LeakyReLU(inplace=True),
								nn.ConvTranspose2d(in_channels=ofm[0], out_channels=channel, kernel_size=3, stride=2, padding=1, output_padding=1)
							)

	def forward(self, x):
		x = self.dec(x)
		return x


class MNISTDSADEncoder(nn.Module):
	def __init__(self, channel, rep_dim=32, bias=False):
		super(MNISTDSADEncoder,self).__init__()
		self.rep_dim = rep_dim
		self.pool = nn.MaxPool2d(2,2)

		#Encoder
		self.enc = nn.Sequential(
					nn.Conv2d(channel, 8, 5, padding=2, bias=bias),
					nn.BatchNorm2d(8, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2),
					nn.Conv2d(8, 4, 5, padding=2, bias=bias),
					nn.BatchNorm2d(4, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2)
				)

		self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False)


	def forward(self, x):
		x = self.enc(x)
		x = x.view(x.size(0), -1)
		x = self.fc1(x)
		return x

class MNISTDSADDecoder(nn.Module):
	def __init__(self, channel, rep_dim=32, bias=False):
		super(MNISTDSADDecoder, self).__init__()
		self.rep_dim = rep_dim

		self.bn1d = nn.BatchNorm1d(self.rep_dim, eps=1e-04, affine=False)
		#Decoder
		self.deconv1 = nn.ConvTranspose2d(int(self.rep_dim/16),4,5, bias=bias, padding=2)
		self.bn1 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
		self.deconv2 = nn.ConvTranspose2d(4,8,5, bias=bias, padding=3)
		self.bn2 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
		self.deconv3 = nn.ConvTranspose2d(8,channel,5, bias=bias, padding=2)


	def forward(self,x):
		#Decoder
		x = x.view(x.size(0), int(self.rep_dim/16), 4, 4)
		x = F.interpolate(F.leaky_relu(x), scale_factor=2)
		x = self.deconv1(x)
		x = F.interpolate(F.leaky_relu(self.bn1(x)), scale_factor=2)
		x = self.deconv2(x)
		x = F.interpolate(F.leaky_relu(self.bn2(x)), scale_factor=2)
		x = self.deconv3(x)
		x = torch.sigmoid(x)
		return x

class CIFARDSADEncoder(nn.Module):
	def __init__(self, channel, rep_dim=128, bias=False, weight_init=False):
		super(CIFARDSADEncoder,self).__init__()
		self.rep_dim = rep_dim
		self.pool = nn.MaxPool2d(2,2)

		#Encoder
		self.enc = nn.Sequential(
					nn.Conv2d(channel, 32, 5, padding=2, bias=bias),
					nn.BatchNorm2d(32, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2),
					nn.Conv2d(32, 64, 5, padding=2, bias=bias),
					nn.BatchNorm2d(64, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2),
					nn.Conv2d(64, 128, 5, padding=2, bias=bias),
					nn.BatchNorm2d(128, eps=1e-04, affine=False),
					nn.LeakyReLU(inplace=True),
					nn.MaxPool2d(2,2)
				)
		self.fc1 = nn.Linear(128 * 4 * 4, self.rep_dim, bias=False)

		if weight_init==True:
			self.apply(weight_init_xavier_uniform)

	def forward(self, x):
		x = self.enc(x)
		x = x.view(x.size(0), -1)
		x = self.fc1(x)
		return x

class CIFARDSADDecoder(nn.Module):
	def __init__(self, channel, rep_dim = 128, bias=False, weight_init=False):
		super(CIFARDSADDecoder, self).__init__()
		self.rep_dim = rep_dim
		self.pool = nn.MaxPool2d(2,2)

		self.bn1d = nn.BatchNorm1d(self.rep_dim, eps=1e-04, affine=False)
		#Decoder
		self.deconv1 = nn.ConvTranspose2d(int(self.rep_dim/16),128,5, bias=bias, padding=2)
		self.bn1 = nn.BatchNorm2d(128, eps=1e-04, affine=False)
		self.deconv2 = nn.ConvTranspose2d(128,64,5, bias=bias, padding=2)
		self.bn2 = nn.BatchNorm2d(64, eps=1e-04, affine=False)
		self.deconv3 = nn.ConvTranspose2d(64,32,5, bias=bias, padding=2)
		self.bn3 = nn.BatchNorm2d(32, eps=1e-04, affine=False)
		self.deconv4 = nn.ConvTranspose2d(32,channel,5, bias=bias, padding=2)

		if weight_init==True:
			self.apply(weight_init_xavier_uniform)

	def forward(self,x):
		x = self.bn1d(x)
		#Decoder
		x = x.view(x.size(0), int(self.rep_dim/16), 4, 4)
		x = F.leaky_relu(x)
		x = self.deconv1(x)
		x = F.interpolate(F.leaky_relu(self.bn1(x)), scale_factor=2)
		x = self.deconv2(x)
		x = F.interpolate(F.leaky_relu(self.bn2(x)), scale_factor=2)
		x = self.deconv3(x)
		x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
		x = self.deconv4(x)
		x = torch.sigmoid(x)

		return x

class Discriminator(nn.Module):
	def __init__(self,channel):
		super(Discriminator, self).__init__()
		if channel ==1:
			self.c=64
			self.h = self.w = 3
		else:
			self.c=256
			self.h = self.w = 2
		self.fea_dim = self.c*self.h*self.w

		self.dis = nn.Sequential(
								nn.Linear(self.fea_dim, 512),
								nn.LeakyReLU(inplace=True),
								nn.Linear(512, 256),
								nn.LeakyReLU(inplace=True),
								nn.Linear(256,1),
								nn.Sigmoid()
						)


	def forward(self, x):
		x = x.view(x.size()[0], -1)
		return self.dis(x)


class Discriminator_L(nn.Module):
	def __init__(self, channel, rep_dim):
		super(Discriminator_L, self).__init__()
		self.rep_dim=rep_dim
	
		if channel == 1:
			ofm=[128,64]
			self.dis = nn.Sequential(
									nn.Linear(self.rep_dim, ofm[0], bias=False),
									nn.BatchNorm1d(ofm[0],affine=False, eps=1e-4),
									nn.LeakyReLU(inplace=True),
									nn.Linear(ofm[0], ofm[1],bias=False),
									nn.BatchNorm1d(ofm[1],affine=False, eps=1e-4),
									nn.LeakyReLU(inplace=True),
									nn.Linear(ofm[1], self.rep_dim*2,bias=False),
									nn.BatchNorm1d(self.rep_dim*2, affine=False, eps=1e-4),
									nn.LeakyReLU(inplace=True),
									nn.Linear(self.rep_dim*2,2, bias=False),
							)
		else:
			ofm=[128, 256, 128]
			self.dis = nn.Sequential(
									nn.Linear(self.rep_dim, ofm[0], bias=False),
									nn.BatchNorm1d(ofm[0],affine=False, eps=1e-4),
									nn.LeakyReLU(inplace=True),
									nn.Linear(ofm[0], ofm[1],bias=False),
									nn.BatchNorm1d(ofm[1],affine=False, eps=1e-4),
									nn.LeakyReLU(inplace=True),
									nn.Linear(ofm[1], ofm[2] ,bias=False),
									nn.BatchNorm1d(ofm[2], affine=False, eps=1e-4),
									nn.LeakyReLU(inplace=True),
									nn.Linear(ofm[2], self.rep_dim,bias=False),
									nn.BatchNorm1d(self.rep_dim, affine=False, eps=1e-4),
									nn.LeakyReLU(inplace=True),
									nn.Linear(self.rep_dim,2, bias=False),
							)

	def forward(self, x):
		x = x.view(x.size()[0], -1)
		return self.dis(x)

class Discriminator_S(nn.Module):
	def __init__(self, channel, rep_dim):
		super(Discriminator_S, self).__init__()
		self.channel=channel
		self.rep_dim=rep_dim
		if self.channel==1:
			self.dis = MNISTDSADEncoder(self.channel, self.rep_dim)
		else:
			self.dis = CIFARDSADEncoder(self.channel, self.rep_dim)
		self.linear= nn.Sequential(
									nn.LeakyReLU(inplace=True),
									nn.Linear(self.rep_dim, 2, bias=False),
									)



	def forward(self, x):
		x = self.dis(x)
		return self.linear(x)


class MNIST_Discriminator_S(nn.Module):
	def __init__(self, rep_dim=32,num_class=2):
		super(MNIST_Discriminator_S, self).__init__()

		bias=False
		self.rep_dim = rep_dim
		self.num_class = num_class
		self.pool = nn.MaxPool2d(2, 2)

		self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
		self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
		self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
		self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
		self.fc1 = nn.Linear(196, self.rep_dim, bias=False)
		self.fc2 = nn.Linear(self.rep_dim, self.num_class, bias=False)
		self.sigmoid = nn.Sigmoid()

	def forward(self, x):
		x = x.view(-1, 1, 28, 28)
		x = self.conv1(x)
		x = self.pool(F.leaky_relu(self.bn1(x)))
		x = self.conv2(x)
		x = self.pool(F.leaky_relu(self.bn2(x)))
		x = x.view(int(x.size(0)), -1)
		latent = F.leaky_relu(self.fc1(x))
		out =self.fc2(latent)
		return out

class CIFAR10_Discriminator_S(nn.Module):
	def __init__(self, rep_dim=128,num_class=2):
		super(CIFAR10_Discriminator_S, self).__init__()

		self.rep_dim = rep_dim
		self.num_class = num_class
		self.pool = nn.MaxPool2d(2, 2)

		self.conv1 = nn.Conv2d(3, 32, 5, bias=False, padding=2)
		self.bn2d1 = nn.BatchNorm2d(32, eps=1e-04, affine=False)
		self.conv2 = nn.Conv2d(32, 64, 5, bias=False, padding=2)
		self.bn2d2 = nn.BatchNorm2d(64, eps=1e-04, affine=False)
		self.conv3 = nn.Conv2d(64, 128, 5, bias=False, padding=2)
		self.bn2d3 = nn.BatchNorm2d(128, eps=1e-04, affine=False)
		self.fc1 = nn.Linear(128 * 4 * 4, self.rep_dim, bias=False)
		self.fc2 = nn.Linear(self.rep_dim, self.num_class, bias=False)
		self.sigmoid = nn.Sigmoid()

	def forward(self, x):
		x = x.view(-1, 3, 32, 32)
		x = self.conv1(x)
		x = self.pool(F.leaky_relu(self.bn2d1(x)))
		x = self.conv2(x)
		x = self.pool(F.leaky_relu(self.bn2d2(x)))
		x = self.conv3(x)
		x = self.pool(F.leaky_relu(self.bn2d3(x)))
		x = x.view(int(x.size(0)), -1)
		latent = F.leaky_relu(self.fc1(x))
		out = self.fc2(latent)
		return out



