"""solver.py"""


import torch
# torch.cuda.set_device(0)
import warnings
warnings.filterwarnings("ignore")

import os, csv
from tqdm import tqdm
import visdom
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image

from utils import cuda, grid2gif
from model64 import BetaVAE_H
from dataloader import load_data
from P_PID import PIDControl
import random
import numpy as np
import matplotlib.pyplot as plt

from data_swap import return_data as supervised_dataloader
from data_swap import return_semi_data as semi_supervised_dataloader
from data_swap import return_test_data as unsupervised_dataloader
from data_swap import change_latent_space
from data_swap import return_attack_data as attack_loader


def reconstruction_loss(x, x_recon, distribution):
    batch_size = x.size(0)
    assert batch_size != 0

    if distribution == 'bernoulli':
        recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size)
    elif distribution == 'gaussian':
        x_recon = F.sigmoid(x_recon)
        recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
    else:
        recon_loss = None

    return recon_loss


def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    dimension_wise_kld = klds.mean(0)
    mean_kld = klds.mean(1).mean(0, True)

    return total_kld, dimension_wise_kld, mean_kld


## For swapped structure
# def compute_kl(z1,z2,logvar1,logvar2):
#     var1 = logvar1.exp()
#     var2 = logvar2.exp()
#     return var1/var2 + torch.square(z2-z1)/var2 - 1 + logvar2 - logvar1

def compute_kl(mu1,mu2,logvar1,logvar2):
    var1 = logvar1.exp()
    var2 = logvar2.exp()
    return -0.5*(1 + logvar1 - logvar2 - var1/var2 - torch.square(mu1 - mu2)/var2)


def reparameterize(mu,logvar):
    std = logvar.div(2).exp()
    eps = Variable(std.data.new(std.size())).normal_()
    return mu + std*eps


class DataGather(object):
    def __init__(self):
        self.data = self.get_empty_data_dict()

    def get_empty_data_dict(self):
        return dict(iter=[],
                    recon_loss=[],
                    total_kld=[],
                    dim_wise_kld=[],
                    mean_kld=[],
                    mu=[],
                    var=[],
                    images=[], beta=[])

    def insert(self, **kwargs):
        for key in kwargs:
            self.data[key].append(kwargs[key])

    def flush(self):
        self.data = self.get_empty_data_dict()


class Solver(object):
    def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        self.max_iter = args.max_iter
        self.global_iter = 0

        self.z_dim = args.z_dim
        self.beta = args.beta
        self.gamma = args.gamma
        self.C_max = args.C_max
        self.C_max_org = args.C_max
        self.C_stop_iter = args.C_stop_iter
        self.objective = args.objective
        self.model = args.model
        self.dataset = args.dataset.lower()
        self.lr = args.lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        # self.KL_loss = args.KL_loss
        self.pid_fixed = args.pid_fixed
        self.is_PID = args.is_PID
        self.step_value = args.step_val
        self.C_start = args.C_start

        ## Swap
        attack_types = ['all', 'elastic', 'g_blur', 'g_noise', 'splatter', 'sticker']
        
        self.warm_up = args.warmup
        self.weight = args.compare_weight
        if args.train:
            self.data_generator = semi_supervised_dataloader(args)
        self.test_loader = unsupervised_dataloader(args)
        if args.dataset.lower() in attack_types:
            self.attack_loader = attack_loader(args)
        self.threshold = args.threshold

        if args.dataset.lower() == 'dsprites':
            self.nc = 1
            self.decoder_dist = 'bernoulli'
        elif args.dataset.lower() == 'traffic' or args.dataset.lower() == 'stopsign' or args.dataset.lower() == 'attack' or args.dataset.lower() in attack_types:
            self.nc = 3
            self.decoder_dist = 'gaussian'
        elif args.dataset.lower() == 'celeba':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        elif args.dataset.lower() == '3d_shapes':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        elif args.dataset.lower() == '3d_chairs':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        
        else:
            raise NotImplementedError

        if args.model == 'H':
            net = BetaVAE_H
        elif args.model == 'B':
            net = BetaVAE_B
        else:
            raise NotImplementedError('only support model H or B')

        self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
        self.optim = optim.Adam(self.net.parameters(), lr=self.lr,
                                    betas=(self.beta1, self.beta2))

        self.viz_name = args.viz_name
        self.viz_port = args.viz_port
        self.viz_on = args.viz_on
        self.win_recon = None
        self.win_beta = None
        self.win_kld = None
        self.win_mu = None
        self.win_var = None
        if self.viz_on:
            self.viz = visdom.Visdom(port=self.viz_port)

        self.ckpt_dir = os.path.join(args.ckpt_dir, args.viz_name)
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir, exist_ok=True)
        self.ckpt_name = args.ckpt_name
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)

        self.save_output = args.save_output
        self.output_dir = os.path.join(args.output_dir, args.viz_name)
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)

        self.gather_step = args.gather_step
        self.display_step = args.display_step
        self.save_step = args.save_step

        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        
        self.gather = DataGather()
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        
    def train(self):
        self.net_mode(train=True)
        self.C_max = Variable(cuda(torch.FloatTensor([self.C_max]), self.use_cuda))
        out = False
        
        pbar = tqdm(total=self.max_iter)
        pbar.update(self.global_iter)
        ## write log to log file
        outfile = os.path.join(self.ckpt_dir, "train.log")
        kl_file = os.path.join(self.ckpt_dir, "train.kl")
        fw_log = open(outfile, "w")
        fw_kl = open(kl_file, "w")
        # fw_kl.write('total KL\tz_dim' + '\n')

        ## init PID control
        PID = PIDControl()
        Kp = 0.01
        Ki = -0.001
        Kd = 0.0
        C = 0.5
        period = 5000
        fw_log.write("Kp:{0:.5f} Ki: {1:.6f} C_iter:{2:.1f} period:{3} step_val:{4:.4f}\n" \
                    .format(Kp, Ki, self.C_stop_iter, period,self.step_value))
        
        
        while not out:
            for x, label in self.data_generator:
                self.global_iter += 1
                pbar.update(1)
                
                x1 = Variable(cuda(x[:,0,:,:,:], self.use_cuda))
                x2 = Variable(cuda(x[:,1,:,:,:], self.use_cuda))

                if self.nc == 3:
                    x1 = x1 / 255.
                    x2 = x2 / 255.

                ## Swap
                swapped = False
                unchanged_dims = 1

                recon_x1, mu1, logvar1, z1 = self.net(x1)
                recon_x2, mu2, logvar2, z2 = self.net(x2)

                # print(x1.shape)
                # print(mu1.shape)
                # input()

                total_kld_1, dim_wise_kld_1, mean_kld_1 = kl_divergence(mu1, logvar1)
                total_kld_2, dim_wise_kld_2, mean_kld_2 = kl_divergence(mu2, logvar2)
                total_kld = (total_kld_1 + total_kld_2) / 2.
                dim_wise_kld = (dim_wise_kld_1 + dim_wise_kld_2) / 2.
                mean_kld = (mean_kld_1 + mean_kld_2) / 2.

                recon_loss_1 = reconstruction_loss(x1, recon_x1, self.decoder_dist)
                recon_loss_2 = reconstruction_loss(x2, recon_x2, self.decoder_dist)
                recon_loss = (recon_loss_1 + recon_loss_2) / 2.

                kl_per_point = compute_kl(mu1, mu2, logvar1, logvar2) # shape = [batch, dim_z]

                ## Filter by threshold
                indices = Variable(cuda(torch.range(0, dim_wise_kld.shape[0] - 1, dtype=int), self.use_cuda))
                mask = Variable(cuda(dim_wise_kld >= self.threshold, self.use_cuda))
                disentangled_dims = indices[mask]
                # print(dim_wise_kld)
                # print(mask)
                # input()

                # disentangled_dims = []
                # dim_kl = dim_wise_kld.data.cpu().numpy()
                # for i in range(len(dim_kl)):
                #     if dim_kl[i] >= threshold:
                #         disentangled_dims.append(i)

                # if len(disentangled_dims) <= 3:
                #     print("Disentangled:",disentangled_dims)
                #     print("Iter:", self.global_iter)
                #     input()

                # Do swap if disentangled dimensions are beyond unchanged factors
                # Swap starts at about iter=2000
                if len(disentangled_dims) > unchanged_dims and self.global_iter > self.warm_up:
                # if len(disentangled_dims) > unchanged_dims and label >= 1:
                    swapped = True

                    mu1_new, logvar1_new, mu2_new, logvar2_new = change_latent_space(mu1, mu2, logvar1, logvar2, kl_per_point, disentangled_dims=disentangled_dims, unchanged_latent_indices=unchanged_dims)
                    z1_new = reparameterize(mu1_new, logvar1_new)
                    z2_new = reparameterize(mu2_new, logvar2_new)
                    recon_x1_new = self.net._decode(z1_new)
                    recon_x2_new = self.net._decode(z2_new)

                    recon_loss_1_new = reconstruction_loss(x1, recon_x1_new, self.decoder_dist)
                    recon_loss_2_new = reconstruction_loss(x2, recon_x2_new, self.decoder_dist)


                if self.is_PID and self.objective == 'H':
                    if self.global_iter%period==0:
                        C += self.step_value
                    if C > self.C_max_org:
                        C = self.C_max_org
                    ## dynamic pid
                    self.beta, _ = PID.pid(C, total_kld.item(), Kp, Ki, Kd)
                
                if self.objective == 'H':
                    loss1 = recon_loss_1 + self.beta * total_kld_1
                    loss2 = recon_loss_2 + self.beta * total_kld_2
                    if swapped:
                        loss3 = recon_loss_1_new + recon_loss_2_new
                        beta_vae_loss = loss1 + loss2 + self.weight * loss3
                    else:
                        beta_vae_loss = loss1 + loss2
                        # print('no_swap.')

                elif self.objective == 'B':
                    ### tricks for C
                    C = torch.clamp(self.C_max/self.C_stop_iter*self.global_iter, self.C_start, self.C_max.data[0])
                    beta_vae_loss = recon_loss + self.gamma*(total_kld-C).abs()

                self.optim.zero_grad()
                beta_vae_loss.backward()
                self.optim.step()

                if self.global_iter%20 == 0:
                    ## write log to file
                    if self.objective == 'B':
                        C = C.item()
                    fw_log.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} exp_kld:{:.3f} beta:{:.4f}\n'.format(
                                self.global_iter, recon_loss.item(), total_kld.item(), C, self.beta))
                    ## write KL to file
                    dim_kl = dim_wise_kld.data.cpu().numpy()
                    dim_kl = [str(k) for k in dim_kl]
                    fw_kl.write('total_kld:{0:.3f}\t'.format(total_kld.item()))
                    fw_kl.write('z_dim:' + ','.join(dim_kl) + '\n')
                    
                    if self.global_iter%500 == 0:
                        fw_log.flush()
                        fw_kl.flush()
                    
                if self.viz_on and self.global_iter % self.gather_step==0:
                    self.gather.insert(images=x1.data)
                    self.gather.insert(images=F.sigmoid(recon_x1).data)
                    # self.viz_reconstruction()
                    self.gather.flush()

                    self.gather.insert(images=x2.data)
                    self.gather.insert(images=F.sigmoid(recon_x2).data)
                    # self.viz_reconstruction()
                    self.gather.flush()

                if (self.viz_on or self.save_output) and self.global_iter%150000==0:
                    # self.viz_traverse(limit=2, inter=0.5)
                    self.viz_traverse(limit=3, inter=2/3, loc=-1)

                if self.global_iter % self.save_step == 0:
                    self.save_checkpoint('last')
                    pbar.write('Saved checkpoint(iter:{})'.format(self.global_iter))
                    
                # if self.global_iter % 50000 == 0:
                #     self.save_checkpoint(str(self.global_iter))

                if self.global_iter >= self.max_iter:
                    out = True
                    break
                    
        pbar.write("[Training Finished]")
        pbar.close()
        fw_log.close()
        

    def viz_reconstruction(self):
        self.net_mode(train=False)
        x = self.gather.data['images'][0][:100]
        x = make_grid(x, normalize=True)
        x_recon = self.gather.data['images'][1][:100]
        x_recon = make_grid(x_recon, normalize=True)
        images = torch.stack([x, x_recon], dim=0).cpu()
        self.viz.images(images, env=self.viz_name+'_reconstruction',
                        opts=dict(title=str(self.global_iter)), nrow=10)
        if self.save_output:
            output_dir = os.path.join(self.output_dir, str(self.global_iter))
            os.makedirs(output_dir, exist_ok=True)
            save_image(tensor=images, fp=os.path.join(output_dir, 'recon.jpg'), pad_value=1)
        self.net_mode(train=True)
    

    '''save embedding z into csv file'''
    def save_z_embedding(self):
        print("write embedding now...")
        self.net_mode(train=False)
        encoder = self.net.encoder
        
        if self.dataset in ['all', 'elastic', 'g_blur', 'g_noise', 'splatter', 'sticker']:
            save_path = os.path.join(self.dset_dir, f'val_z_label_{self.dataset}.csv')
        elif self.dataset == 'train':
            save_path = os.path.join(self.dset_dir, 'train_z_label.csv')
        else:
            raise ValueError
        
        with open(save_path, 'w', encoding='UTF8') as f_in:
            writer = csv.writer(f_in)
            headers = ['z','class_label','shape_label', 'color_label']
            writer = csv.DictWriter(f_in, fieldnames=headers)
            writer.writeheader()
            # writer.writerow(header)
            for x, class_label, shape_label, color_label in self.attack_loader:
            # for x, class_label, shape_label, color_label in self.test_loader:
                x = Variable(x).to(self.device)
                # print(shape_label.numpy())
                # input()
                shape_label = shape_label.numpy()[0]
                color_label = color_label.numpy()[0]
                class_label = class_label.numpy()[0]
                z = encoder(x)[:, :self.z_dim]
                z_vec = z.data.cpu().numpy()
                z_vec = z_vec[0].tolist()
                writer.writerow({'z':z_vec,'class_label':class_label,'shape_label':shape_label,'color_label':color_label})
        print("***---save embedding to file---***")
        

    def viz_traverse(self, limit=3, inter=2/3, loc=-1):
        self.net_mode(train=False)
        decoder = self.net.decoder
        encoder = self.net.encoder
        interpolation = torch.arange(-limit, limit+0.1, inter)
        
        n_dsets = len(self.test_loader.dataset)
        rand_idx = random.randint(1, n_dsets-1)

        if self.dataset.lower() == 'dsprites':
            random_img = self.test_loader.dataset.__getitem__(rand_idx)[0]
            random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
            random_img_z = encoder(random_img)[:, :self.z_dim]
            random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)
            
            fixed_idx1 = 87040 # square
            fixed_idx2 = 332800 # ellipse
            fixed_idx3 = 578560 # heart

            fixed_img1 = self.test_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.test_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.test_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
            
            Z = {'fixed_square':fixed_img_z1,
                 'fixed_ellipse':fixed_img_z2,
                 'fixed_heart':fixed_img_z3,
                 'random_img':random_img_z}
            
        elif self.dataset.lower() == 'traffic':
            random_img = self.test_loader.dataset.__getitem__(rand_idx)[0]
            random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
            random_img_z = encoder(random_img)[:, :self.z_dim]
            random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)
            
            fixed_idx1 = 0 # deerCrossing
            fixed_idx2 = 5000 # stop
            fixed_idx3 = 6000

            fixed_img1 = self.test_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.test_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.test_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
            
            Z = {'fixed_deer':fixed_img_z1, 
                 'fixed_stop':fixed_img_z2,
                 'fixed_heart':fixed_img_z3, 
                 'random_img':random_img_z}

        elif self.dataset.lower() == '3d_shapes':
            random_img = self.test_loader.dataset.__getitem__(rand_idx) / 255.
            random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
            random_img_z = encoder(random_img)[:, :self.z_dim]
            random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)
            
            fixed_idx1 = 0
            fixed_idx2 = 45
            fixed_idx3 = 48000 * 1 + 4800 * 2

            fixed_img1 = self.test_loader.dataset.__getitem__(fixed_idx1) / 255.
            # print("Checkpoint 1: Input image:")
            # print(fixed_img1.shape)
            # print(fixed_img1)
            # print("Checkpoint 1: END")
            fixed_img1 = fixed_img1.float()
            fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
            
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.test_loader.dataset.__getitem__(fixed_idx2) / 255.
            fixed_img2 = fixed_img2.float()
            fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.test_loader.dataset.__getitem__(fixed_idx3) / 255.
            fixed_img3 = fixed_img3.float()
            fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
            
            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
                 'fixed_3':fixed_img_z3, 'random_img':random_img_z}
            
        elif self.dataset.lower() == '3d_chairs':
            random_img = self.test_loader.dataset.__getitem__(rand_idx) / 255.
            random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
            random_img_z = encoder(random_img)[:, :self.z_dim]
            random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)
            
            fixed_idx1 = 0
            fixed_idx2 = 62
            fixed_idx3 = 1000

            fixed_img1 = self.test_loader.dataset.__getitem__(fixed_idx1) / 255.
            print(fixed_img1.shape)
            fixed_img1 = fixed_img1.float()
            fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
            
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.test_loader.dataset.__getitem__(fixed_idx2) / 255.
            fixed_img2 = fixed_img2.float()
            fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.test_loader.dataset.__getitem__(fixed_idx3) / 255.
            fixed_img3 = fixed_img3.float()
            fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
            
            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
                 'fixed_3':fixed_img_z3, 'random_img':random_img_z}
        
        else:
            fixed_idx = 0
            fixed_img = self.test_loader.dataset.__getitem__(fixed_idx)[0]

            fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z = encoder(fixed_img)[:, :self.z_dim]
            Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
            
        gifs = []
        for key in Z.keys():
            z_ori = Z[key]
            samples = []
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val  ## row is the z latent variable
                    sample = F.sigmoid(decoder(z)).data
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()
            # print(samples.shape)
            # if self.dataset.lower() == 'dsprites' or self.dataset.lower() == 'traffic':
            #     pass
            # elif self.dataset.lower() == '3d_shapes' or self.dataset.lower() == '3d_chairs':
            #     samples = samples.transpose()
            title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
            
            if self.viz_on:
                self.viz.images(samples, env=self.viz_name+'_traverse',
                                opts=dict(title=title), nrow=len(interpolation))

        if self.save_output:
            output_dir = os.path.join(self.output_dir, str(self.global_iter))
            os.makedirs(output_dir, exist_ok=True)
            gifs = torch.cat(gifs)

            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               fp=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)),
                               nrow=self.z_dim, pad_value=1)

                grid2gif(os.path.join(output_dir, key+'*.jpg'),
                         os.path.join(output_dir, key+'.gif'), delay=10)

        self.net_mode(train=True)

    def net_mode(self, train):
        if not isinstance(train, bool):
            raise('Only bool type is supported. True or False')

        if train:
            self.net.train()
        else:
            self.net.eval()

    def save_checkpoint(self, filename, silent=True):
        model_states = {'net':self.net.state_dict(),}
        optim_states = {'optim':self.optim.state_dict(),}
        win_states = {'recon':self.win_recon,
                      'beta': self.win_beta,
                      'kld':self.win_kld,
                    #   'mu':self.win_mu,
                    #   'var':self.win_var,
                      }
        states = {'iter':self.global_iter,
                  'win_states':win_states,
                  'model_states':model_states,
                  'optim_states':optim_states}

        file_path = os.path.join(self.ckpt_dir, filename)
        with open(file_path, mode='wb+') as f:
            torch.save(states, f)
        if not silent:
            print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter))


    def load_checkpoint(self, filename):
        file_path = os.path.join(self.ckpt_dir, filename)
        if os.path.isfile(file_path):
            checkpoint = torch.load(file_path)
            self.global_iter = checkpoint['iter']
            self.win_recon = checkpoint['win_states']['recon']
            self.win_kld = checkpoint['win_states']['kld']
            # self.win_var = checkpoint['win_states']['var']
            # self.win_mu = checkpoint['win_states']['mu']
            self.net.load_state_dict(checkpoint['model_states']['net'])
            self.optim.load_state_dict(checkpoint['optim_states']['optim'])
            print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
        else:
            print("=> no checkpoint found at '{}'".format(file_path))
        