import torch.nn as nn
import torchinfo
from model.Conv import *

class Autoencoder(nn.Module):
	def __init__(self, channel):
		super(Autoencoder,self).__init__()
		
		if channel ==1:
			self.enc = MNISTEncoder(channel)
			self.dec = MNISTDecoder(channel)
			self.c=64
			self.h = self.w = 3
		else:
			self.enc = CIFAREncoder(channel)
			self.dec = CIFARDecoder(channel)
			self.c=256
			self.h = self.w = 2
		self.fea_dim = self.c*self.h*self.w

	def forward(self, x):
		latent_vector=self.enc(x)
		output=self.dec(latent_vector)
		latent_vector = latent_vector.view(x.size()[0], -1)
		return {'output': output, 'latent': latent_vector}


class DSADAutoencoder(nn.Module):
	def __init__(self, channel, rep_dim, weight_init=False):
		super(DSADAutoencoder,self).__init__()
		
		if channel ==1:
			self.enc = MNISTDSADEncoder(channel, rep_dim)
			self.dec = MNISTDSADDecoder(channel, rep_dim)
		else:
			self.enc = CIFARDSADEncoder(channel, rep_dim, weight_init=weight_init)
			self.dec = CIFARDSADDecoder(channel, rep_dim, weight_init=weight_init)

	def forward(self, x):
		latent_vector=self.enc(x)
		output=self.dec(latent_vector)
		return {'output': output, 'latent': latent_vector}

if __name__=="__main__":
	model=Autoencoder(1)
	print(torchinfo.summary(model, (2,1,28,28), device="cpu", verbose=0))
	model=Autoencoder(3)
	print(torchinfo.summary(model, (2,3,32,32), device="cpu", verbose=0))
