import torch.nn as nn
import torchinfo
from model.Conv import *

class DSAD(nn.Module):
	def __init__(self, channel, rep_dim):
		super(DSAD,self).__init__()
		
		if channel ==1:
			self.enc = MNISTDSADEncoder(channel, rep_dim)
		else:
			self.enc = CIFARDSADEncoder(channel, rep_dim, weight_init=True)


	def forward(self, x):
		latent_vector = self.enc(x)
		return latent_vector

if __name__=="__main__":
	model=DSAD(1)
	print(torchinfo.summary(model, (2,1,28,28), device="cpu", verbose=0))
	model=DSAD(3)
	print(torchinfo.summary(model, (2,3,32,32), device="cpu", verbose=0))
