import torch.nn as nn
import torchinfo
from model.Conv import MNISTEncoder, MNISTDecoder, CIFAREncoder, CIFARDecoder 
from model.mem_module import MemModule

class MemAE(nn.Module):
	def __init__(self, channel, mem_dim=100, shrink_thres=0.0025):
		super(MemAE,self).__init__()
		
		if channel ==1:
			self.enc = MNISTEncoder(channel)
			self.dec = MNISTDecoder(channel)
			self.c=64
			self.h = self.w = 3
		else:
			self.enc = CIFAREncoder(channel)
			self.dec = CIFARDecoder(channel)
			self.c=256
			self.h = self.w = 2
		self.fea_dim = self.c*self.h*self.w
		self.mem_rep = MemModule(mem_dim=mem_dim, fea_dim=self.fea_dim, shrink_thres =shrink_thres)


	def forward(self, x):
		f = self.enc(x)
		latent_vector = f.view(-1, self.fea_dim)
		res_mem = self.mem_rep(latent_vector)
		latent_vector = f = res_mem['output']
		f = f.view(-1, self.c, self.h, self.w)
		att = res_mem['att']
		output = self.dec(f)
		return {'output': output, 'att': att, 'latent': latent_vector}

if __name__=="__main__":
	model=MemAE(1)
	print(torchinfo.summary(model, (2,1,28,28), device="cpu", verbose=0))
	model=MemAE(3)
	print(torchinfo.summary(model, (2,3,32,32), device="cpu", verbose=0))
