from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
#import torchnet
import numpy as np
import plots_pdf
import plots_ae
from scipy import io, stats
import torch.optim.lr_scheduler
import pandas as pd

parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=500, metavar='N',
                    help='number of epochs to train (default: 800)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--save-interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument("--dataset", default="mix2_pn",
                    help="name of dataset(mix_pn, mix2_pn, ramp_pn) ")
#parser.add_argument("--datapath", default="../../dataset/",
#                    help="path of dataset dir) ")
parser.add_argument("--datapath", default="./",
                    help="path of dataset dir) ")

parser.add_argument("--loss1", default="mse",)
parser.add_argument("--loss2", default="mse",)
parser.add_argument('--lambda1', type=float, default=3000)
parser.add_argument('--lambda2', type=float, default=1000)

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if args.cuda else "cpu")
print('cuda' if args.cuda else 'cpu')

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

bottom=3


if args.dataset == 'ramp3_pn':
    datadim=16
    layer1 = 128
    layer2 = 64
    print('open ramp3_pn')
    f = open(args.datapath +'ramp3_pn.csv')
elif args.dataset == 'norm3_pn':
    datadim=16
    layer1 = 128
    layer2 = 64
    print('open norm3_pn')
    f = open(args.datapath +'norm3_pn.csv')
elif args.dataset == 'mix3_pn':
    datadim=16
    layer1 = 128
    layer2 = 64
    print('open mix3_pn')
    f = open(args.datapath +'mix3_pn.csv')
elif args.dataset == 'mix_pn':
    datadim=16
    layer1 = 128
    layer2 = 64
    print('open mix_pn')
    f = open(args.datapath +'mix_pn.csv')
elif args.dataset == 'mix2_pn':
    datadim=16
    layer1 = 128
    layer2 = 64
    print('open mix2_pn')
    f = open(args.datapath +'mix2_pn.csv')
elif args.dataset == 'mix2':
    datadim=3
    layer1 = 24
    layer2 = 12
    print('open mix2')
    f = open(args.datapath +'mix2.csv')
elif args.dataset == 'mix3':
    datadim=3
    layer1 = 24
    layer2 = 12
    print('open mix3')
    f = open(args.datapath +'mix3.csv')
elif args.dataset == 'ramp3':
    datadim=3
    layer1 = 24
    layer2 = 12
    print('open ramp3')
    f = open(args.datapath +'ramp3.csv')
elif args.dataset == 'norm3':
    datadim=3
    layer1 = 24
    layer2 = 12
    print('open norm3')
    f = open(args.datapath +'norm3.csv')
else:
    datadim=16
    layer1 = 128
    layer2 = 64
    print('open toy2')
    f = open('../../dataset/toy2.csv')
line = f.readline() 
X = []
y = []
label = []

while line:
    line = line.strip('\n').split(',')
    x_tmp = line[:datadim]
    X.append([float(n) for n in x_tmp])
    if args.dataset == 'toy7' or args.dataset == 'toy8':
        y.append(float(line[-2]))
    else:
#        y.append(float(line[-1]))
        y.append(float(line[datadim]))
    label.append(1)
    line = f.readline()
f.close
log_step = 1

X = np.array(X,dtype=np.float32)
y = np.array(y,dtype=np.float32)
label = np.array(label,dtype=np.int32)


tensor_X = torch.stack([torch.from_numpy(np.array(i)) for i in X])
tensor_y = torch.stack([torch.from_numpy(np.array(i)) for i in y])

#train_size = 60000

train_dataset = torch.utils.data.TensorDataset(tensor_X, tensor_y)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=True)


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(datadim, layer1)
        self.fc1_2 = nn.Linear(layer1, layer2)

        self.fc21 = nn.Linear(layer2, bottom) #mu
        self.fc22 = nn.Linear(layer2, bottom) #log

        self.fc3 = nn.Linear(bottom, layer2)
        self.fc3_2 = nn.Linear(layer2, layer1)

        self.fc4 = nn.Linear(layer1, datadim)

    def encode(self, x):
        h1_1 = F.tanh(self.fc1(x))
        h1 = F.tanh(self.fc1_2(h1_1))

        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3_1 = F.tanh(self.fc3(z))
        h3 = F.tanh(self.fc3_2(h3_1))

        return self.fc4(h3)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, datadim))
        diffnorm  = torch.zeros_like(mu).to(device) 
        diffnorm2 = torch.zeros_like(mu).to(device) 
        z = self.reparameterize(mu, logvar)

        # Calculate Differential D'j(z)
        for dimnt in range(bottom):
            dz = torch.zeros(bottom).to(device)
            dz_vlalue = 0.01
            dz[dimnt] = dz_vlalue
            zz = mu + dz
            recon = self.decode(mu)
            recon2 = self.decode(zz)
            diffs = (recon-recon2)
            diffnorm[:,dimnt] = torch.norm(diffs, dim=1).pow(2)/(dz_vlalue**2)

        # Calculate Differential D'j(z)*sigma^2
        mu_p = mu +  torch.exp(0.5*logvar)
        mu_m = mu -  torch.exp(0.5*logvar)

        diffnorm2 = diffnorm * logvar.exp()
        return self.decode(mu), self.decode(z), mu, logvar, z, diffnorm, diffnorm2, self.decode(mu_p), self.decode(mu_m)

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1)

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(x_hat, x_breve, x, mu, logvar, xp, xm):
    if args.loss1 == 'logmse':
        BCE = torch.log(F.mse_loss(x_hat, x) + (1e-10))
    elif args.loss1 == 'mse':
        BCE = F.mse_loss(x_hat, x)
        BCE = F.mse_loss(x_hat, x)
    elif args.loss1 == 'sqrtmse':
        BCE = torch.sqrt(F.mse_loss(x_hat, x) + (1e-10))

#    xmse = torch.mean((x_hat- x_breve)**2, dim=1)
    #xmse = torch.sum((x_hat- x_breve)**2, dim=1)
    xmse = torch.sum((x - x_breve)**2, dim=1)
    if args.loss2 == 'mse':
        DetRiemanTensor = torch.ones_like(xmse).to(device)
    elif args.loss2 == 'sl1':
        DetRiemanTensor = 2.0/3.0 + (2.0/21.0) * (torch.norm(x, dim=1)**2)
#        DetRiemanTensor = (1.0+(torch.norm(x_hat, dim=1)**2)/9.0) / 2.0
    elif args.loss2 == 'sl2':
        DetRiemanTensor = 1.0/(2.0/3.0 + (2.0/21.0) * (torch.norm(x, dim=1)**2))
#        DetRiemanTensor = 2.0/(1.0+(torch.norm(x_hat, dim=1)**2)/9.0)
    elif args.loss2 == 'sl1b':
        DetRiemanTensor = 3.0/4.0 + (6.0/21.0)/4.0 * (torch.norm(x, dim=1)**2)
#        DetRiemanTensor = (1.0+(torch.norm(x_hat, dim=1)**2)/9.0) / 2.0
    elif args.loss2 == 'sl2b':
        DetRiemanTensor = 1.0/(3.0/4.0 + (6.0/21.0)/4.0 * (torch.norm(x, dim=1)**2))
#        DetRiemanTensor = 2.0/(1.0+(torch.norm(x_hat, dim=1)**2)/9.0)
    xmse *= DetRiemanTensor
    scale = torch.mean(xmse)

        
    xmse_av = 0.5*(torch.sum((x - xp)**2, dim=1) + torch.sum((x - xm)**2, dim=1))
    xmse_av *= DetRiemanTensor
    scale2 = torch.mean(xmse_av)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_data = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),dim=1)
    KLD = torch.mean(KLD_data)

    TransformLooss= torch.mean(torch.sum((x - x_hat)**2,  dim=1) * DetRiemanTensor)
    CodingLooss= torch.mean(torch.sum((x_breve - x_hat)**2, dim=1)  * DetRiemanTensor)

    return KLD + args.lambda2 * scale, BCE, KLD , scale, torch.sqrt(DetRiemanTensor), KLD_data + args.lambda2 * scale2, TransformLooss, CodingLooss

import os
savedir = os.path.join('results', "VAE30_"+args.dataset+'_'+format(args.lambda1,'06.4f')+'_'+format(args.lambda2, '06.4f')+'_'+args.loss1+'_'+args.loss2)
#if not os.path.exists(args.checkpoint_dir):
if not os.path.exists(savedir):
    # shutil.rmtree(args.checkpoint_dir)
    os.makedirs(savedir)

plot_num = 10000

def train(epoch):
    model.train()
    train_loss = 0
    for e in range(epoch):
        scheduler.step()
        for batch_idx, (data, realpdf) in enumerate(train_loader):
            data = data.to(device)
            #print(data.shape)
            optimizer.zero_grad()

            x_hat, x_breve, mu, logvar, z, diffnorm, diffnorm2, xp, xm = model(data)

            loss, rec, kld, scale,SqDetRT, ELBO, TransLoss, CodeLoss = loss_function(x_hat, x_breve, data, mu, logvar, xp, xm)
            diffnorm *= SqDetRT[:,None]**2
            diffnorm2 *= SqDetRT[:,None]**2

            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            if (e + 1) % args.save_interval == 0:
                if batch_idx == 0:
                    recons = x_hat.cpu().detach().numpy()
                    z_results =  z.cpu().detach().numpy()
                    mu_results =  mu.cpu().detach().numpy()
                    diffnorm_results = diffnorm.cpu().detach().numpy()
                    diffnorm2_results = diffnorm2.cpu().detach().numpy()
                    realpdf_results = realpdf.cpu().detach().numpy()
                    SqDetRT_results = SqDetRT.cpu().detach().numpy()
                    ELBO_results = np.exp(-ELBO.cpu().detach().numpy())
                    ELBO_results2 = np.exp(-ELBO.cpu().detach().numpy()) * ((SqDetRT.cpu().detach().numpy())**3)

                    stds = logvar.mul(0.5).exp_().cpu().detach().numpy()
                    logvar_tmp =  1/(logvar.mul(0.5).exp_().cpu().detach().numpy()+ 1e-6)

                    logvar_tmp_l = np.sum(logvar_tmp/np.sqrt(np.pi), axis=1)/bottom

                    logvar_tmp = np.exp(np.sum(np.log(logvar_tmp/np.sqrt(np.pi)), axis=1)/bottom)
                    logvar_results = logvar_tmp
                    logvar_results_l = logvar_tmp_l

                else:
                    recons = np.concatenate([recons, x_hat.cpu().detach().numpy()])
                    z_results = np.concatenate([z_results, z.cpu().detach().numpy()])
                    mu_results = np.concatenate([mu_results, mu.cpu().detach().numpy()])
                    diffnorm_results = np.concatenate([diffnorm_results,diffnorm.cpu().detach().numpy()])
                    diffnorm2_results = np.concatenate([diffnorm2_results, diffnorm2.cpu().detach().numpy()])
                    realpdf_results =np.concatenate([realpdf_results, realpdf.cpu().detach().numpy()])
                    SqDetRT_results =np.concatenate([SqDetRT_results, SqDetRT.cpu().detach().numpy()])
                    ELBO_results =np.concatenate([ELBO_results, np.exp(-ELBO.cpu().detach().numpy()) ])
                    ELBO_results2 =np.concatenate([ELBO_results2, np.exp(-ELBO.cpu().detach().numpy())* ((SqDetRT.cpu().detach().numpy())**3)])

                    stds = np.concatenate([stds,logvar.mul(0.5).exp_().cpu().detach().numpy()])
                    logvar_tmp =  1/(logvar.mul(0.5).exp_().cpu().detach().numpy()+ 1e-6)
                    logvar_tmp_l = np.sum(logvar_tmp/np.sqrt(np.pi), axis=1)/bottom

                    logvar_tmp = np.exp(np.sum(np.log(logvar_tmp/np.sqrt(np.pi)), axis=1)/bottom)

                    logvar_results = np.concatenate([logvar_results, logvar_tmp])
                    logvar_results_l = np.concatenate([logvar_results_l, logvar_tmp_l])

        if (e + 1) % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f},KLD: {:.6f},TransLoss: {:.6f}, CodeLoss: {:.6f}, RecLoss: {:.6f}'.format(
                epoch, e+1, epoch,
                100. * e / epoch,
                loss.item() ,#/ len(data),
                kld.item() ,#/ len(data),
                TransLoss.item() ,#/ len(data),
                CodeLoss.item() ,#/ len(data),
                scale.item() #/ len(data)
            ))

        if (e+1) % args.save_interval == 0:

            print(z_results.shape)

            '''
            plots_ae.plot1D_AE(savedir, stds, label, str(e + 1) + '_Hist_of_sd', args.dataset)
            plots_ae.plot1D_AE(savedir, diffnorm_results, label, str(e + 1) + '_Hist_of_Dj', args.dataset)
            plots_ae.plot1D_AE(savedir, diffnorm2_results, label, str(e + 1) + '_Hist_of_VarDj', args.dataset)
            '''


            pdfs=[]
            pdfs_mode=[]
            pdfs_mode2=[]

            z_mu = np.mean(z_results, axis=0)
            z_std = np.std(z_results, axis=0)
            print('z_mu', z_mu.shape, z_mu)
            print('z_std', z_mu.shape, z_std)

            for i in range(z_results.shape[0]):
                #print(i)
                pdf = 0
                pdf_m = 0
                pdf_m2 = 0

                for d in range(bottom):
                    #Probability of Prior p(mu_j)
                    pdf += np.log(stats.norm.pdf(mu_results[i][d], loc=0, scale=1))

                    #Probability of p(mu_j)*sigma_j^2
                    pdf_m += np.log(stats.norm.pdf(mu_results[i][d], loc=0, scale=1)* stds[i][d])
                    #pdf_m2 += np.log(stats.norm.pdf(mu_results[i][d], loc=0, scale=1)* stds[i][d]* np.exp(-stds[i][d]*stds[i][d]/2))


                pdf = np.exp(pdf)
                pdf_m2 = np.exp(pdf_m)
                pdf_m = pdf_m2 * (SqDetRT_results[i] ** 3)
#                if args.loss2 == 'sl1d3':
#                    pdf_m = pdf_m2 * (SqDetRT_results[i] ** 3)
#                elif args.loss2 == 'sl2d3':
#                    pdf_m = pdf_m2 * (SqDetRT_results[i] ** 3)
#                else :
#                    pdf_m = pdf_m2 * SqDetRT_results[i]
                pdfs.append(pdf)
                pdfs_mode.append(pdf_m)
                pdfs_mode2.append(pdf_m2)

                #pdfs_mode.append(np.exp(pdf_m))
                #print(stds[:10])

            pdfs = np.array(pdfs)
            pdfs_mode = np.array(pdfs_mode)
            pdfs_mode2 = np.array(pdfs_mode2)

            realpdf_results = np.array(realpdf_results)

            print('pdfs_mode', pdfs_mode[:10])
            print('std', stds[0][:10])

            print(pdfs.shape)
            scat_num = plot_num

            cor1 = plots_pdf.plot2D(savedir, realpdf_results[:scat_num], pdfs[:scat_num],str(e+1)+'_P(x)_vs_P(mu)', args.dataset)
            cor2 = plots_pdf.plot2D(savedir, realpdf_results[:scat_num], pdfs_mode[:scat_num], str(e+1)+'_P(x)_vs_(sqrt(A)_sigma_P(mu))', args.dataset)
            cor3 = plots_pdf.plot2D(savedir, realpdf_results[:scat_num], pdfs_mode2[:scat_num], str(e+1)+'_P(x)_vs_(sigma_P(mu))', args.dataset)
            cor4 = plots_pdf.plot2D(savedir, realpdf_results[:scat_num], ELBO_results[:scat_num], str(e+1)+'_P(x)_vs_exp(ELBO)', args.dataset)
            cor5 = plots_pdf.plot2D(savedir, realpdf_results[:scat_num], ELBO_results2[:scat_num], str(e+1)+'_P(x)_vs_(sqrt(A)_exp(ELBO))', args.dataset)

            stats_file = os.path.join(savedir, 'stats_%s_%s.csv'%(args.dataset, str(e+1)))
            fstats = open(stats_file, 'w')

            fstats.write('VAE30,'+args.dataset+',' + args.loss1+','+args.loss2+',' + format(args.lambda1,'06.4f') + ',' +format(args.lambda2, '06.4f')+',\n')
            fstats.write(', z1, z2, z3,\n')

            val_mean = np.mean(diffnorm_results, axis=0)
            val_mean_mean = np.mean(val_mean)
            val_std = np.std(diffnorm_results, axis=0)
            fstats.write('D\'(z_j) mean, %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            fstats.write('D\'(z_j) std, %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))
            #val_mean /= val_mean_mean
            #val_std /= val_mean_mean
            #fstats.write('Diffs_mean(norm), %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            #fstats.write('Diffs_std(norm), %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))

            diffnorm2_results *= args.lambda2 * 2 #scaled by 2/beta = 2*lambda2
            val_mean = np.mean(diffnorm2_results, axis=0)
            #val_mean_mean = np.mean(val_mean)
            val_std = np.std(diffnorm2_results, axis=0)
            fstats.write('sigma_j^2*D\'(z_j)*2/beta mean, %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            fstats.write('sigma_j^2*D\'(z_j)*2/beta std, %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))
            #val_mean /= val_mean_mean
            #val_std /= val_mean_mean
            #fstats.write('Diffs2_mean(norm), %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            #fstats.write('Diffs2_std(norm), %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))

            '''
            val_mean = np.mean(stds, axis=0)
            std_min = np.min(val_mean)
            val_std = np.std(stds, axis=0)
            fstats.write('std_mean, %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            fstats.write('std_std, %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))
            val_mean /= std_min
            val_std /= std_min
            fstats.write('std_mean(norm), %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            fstats.write('std_std(norm), %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))

            val_mean = np.mean(np.reciprocal(stds), axis=0)
            std_min = np.min(val_mean)
            val_std = np.std(np.reciprocal(stds), axis=0)
            fstats.write('std^(-1)_mean, %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            fstats.write('std^(-1)_std, %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))
            val_mean /= std_min
            val_std /= std_min
            fstats.write('std^(-1)_mean(norm), %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            fstats.write('std^(-1)_std(norm), %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))
            '''

            #val_mean = np.sqrt(np.mean(np.square(np.reciprocal(stds)), axis=0))
            val_mean = np.mean(np.square(np.reciprocal(stds)), axis=0)
            std_min = np.min(val_mean)
            val_std = np.std( np.square(np.reciprocal(stds)), axis=0)
            fstats.write('sigma^(-2) mean, %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            fstats.write('sigma^(-2) std, %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))
            val_mean /= std_min
            val_std /= std_min
            fstats.write('sigma^(-2) mean(norm), %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            fstats.write('sigma^(-2) std(norm), %s, %s, %s, \n'%(str(val_std[0]), str(val_std[1]), str(val_std[2])))

            #val_mean = np.sqrt(np.mean(np.square(np.reciprocal(stds)), axis=0))
            val_mean = np.mean(np.square(np.reciprocal(stds)), axis=0)
            val_mean -= (0.3989*np.mean(np.sign(mu_results)* np.reciprocal(stds), axis=0))**2
            std_min = np.min(val_mean)
            fstats.write('sigma^(-2) mod mean, %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            val_mean /= std_min
            fstats.write('sigma^(-2) mod mean(norm), %s, %s, %s, \n'%(str(val_mean[0]), str(val_mean[1]), str(val_mean[2])))
            fstats.write('Correlation, %s, %s, %s, %s, %s, \n'%(str(cor1), str(cor2), str(cor3), str(cor4), str(cor5)))
            fstats.write('Loss, {:.6f}, {:.6f}, {:.6f}, {:.6f},  {:.6f}\n'.format(
                loss.item() ,#/ len(data),
                kld.item() ,#/ len(data),
                TransLoss.item() ,#/ len(data),
                CodeLoss.item() ,#/ len(data),
                scale.item() #/ len(data)
            ))
            fstats.close()



    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

if __name__ == "__main__":
    train(args.epochs)
    '''
    for epoch in range(1, args.epochs + 1):
        train(epoch)
        #test(epoch)

        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png')
    '''
