import os
import time
import numpy as np
import shutil
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import scipy.stats
from torchattacks import FGSM

from utils import AverageMeter, print_log
from wrapper import Wrapper
from metric import accuracy
import CVAE

def get_transformed_batch(x, transform='rotation', labels=None, source_model=None, mean=None, std=None):
    if transform == "rotation":
        transformed_x1 = torch.rot90(x, dims=[2,3])
        transformed_x2 = torch.rot90(transformed_x1, dims=[2,3])
        transformed_x3 = torch.rot90(transformed_x2, dims=[2,3])
        return torch.cat((transformed_x1, transformed_x2, transformed_x3), dim=0)
    elif transform == "adv_attack":
        assert source_model != None
        assert mean != None
        assert std != None
        assert labels != None
        attack = FGSM(source_model, eps=16/255)
        attack.set_normalization_used(mean, std)
        transformed_x1 = attack(x, labels)
        transformed_x2 = x + torch.randn(*(x.shape)).cuda()
        transformed_x3 = x + torch.randn(*(x.shape)).cuda()
        return torch.cat((transformed_x1, transformed_x2, transformed_x3), dim=0)
    elif transform == "random_mask":
        mask = torch.from_numpy(np.random.binomial(1, 0.7, x.shape)).type(x.dtype).to(x.device)
        transformed_x1 = x * mask
        mask = torch.from_numpy(np.random.binomial(1, 0.7, x.shape)).type(x.dtype).to(x.device)
        transformed_x2 = x * mask
        mask = torch.from_numpy(np.random.binomial(1, 0.7, x.shape)).type(x.dtype).to(x.device)
        transformed_x3 = x * mask
        return torch.cat((transformed_x1, transformed_x2, transformed_x3), dim=0)
    else:
        raise NotImplementedError

def train_ssp_head(train_loader, model, log, num_epoch, lr=0.05, momentum=0.9, weight_decay=5e-4,
                   transform='rotation', mean=None, std=None):
    assert isinstance(model, Wrapper) # model must be a wrapper
    optimizer = optim.SGD([{'params':model.backbone.parameters(), 'lr':0.0},
                        {'params':model.proj_head.parameters(), 'lr':lr}],
                        momentum=momentum, weight_decay=weight_decay)
    
    model.backbone.eval()
    model.proj_head.train()
    for epoch in range(num_epoch):
        loss_record = AverageMeter()
        acc_record = AverageMeter()
        batch_time = AverageMeter()

        start = time.time()
        # random.shuffle(batches)
        for x, y in train_loader:
            optimizer.zero_grad()

            x = x.cuda()
            y = y.cuda()
            transformed_x = get_transformed_batch(x, transform=transform, labels=y, source_model=model.backbone, 
                                                  mean=mean, std=std)
            transformed_x = transformed_x.cuda() # (3B, C, H, W)

            b, c, h, w = x.size()
            nor_rep = model(x)[1] # (B, C')
            aug_rep = model(transformed_x)[1] # (3B, C')

            nor_rep = nor_rep.unsqueeze(2).expand(-1,-1,3*b).transpose(0,2) # (3B, C', B)
            aug_rep = aug_rep.unsqueeze(2).expand(-1,-1,1*b) # (3B, C', B)
            simi = F.cosine_similarity(aug_rep, nor_rep, dim=1) # (3B, B)
            target = torch.arange(b).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda() # (3B)
            loss = F.cross_entropy(simi, target)

            loss.backward()
            optimizer.step()

            batch_acc = accuracy(simi, target, topk=(1,))[0]
            loss_record.update(loss.item(), 3*b)
            acc_record.update(batch_acc.item(), 3*b)
            batch_time.update(time.time() - start)
            start = time.time()
        
        print_log('teacher_ssp_train_epoch: {:}/{:}\t batch_time: {:.3f}\t ssp_loss: {:.3f}\t ssp_acc: {:.3f}\t'.format(
            epoch, num_epoch, batch_time.avg, loss_record.avg, acc_record.avg), log)
        
def train_knowledge_distillation(train_loader, val_loader, pruning_manager, log, num_epoch, 
                                 lr=0.05, momentum=0.9, weight_decay=5e-4, kd_T=4.0, tf_T=4.0, ss_T=0.5, 
                                 ce_weight=0.1, kd_weight=0.9, tf_weight=2.7, ss_weight=10.0, 
                                 reg_weight=0.0, reg_layer_idx=None, 
                                 arch=None, prefix=None,
                                 transform='rotation', mean=None, std=None):
    t_model = pruning_manager.original_model
    s_model = pruning_manager.model
    assert isinstance(t_model, Wrapper)
    assert isinstance(s_model, Wrapper)
    # if arch != None and prefix != None and (not os.path.isdir(f"figure/{arch}_{prefix}")):
    #     os.mkdir(f"figure/{arch}_{prefix}")

    optimizer = optim.SGD(s_model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = MultiStepLR(optimizer, milestones=[150,180,210], gamma=0.1)
    train_acc_list = []
    val_acc_list = []
    val_freq = 10

    t_model.eval()
    s_model.train()
    for epoch in range(num_epoch):
        loss1_record = AverageMeter()
        loss2_record = AverageMeter()
        loss3_record = AverageMeter()
        loss4_record = AverageMeter()
        reg_loss_record = AverageMeter()
        cls_acc_record = AverageMeter()
        ssp_acc_record = AverageMeter()
        batch_time = AverageMeter()

        start = time.time()
        # random.shuffle(batches)
        for x, target in train_loader:
            optimizer.zero_grad()

            x = x.cuda()
            target = target.cuda()
            transformed_x = get_transformed_batch(x, transform=transform, labels=target, source_model=t_model.backbone, 
                                                  mean=mean, std=std)
            transformed_x = transformed_x.cuda() # (3B, C, H, W)

            # save intermediate feature maps
            inter_feat = None
            handle = None
            def hook(module, feat_in, feat_out):
                nonlocal inter_feat
                inter_feat = feat_out
            layer_idx = 0
            for m in s_model.backbone.modules():
                if isinstance(m, nn.Conv2d):
                    if layer_idx == reg_layer_idx:
                        handle = m.register_forward_hook(hook)
                        break
                    layer_idx += 1
            
            b, c, h, w = x.size()
            nor_output, s_nor_feat, _ = s_model(x)
            if reg_weight > 0.0:
                assert inter_feat != None
            nor_inter_feat = inter_feat
            aug_output, s_aug_feat, _ = s_model(transformed_x)
            # aug_inter_feat = inter_feat

            log_nor_output = F.log_softmax(nor_output / kd_T, dim=1)
            log_aug_output = F.log_softmax(aug_output / tf_T, dim=1)
            with torch.no_grad():
                t_nor_output, t_nor_feat, _ = t_model(x)
                t_aug_output, t_aug_feat, _ = t_model(transformed_x)
                nor_knowledge = F.softmax(t_nor_output / kd_T, dim=1) # (B, C')
                aug_knowledge = F.softmax(t_aug_output / tf_T, dim=1) # (3B, C')
            
            # error level ranking
            aug_target = target.unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda() # (3B)
            rank = torch.argsort(aug_knowledge, dim=1, descending=True) # (3B, C')
            rank = torch.argmax(torch.eq(rank, aug_target.unsqueeze(1)).long(), dim=1)  # groundtruth label's rank
            index = torch.argsort(rank) # (3B)
            tmp = torch.nonzero(rank, as_tuple=True)[0]
            wrong_num = tmp.numel()
            correct_num = 3*b - wrong_num
            wrong_keep = int(wrong_num * 1.0) # ratio_tf: 1.0
            index = index[:correct_num+wrong_keep]
            distill_index_tf = torch.sort(index)[0] 

            s_nor_feat = s_nor_feat.unsqueeze(2).expand(-1,-1,3*b).transpose(0,2) # (3B, C', B)
            s_aug_feat = s_aug_feat.unsqueeze(2).expand(-1,-1,1*b) # (3B, C', B)
            s_simi = F.cosine_similarity(s_aug_feat, s_nor_feat, dim=1) # (3B, B)
            t_nor_feat = t_nor_feat.unsqueeze(2).expand(-1,-1,3*b).transpose(0,2) # (3B, C', B)
            t_aug_feat = t_aug_feat.unsqueeze(2).expand(-1,-1,1*b) # (3B, C', B)
            t_simi = F.cosine_similarity(t_aug_feat, t_nor_feat, dim=1) # (3B, B)

            t_simi = t_simi.detach()
            aug_target = torch.arange(b).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda()
            rank = torch.argsort(t_simi, dim=1, descending=True)
            rank = torch.argmax(torch.eq(rank, aug_target.unsqueeze(1)).long(), dim=1)  # groundtruth label's rank
            index = torch.argsort(rank)
            tmp = torch.nonzero(rank, as_tuple=True)[0]
            wrong_num = tmp.numel()
            correct_num = 3*b - wrong_num
            wrong_keep = int(wrong_num * 0.75) # ratio_ss: 0.75
            index = index[:correct_num+wrong_keep]
            distill_index_ss = torch.sort(index)[0]

            log_simi = F.log_softmax(s_simi / ss_T, dim=1)
            simi_knowledge = F.softmax(t_simi / ss_T, dim=1)

            # calculate gradient with respect to the intermediate layer
            ## TODO
            if reg_weight > 0:
                one_hot_target = F.one_hot(target, num_classes=1000)
                nor_grad = torch.autograd.grad(torch.sum(nor_output*one_hot_target), nor_inter_feat, 
                                               retain_graph=True, create_graph=True)[0]
                
                # target_2 = target.unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda() # (3B)
                # one_hot_target_2 = F.one_hot(target_2, num_classes=1000)
                # index_sampled = random.sample(range(aug_target.shape[0]), b)
                # sampled_transformed_x = transformed_x[index_sampled]
                # one_hot_target_2 = one_hot_target_2[index_sampled]
                # aug_output_2, _, _ = s_model(sampled_transformed_x)
                # aug_inter_feat = inter_feat
                # aug_grad = torch.autograd.grad(torch.sum(aug_output_2*one_hot_target_2), aug_inter_feat, 
                #                                retain_graph=True, create_graph=True)[0]
                # grad_var = torch.var(nor_grad) + torch.var(aug_grad)
                grad_var = torch.var(nor_grad)
                reg_loss_record.update(grad_var.item())

            loss1 = F.cross_entropy(nor_output, target)
            loss2 = F.kl_div(log_nor_output, nor_knowledge, reduction='batchmean') * kd_T * kd_T
            loss3 = F.kl_div(log_aug_output[distill_index_tf], aug_knowledge[distill_index_tf], \
                            reduction='batchmean') * tf_T * tf_T
            loss4 = F.kl_div(log_simi[distill_index_ss], simi_knowledge[distill_index_ss], \
                            reduction='batchmean') * ss_T * ss_T

            if reg_weight == 0:
                loss = ce_weight * loss1 + kd_weight * loss2 + tf_weight * loss3 + ss_weight * loss4
            else:
                loss = ce_weight * loss1 + kd_weight * loss2 + tf_weight * loss3 + ss_weight * loss4 + reg_weight * grad_var

            loss.backward()
            if reg_weight >= 10000:
                torch.nn.utils.clip_grad_norm_(s_model.parameters(), 10.0)
            pruning_manager.do_grad_mask()
            optimizer.step()

            cls_batch_acc = accuracy(nor_output, target, topk=(1,))[0]
            ssp_batch_acc = accuracy(s_simi, aug_target, topk=(1,))[0]
            loss1_record.update(loss1.item(), b)
            loss2_record.update(loss2.item(), b)
            loss3_record.update(loss3.item(), len(distill_index_tf))
            loss4_record.update(loss4.item(), len(distill_index_ss))
            cls_acc_record.update(cls_batch_acc.item(), b)
            ssp_acc_record.update(ssp_batch_acc.item(), 3*b)
            batch_time.update(time.time() - start)
            start = time.time()

        train_acc_list.append(cls_acc_record.avg)
        if inter_feat == None:
            print_log('student_train_epoch:{:}/{:}\t run_time:{:.3f}\t ce_loss:{:.3f}\t kd_loss:{:.3f}\t tf_loss:{:.3f}\t ss_loss:{:.3f}\t cls_acc:{:.3f}\t ssp_acc:{:.3f}'.format(
            epoch, num_epoch, batch_time.avg, loss1_record.avg, loss2_record.avg, loss3_record.avg, loss4_record.avg, cls_acc_record.avg, ssp_acc_record.avg), log)
        else:
            print_log('student_train_epoch:{:}/{:}\t run_time:{:.3f}\t ce_loss:{:.3f}\t kd_loss:{:.3f}\t tf_loss:{:.3f}\t ss_loss:{:.3f}\t reg_loss:{}\t cls_acc:{:.3f}\t ssp_acc:{:.3f}'.format(
            epoch, num_epoch, batch_time.avg, loss1_record.avg, loss2_record.avg, loss3_record.avg, loss4_record.avg, reg_loss_record.avg, cls_acc_record.avg, ssp_acc_record.avg), log)
        
        if epoch % val_freq == 0:
            acc_record = AverageMeter()
            loss_record = AverageMeter()
            for x, target in val_loader:
                x = x.cuda()
                target = target.cuda()
                with torch.no_grad():
                    output, _, _ = s_model(x)
                    loss = F.cross_entropy(output, target)
                batch_acc = cls_batch_acc = accuracy(output, target, topk=(1,))[0]
                loss_record.update(loss.item(), x.size(0))
                acc_record.update(batch_acc.item(), x.size(0))
            
            val_acc_list.append(acc_record.avg)
            print_log('student_test_Epoch:{:}/{:}\t cls_acc:{:.3f}\t ce_loss:{:.3f}'.format(
                epoch, num_epoch, acc_record.avg, loss_record.avg), log)
    
    # if handle != None:
    #     handle.remove()
    
    # if arch != None and prefix != None:
    #     plt.plot(list(range(0, num_epoch)), train_acc_list)
    #     plt.plot(list(range(0, num_epoch, val_freq)), val_acc_list)
    #     plt.savefig(f"figure/{arch}_{prefix}/accuracy.png")

def kl_loss_fn(z, mean, std):
    dist_p = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std))
    dist_q = torch.distributions.Normal(mean, std)

    log_pz = dist_p.log_prob(z)
    log_qzx = dist_q.log_prob(z)

    kl_loss = (log_qzx - log_pz)
    kl_loss = kl_loss.sum(-1).mean()
    return kl_loss

def gaussian_likelihood(inputs, outputs):
    dist = torch.distributions.Normal(outputs, 1.0)
    log_pxz = dist.log_prob(inputs)
    return log_pxz.sum(dim=-1).mean()

def train_cvae(train_loader, cvae_model, t_model, log, latent_len=1024, num_epoch=600, lr=1e-3, clip=None,
               img_ch = 3, img_h = 128, img_w = 128,
               normalize_mean = [0.485, 0.456, 0.406], normalize_std = [0.229, 0.224, 0.225], figure_name = 'cvae.png'):
    assert isinstance(t_model, Wrapper)
    assert isinstance(cvae_model, CVAE.CVAE)
    optimizer = optim.Adam(cvae_model.parameters(), lr=lr)
    bce_loss_fn = nn.BCELoss(reduction='mean')
    mse_loss_fn = nn.MSELoss(reduction='mean')
    resize = transforms.Resize((img_h, img_w))
    cvae_model.train()
    t_model.eval()
    torch.autograd.set_detect_anomaly(True)

    for epoch in range(num_epoch):
        cvae_model.train()
        ce_loss_record = AverageMeter()
        kl_loss_record = AverageMeter()
        batch_time = AverageMeter()

        start = time.time()
        for x, target in train_loader:
            optimizer.zero_grad()

            batch_size = x.shape[0]
            x = x.cuda()
            denormalize_x = x.clone()
            for i in range(len(normalize_mean)):
                denormalize_x[:,i,:,:] = denormalize_x[:,i,:,:] * normalize_std[i] + normalize_mean[i]
            denormalize_x = resize(denormalize_x)
            x_train_vector = denormalize_x.clone().reshape((denormalize_x.shape[0], -1))
            target = target.cuda()
            target_pred, _, _ = t_model(x)
            target_pred = torch.argmax(target_pred, dim=1)
            cond_vector = F.one_hot(target_pred, num_classes=1000).float().cuda()
            # eps_vector = torch.normal(0, 1, (batch_size, latent_len)).cuda()

            x_decoded, z_mu, z_sigma, z = cvae_model(denormalize_x, cond_vector)
            # ce_loss = bce_loss_fn(x_decoded, x_train_vector)
            # ce_loss = - gaussian_likelihood(x_train_vector, x_decoded)
            ce_loss = mse_loss_fn(x_train_vector, x_decoded)
            # kl_loss = torch.mean(-0.5*torch.sum(1 + z_log_var - torch.square(z_mu) - torch.exp(z_log_var), dim=1))
            kl_loss = kl_loss_fn(z, z_mu, z_sigma)

            loss = ce_loss + kl_loss
            loss.backward()
            if clip != None:
                torch.nn.utils.clip_grad_norm_(cvae_model.parameters(), clip)
            optimizer.step()

            ce_loss_record.update(ce_loss.item())
            kl_loss_record.update(kl_loss.item())
            batch_time.update(time.time() - start)
            start = time.time()
        
        print_log('cvae_train_epoch:{:}/{:}\t run_time:{:.3f}\t ce_loss:{:.3f}\t kl_loss:{:.3f}'.format(
            epoch, num_epoch, batch_time.avg, ce_loss_record.avg, kl_loss_record.avg), log)
        '''print_log('cvae_train_epoch:{:}/{:}\t run_time:{:.3f}\t ce_loss:{:.3f}'.format(
            epoch, num_epoch, batch_time.avg, ce_loss_record.avg), log)'''
        '''print_log('cvae_train_epoch:{:}/{:}\t run_time:{:.3f}\t kl_loss:{:.3f}'.format(
            epoch, num_epoch, batch_time.avg, kl_loss_record.avg), log)'''

        if (epoch + 1) % 100 == 0 or epoch == 0 or ((epoch+1) % 10 == 0 and epoch < 100):
            # evaluation and visualization
            cvae_model.eval()
            sample_batch_idx = random.randint(0, len(train_loader)-2)
            sample_batch, sample_target = None, None
            for idx, (batch, target) in enumerate(train_loader):
                if idx == sample_batch_idx:
                    sample_batch, sample_target = batch, target
                    break
            # img_ch, img_h, img_w = sample_batch.shape[1], sample_batch.shape[2], sample_batch.shape[3]
            sample_x = sample_batch[:10].clone().cuda()
            sample_pred, _, _ = t_model(sample_x)
            for i in range(img_ch):
                sample_x[:,i,:,:] = sample_x[:,i,:,:] * normalize_std[i] + normalize_mean[i]
            sample_x = resize(sample_x)
            sample_ycond = F.one_hot(torch.argmax(sample_pred, dim=1), num_classes=1000).float().cuda()
            sample_x_decoded = cvae_model(sample_x, sample_ycond)[0]
            
            # z_random = torch.rand((10, latent_len))
            z_random = torch.from_numpy(scipy.stats.norm.ppf(np.random.rand(10, latent_len))).float()
            sample_y_vector = F.one_hot(sample_target[:10], num_classes=1000).float()
            sample_x_encoded = torch.cat((z_random, sample_y_vector), dim=1).cuda() # (10, n_classes + latent_len)
            sample_x_generated = cvae_model.decoder(sample_x_encoded)

            sample_x = sample_x.clone().detach().cpu().numpy()
            sample_x_decoded = sample_x_decoded.clone().detach().cpu().numpy().reshape((10, img_ch, img_h, img_w))
            sample_x_generated = sample_x_generated.clone().detach().cpu().numpy().reshape((10, img_ch, img_h, img_w))
            
            sample_x = np.clip(sample_x, 0, 1)
            sample_x_decoded = np.clip(sample_x_decoded, 0, 1)
            sample_x_generated = np.clip(sample_x_generated, 0, 1)
            # BCHW -> BHWC
            sample_x = np.transpose(sample_x, (0, 2, 3, 1))
            sample_x_decoded = np.transpose(sample_x_decoded, (0, 2, 3, 1))
            sample_x_generated = np.transpose(sample_x_generated, (0, 2, 3, 1))

            plt.figure(figsize=(20, 6))
            for idx in range(10):
                # plot original images
                plt.subplot(3, 10, idx + 1)
                plt.axis('off')
                plt.imshow(sample_x[idx])
                # plot reconstructed images
                plt.subplot(3, 10, idx + 1 + 10)
                plt.axis('off')
                plt.imshow(sample_x_decoded[idx])
                # plot generated images
                plt.subplot(3, 10, idx + 1 + 10 + 10)
                plt.axis('off')
                plt.imshow(sample_x_generated[idx])
            plt.suptitle("first row: original images, second row: reconstructed images, third row: generated images")
            plt.savefig(f"./figure/epoch{epoch}_{figure_name}")
            plt.close()

def train_cvae_2(train_loader, cvae_model, t_model, log, 
                 latent_len=1024, num_epoch=600, lr=1e-3, clip=None,
               img_ch = 3, img_h = 128, img_w = 128, use_gt_labels=False,
               normalize_mean = [0.485, 0.456, 0.406], normalize_std = [0.229, 0.224, 0.225], figure_name = 'cvae.png'):
    assert isinstance(t_model, Wrapper)
    assert isinstance(cvae_model, CVAE.ConditionalVAE)
    optimizer = optim.Adam(cvae_model.parameters(), lr=lr)
    resize = transforms.Resize((img_h, img_w))
    cvae_model.train()
    t_model.eval()
    torch.autograd.set_detect_anomaly(True)

    for epoch in range(num_epoch):
        cvae_model.train()
        rec_loss_record = AverageMeter()
        kl_loss_record = AverageMeter()
        batch_time = AverageMeter()

        start = time.time()
        for x, target in train_loader:
            optimizer.zero_grad()

            batch_size = x.shape[0]
            x = x.cuda()
            denormalize_x = x.clone()
            for i in range(len(normalize_mean)):
                denormalize_x[:,i,:,:] = denormalize_x[:,i,:,:] * normalize_std[i] + normalize_mean[i]
            denormalize_x = resize(denormalize_x)
            if use_gt_labels:
                target = target.cuda()
                cond_vector = F.one_hot(target, num_classes=1000).float().cuda()
            else:
                target = target.cuda()
                target_pred, _, _ = t_model(x)
                target_pred = torch.argmax(target_pred, dim=1)
                cond_vector = F.one_hot(target_pred, num_classes=1000).float().cuda()

            x_decoded, x_original, z_mu, z_log_var = cvae_model(denormalize_x, labels=cond_vector)
            losses = cvae_model.loss_function(x_decoded, x_original, z_mu, z_log_var, M_N=0.0010) 
            loss = losses['loss']
            loss.backward()
            if clip != None:
                torch.nn.utils.clip_grad_norm_(cvae_model.parameters(), clip)
            optimizer.step()

            rec_loss_record.update(losses['Reconstruction_Loss'].item())
            kl_loss_record.update(losses['KLD'].item())
            batch_time.update(time.time() - start)
            start = time.time()
        
        print_log('cvae_train_epoch:{:}/{:}\t run_time:{:.3f}\t rec_loss:{:.3f}\t kl_loss:{:.3f}'.format(
            epoch, num_epoch, batch_time.avg, rec_loss_record.avg, kl_loss_record.avg), log)
        '''print_log('cvae_train_epoch:{:}/{:}\t run_time:{:.3f}\t ce_loss:{:.3f}'.format(
            epoch, num_epoch, batch_time.avg, ce_loss_record.avg), log)'''
        '''print_log('cvae_train_epoch:{:}/{:}\t run_time:{:.3f}\t kl_loss:{:.3f}'.format(
            epoch, num_epoch, batch_time.avg, kl_loss_record.avg), log)'''

        if (epoch + 1) % 100 == 0 or epoch == 0 or ((epoch+1) % 10 == 0 and epoch < 100):
            # evaluation and visualization
            cvae_model.eval()
            sample_batch_idx = random.randint(0, len(train_loader)-2)
            sample_batch, sample_target = None, None
            for idx, (batch, target) in enumerate(train_loader):
                if idx == sample_batch_idx:
                    sample_batch, sample_target = batch, target
                    break
            
            ## reconstruction
            sample_x = sample_batch[:10].clone().cuda()
            sample_pred, _, _ = t_model(sample_x)
            for i in range(img_ch):
                sample_x[:,i,:,:] = sample_x[:,i,:,:] * normalize_std[i] + normalize_mean[i]
            sample_x = resize(sample_x)
            sample_ycond = F.one_hot(torch.argmax(sample_pred, dim=1), num_classes=1000).float().cuda()
            sample_x_decoded = cvae_model(sample_x, labels=sample_ycond)[0]
            mse_loss_fn = nn.MSELoss()
            print(mse_loss_fn(sample_x, sample_x_decoded).item())

            ## generation
            z_random = torch.from_numpy(scipy.stats.norm.ppf(np.random.rand(10, latent_len))).float().cuda()
            sample_y_vector = F.one_hot(sample_target[:10], num_classes=1000).float().cuda()
            sample_x_encoded = torch.cat((z_random, sample_y_vector), dim=1)
            sample_x_generated = cvae_model.decode(sample_x_encoded)

            sample_x = sample_x.clone().detach().cpu().numpy()
            sample_x_decoded = sample_x_decoded.clone().detach().cpu().numpy()
            sample_x_generated = sample_x_generated.clone().detach().cpu().numpy()
            
            sample_x = np.clip(sample_x, 0, 1)
            sample_x_decoded = np.clip(sample_x_decoded, 0, 1)
            sample_x_generated = np.clip(sample_x_generated, 0, 1)
            # BCHW -> BHWC
            sample_x = np.transpose(sample_x, (0, 2, 3, 1))
            sample_x_decoded = np.transpose(sample_x_decoded, (0, 2, 3, 1))
            sample_x_generated = np.transpose(sample_x_generated, (0, 2, 3, 1))

            plt.figure(figsize=(20, 6))
            for idx in range(10):
                # plot original images
                plt.subplot(3, 10, idx + 1)
                plt.axis('off')
                plt.imshow(sample_x[idx])
                # plot reconstructed images
                plt.subplot(3, 10, idx + 1 + 10)
                plt.axis('off')
                plt.imshow(sample_x_decoded[idx])
                # plot generated images
                plt.subplot(3, 10, idx + 1 + 10 + 10)
                plt.axis('off')
                plt.imshow(sample_x_generated[idx])
            plt.suptitle("first row: original images, second row: reconstructed images, third row: generated images")
            plt.savefig(f"./figure/epoch{epoch}_{figure_name}")
            plt.close()