# 00 -> MIG-GPU-0b63aaba-6ele-a51f-97e5-c8bb77699c04/1/0
# 01 -> MIG-GPU-0b63aaba-6ele-a51f-97e5-c8bb77699c04/2/0
# 10 -> MIG-GPU-1c0365e0-78b1-1672-907a-68efbe86c467/1/0
# 11 -> MIG-GPU-1c0365e0-78b1-1672-907a-68efbe86c467/2/0
# 20 -> MIG-GPU-d0cce0e7-c51a-ba4c-2728-7b6d73beacc8/1/0
# 21 -> MIG-GPU-d0cce0e7-c51a-ba4c-2728-7b6d73beacc8/2/0
# 40 -> MIG-GPU-612f1f55-f57a-e899-0561-2ded7ff24ee7/1/0
# 41 -> MIG-GPU-612f1f55-f57a-e899-0561-2ded7ff24ee7/2/0

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "MIG-GPU-0b63aaba-6ele-a51f-97e5-c8bb77699c04/1/0"

import argparse
import os

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets  # , transforms
from torchvision import transforms as tv_transforms
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision
from oodd.datasets import transforms
import numpy as np
import matplotlib.pyplot as plt
import random
from sklearn.manifold import TSNE
from model import ConvVAE
from lstm_mdn import MixtureDensityRNN as qz_rnn

""" This script is an example of Sigma VAE training in PyTorch. The code was adapted from:
https://github.com/pytorch/examples/blob/master/vae/main.py """

## Arguments
parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--analyze_mode', type=bool, default=True, help='whether analyze the model.')
parser.add_argument('--add_noise', type=bool, default=False, help='')
parser.add_argument('--use_entropy', type=bool, default=False, help='verify the entropy')
parser.add_argument('--kl_q_q', type=bool, default=False, help='verify the q_id(z)')
parser.add_argument('--qq_sigma', type=float, default=0.01, help='')
parser.add_argument('--neigh', type=bool, default=False, help='')
parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--id', type=str, default='CIFAR',
                    help='FMNIST, MNIST, CIFAR')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--z_dim', type=int, default=200, metavar='N',
                    help='dimension of the latent variable')
parser.add_argument('--beta', type=float, default=1.0,
                    help='beta * KL')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--model', type=str, default='discrete_logistic', metavar='N',
                    help='which model to use: bce_vae, mse_vae,  gaussian_vae, or sigma_vae or optimal_sigma_vae, discrete_logistic')
parser.add_argument('--log_dir', type=str, default='./results', metavar='N')
args = parser.parse_args()

if args.analyze_mode:
    args.batch_size = 1
## Cuda
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")

## Dataset
# transform = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()])
TRANSFORM_BINARIZE = torchvision.transforms.Compose(
    [
        tv_transforms.Resize((28, 28)),
        torchvision.transforms.ToTensor(),
        transforms.Binarize(resample=True),
    ]
)
TRANSFORM_non_BINARIZE = torchvision.transforms.Compose(
    [
        tv_transforms.Resize((28, 28)),
        torchvision.transforms.ToTensor(),
        # transforms.Binarize(resample=True),
    ]
)
TRANSFORM_ = TRANSFORM_non_BINARIZE
# train_dataset = datasets.CIFAR10('../../data', train=True, transform=transform, download=True)
# test_dataset = datasets.CIFAR10('../../data', train=False, transform=transform, download=True)
# train_dataset = datasets.SVHN('../../data', split='train', download=True, transform=transform)
# test_dataset = datasets.SVHN('../../data', split='train', transform=transform)

if args.id == 'FMNIST':
    train_dataset = datasets.FashionMNIST('./data', train=True, download=True,
                                transform=TRANSFORM_)  # transforms.ToTensor()
    test_dataset = datasets.FashionMNIST('./data', train=False,
                               transform=TRANSFORM_)
    ood_test_dataset = datasets.MNIST('./data', train=False, download=True,
                               transform=TRANSFORM_)
elif args.id == 'MNIST':
    train_dataset = datasets.MNIST('./data', train=True, download=True,
                                transform=TRANSFORM_)  # transforms.ToTensor()
    test_dataset = datasets.MNIST('./data', train=False,
                               transform=TRANSFORM_)
    ood_test_dataset = datasets.FashionMNIST('./data', train=False, download=True,
                               transform=TRANSFORM_)
elif args.id == 'CIFAR':
    # transform_ = tv_transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()])
    train_dataset = datasets.CIFAR10('../../data', train=True, transform=TRANSFORM_, download=True)
    test_dataset = datasets.CIFAR10('../../data', train=False, transform=TRANSFORM_, download=True)
    ood_test_dataset = datasets.SVHN('../../data', split='test', transform=TRANSFORM_, download=True)
# --- data loading --- #


kwargs = {'num_workers': 10, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
ood_test_loader = torch.utils.data.DataLoader(ood_test_dataset,
                                          batch_size=args.batch_size, shuffle=True, **kwargs)

## Logging
os.makedirs(f'vae_logs/{args.log_dir}_id_{args.id}', exist_ok=True)
summary_writer = SummaryWriter(log_dir='vae_logs/' + args.log_dir, purge_step=0)

## Build Model
if args.id in ['FMNIST', 'MNIST']:
    model = ConvVAE(device, 1, args).to(device)
if args.id in ['CIFAR', 'SVHN']:
    model = ConvVAE(device, 3, args).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        if args.add_noise:
            data = torch.randn_like(data) + data
        optimizer.zero_grad()

        # Run VAE
        recon_batch, mu, logvar = model(data)
        # Compute loss
        rec, kl = model.loss_function(recon_batch, data, mu, logvar)

        total_loss = rec + args.beta*kl
        total_loss.backward()
        train_loss += total_loss.item()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('ID {} Train Epoch: {} [{}/{} ({:.0f}%)]\trec: {:.6f}\tKL: {:.6f}\tlog_sigma: {:f} z_dim: {} beta:{}'.format(
                args.id, epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       rec.item() / len(data),
                       kl.item() / len(data),
                model.log_sigma, args.z_dim, args.beta))

    train_loss /= len(train_loader.dataset)
    print('====> ID {} Epoch: {} Average loss: {:.4f}'.format(args.id,
        epoch, train_loss))
    summary_writer.add_scalar('train/elbo', train_loss, epoch)
    summary_writer.add_scalar('train/rec', rec.item() / len(data), epoch)
    summary_writer.add_scalar('train/kld', kl.item() / len(data), epoch)
    summary_writer.add_scalar('train/log_sigma', model.log_sigma, epoch)


def test(epoch):
    model.eval()
    test_loss, test_kl = 0, 0
    ood_test_loss, ood_test_kl = 0, 0
    with torch.no_grad():
        for i, (data, _) in enumerate(tqdm(test_loader)):
            data = data.to(device)
            if args.add_noise:
                data = torch.randn_like(data) + data
            recon_batch, mu, logvar = model(data)
            # Pass the second value from posthoc VAE
            rec, kl = model.loss_function(recon_batch, data, mu, logvar)
            test_loss += rec + kl
            test_kl += kl
            if i == 0:
                n = min(data.size(0), 8)
                if args.model == 'discrete_logistic':
                    _, data_likelihood = model.likelihood(recon_batch)
                    recon_batch = data_likelihood.mean
                comparison = torch.cat([data[:n], recon_batch.view(args.batch_size, -1, 28, 28)[:n]])
                save_image(comparison.cpu(), 'vae_logs/{}_id_{}/id_zdim_{}_beta{}_{}_{}_reconstruction_{}.png'.format(args.log_dir, args.id, args.z_dim, args.beta, args.model, args.add_noise, str(epoch)),
                           nrow=n)

        for i, (data, _) in enumerate(tqdm(ood_test_loader)):
            data = data.to(device)
            if args.add_noise:
                data = torch.randn_like(data) + data
            recon_batch, mu, logvar = model(data)
            # Pass the second value from posthoc VAE
            rec, kl = model.loss_function(recon_batch, data, mu, logvar)
            ood_test_loss += rec + kl
            ood_test_kl += kl
            if i == 0:
                n = min(data.size(0), 8)
                if args.model == 'discrete_logistic':
                    _, data_likelihood = model.likelihood(recon_batch)
                    recon_batch = data_likelihood.mean
                comparison = torch.cat([data[:n], recon_batch.view(args.batch_size, -1, 28, 28)[:n]])
                save_image(comparison.cpu(), 'vae_logs/{}_id_{}/ood_zdim_{}_beta{}_{}_{}_reconstruction_{}.png'.format(args.log_dir, args.id, args.z_dim, args.beta, args.model, args.add_noise, str(epoch)),
                           nrow=n)

    test_loss /= len(test_loader.dataset)
    test_kl /= len(test_loader.dataset)
    ood_test_loss /= len(ood_test_loader.dataset)
    ood_test_kl /= len(ood_test_loader.dataset)
    print('====> z_dim: {} ID_{} Test set loss: {:.4f} kl {}; OOD Test set loss: {:.4f} kl{} | beta:{}'.format(args.z_dim, args.id, test_loss, test_kl, ood_test_loss, ood_test_kl, args.beta))
    summary_writer.add_scalar('test/elbo', test_loss, epoch)

def analyze_z(model, train_loader, test_loader, ood_test_loader, args):
    model.eval()
    if args.kl_q_q:
        q_id_z = qz_rnn(time_range=args.z_dim - 1).to('cuda:0')
        q_id_z.load_state_dict(torch.load(f'./qz_fitting_results/{args.id}/Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_{args.add_noise}_klqq_True'
                                          f'_train_muz_lstm.pth'))

    with torch.no_grad():
        train_mus, train_logpx, train_kl = [], [], []
        id_mus, id_logpx, id_kl = [], [], []
        ood_mus, ood_logpx, ood_kl = [], [], []
        if args.id == 'CIFAR':
            train_svd_error = 0.00100864
            n_id = 20
        elif args.id == 'FMNIST':
            train_svd_error = 6.13280957e-05  # 18
            n_id = 18
        id_min_step = []
        ood_min_step = []

        avg_loss = 0.
        for batch_idx, (data, label) in enumerate(train_loader):
            data = data.to(device)
            if args.add_noise:
                data = torch.randn_like(data) + data
            recon_batch, mu, logvar = model(data)
            rec, kl = model.loss_function(recon_batch, data, mu, logvar)

            if args.kl_q_q:
                kl_q_q = model.kl_with_qz(mu, logvar, q_id_z)
                loss = rec + kl_q_q
                train_kl.append(kl_q_q.cpu().numpy())
            else:
                if args.id == 'CIFAR':
                    loss = rec + kl
                elif args.id == 'FMNIST':
                    loss = rec + kl
                train_kl.append(kl.cpu().numpy())
            avg_loss += loss
            train_mus.append(mu.cpu().numpy())
            train_logpx.append(-loss.cpu().numpy())

            if batch_idx == 1000:
                avg_loss /= (5000+1)
                print(f'avg loss {avg_loss}')
                break

        id_n_step = 0
        for batch_idx, (data, label) in enumerate(test_loader):
            data = data.to(device)
            if args.add_noise:
                data = torch.randn_like(data) + data
            recon_batch, mu, logvar = model(data)
            rec, kl = model.loss_function(recon_batch, data, mu, logvar)
            if args.kl_q_q:
                kl_q_q = model.kl_with_qz(mu, logvar, q_id_z)
                loss = rec + kl_q_q
                id_kl.append(kl_q_q.cpu().numpy())
            else:
                if args.id == 'CIFAR':
                    loss = rec + kl
                elif args.id == 'FMNIST':
                    loss = rec + kl
                id_kl.append(kl.cpu().numpy())

            if args.use_entropy:
                data_np = np.array(data.cpu()[0, 0])
                U, sigma, V = np.linalg.svd(data_np)
                for i in range(1, 50, 1):
                    reconstimg = np.matrix(U[:, :i]) * np.diag(sigma[:i]) * np.matrix(V[:i, :])
                    gap = np.mean(np.abs(reconstimg - data_np))
                    # id_decay[i - 1] += gap
                    if gap <= train_svd_error:
                        steps = i
                        id_n_step += steps
                        id_min_step.append(steps)
                        break
                if steps <= n_id:
                    # loss += -avg_loss * (steps/n_id)
                    loss += -loss * (steps/n_id)
                else:
                    # loss += -avg_loss * (2*n_id - steps)/n_id
                    loss += -loss * (2*n_id - steps)/n_id

            id_mus.append(mu.cpu().numpy())
            id_logpx.append(-loss.cpu().numpy())

            if batch_idx == 1000:  # len(test_loader) - 2:
                break

        print(f'id n step : {id_n_step/(batch_idx+1)}')

        ood_n_step = 0
        for batch_idx, (data, label) in enumerate(ood_test_loader):
            data = data.to(device)
            if args.add_noise:
                data = torch.randn_like(data) + data
            recon_batch, mu, logvar = model(data)
            rec, kl = model.loss_function(recon_batch, data, mu, logvar)
            if args.kl_q_q:
                kl_q_q = model.kl_with_qz(mu, logvar, q_id_z)
                loss = rec + kl_q_q
                ood_kl.append(kl_q_q.cpu().numpy())
            else:
                if args.id == 'CIFAR':
                    loss = rec + kl
                elif args.id == 'FMNIST':
                    loss = rec + kl
                ood_kl.append(kl.cpu().numpy())

            if args.use_entropy:
                data_np = np.array(data.cpu()[0, 0])
                U, sigma, V = np.linalg.svd(data_np)
                for i in range(1, 51, 1):
                    reconstimg = np.matrix(U[:, :i]) * np.diag(sigma[:i]) * np.matrix(V[:i, :])
                    gap = np.mean(np.abs(reconstimg - data_np))
                    # id_decay[i - 1] += gap
                    if gap <= train_svd_error or i == 50:
                        steps = i
                        ood_n_step += steps
                        ood_min_step.append(steps)
                        break
                if i <= n_id:
                    # loss += -avg_loss * (i/n_id)
                    loss += -loss * (i/n_id)
                else:
                    # loss += -avg_loss * (2*n_id - i)/n_id
                    loss += -loss * (2*n_id - i)/n_id

            ood_mus.append(mu.cpu().numpy())
            ood_logpx.append(-loss.cpu().numpy())


            if batch_idx == 1000:  # len(ood_test_loader) - 2:
                break
        print(f'ood n step : {ood_n_step / (batch_idx + 1)}')
    train_mus = np.stack(train_mus).reshape(-1, mu.shape[-1])
    id_mus = np.stack(id_mus).reshape(-1, mu.shape[-1])
    ood_mus = np.stack(ood_mus).reshape(-1, mu.shape[-1])

    train_logpx = np.stack(train_logpx).reshape(-1, 1)
    id_logpx = np.stack(id_logpx).reshape(-1, 1)
    ood_logpx = np.stack(ood_logpx).reshape(-1, 1)

    train_kl = np.stack(train_kl).reshape(-1, 1)
    id_kl = np.stack(id_kl).reshape(-1, 1)
    ood_kl = np.stack(ood_kl).reshape(-1, 1)

    if args.use_entropy:
        train_decay = np.zeros(30)
        id_decay = np.zeros(30)
        ood_decay = np.zeros(30)

        # train number n
        if args.id == 'CIFAR':
            train_svd_error = 0.00100864
        elif args.id == 'FMNIST':
            train_svd_error = 0.0005  # 6.13280957e-05
        count = 0
        for batch_idx, (data, label) in enumerate(train_loader):
            data = data.to(device)
            count += data.shape[0]
            data_np = np.array(data.cpu()[0, 0])
            # plt.imshow(data_np, cmap='gray')
            # title = 'original'
            # plt.title(title)
            # plt.show()
            # plt.clf()

            min_step = []

            U, sigma, V = np.linalg.svd(data_np)
            for i in range(0, 31, 1):
                reconstimg = np.matrix(U[:, :i]) * np.diag(sigma[:i]) * np.matrix(V[:i, :])
                gap = np.mean(np.abs(reconstimg - data_np))
                train_decay[i - 1] += gap
                # if gap <= 0.0005:
                #     min_step.append(i)
                #     break

        train_id = train_decay / count
        print(f'train decay: {train_id}')

        # id_min_step = []
        # ood_min_step = []
        # for batch_idx, (data, label) in enumerate(test_loader):
        #     data = data.to(device)
        #     data_np = np.array(data.cpu()[0, 0])
        #     # plt.imshow(data_np, cmap='gray')
        #     # title = 'original'
        #     # plt.title(title)
        #     # plt.show()
        #     # plt.clf()
        #
        #
        #     U, sigma, V = np.linalg.svd(data_np)
        #     for i in range(1, 50, 1):
        #         reconstimg = np.matrix(U[:, :i]) * np.diag(sigma[:i]) * np.matrix(V[:i, :])
        #         gap = np.mean(np.abs(reconstimg - data_np))
        #         # id_decay[i - 1] += gap
        #         if gap <= train_svd_error:
        #             id_min_step.append(i)
        #             break
        #         # if gap <= 0.0005:
        #         #     min_step.append(i)
        #         #     break
        #     if batch_idx > 1000:
        #         break
        #
        # for batch_idx, (data, label) in enumerate(ood_test_loader):
        #     data = data.to(device)
        #     data_np = np.array(data.cpu()[0, 0])
        #     # plt.imshow(data_np, cmap='gray')
        #     # title = 'original'
        #     # plt.title(title)
        #     # plt.show()
        #     # plt.clf()
        #
        #     U, sigma, V = np.linalg.svd(data_np)
        #     for i in range(1, 50, 1):
        #         reconstimg = np.matrix(U[:, :i]) * np.diag(sigma[:i]) * np.matrix(V[:i, :])
        #         gap = np.mean(np.abs(reconstimg - data_np))
        #         # ood_decay[i - 1] += gap
        #         if gap <= train_svd_error:
        #             ood_min_step.append(i)
        #             break
        #         # if gap <= 0.0005:
        #         #     ood_min_step.append(i)
        #         #     break
        #     if batch_idx > 1000:
        #         break
        #
        # plt.hist(id_min_step, color="deepskyblue", alpha=0.4, label='id test')
        # plt.hist(ood_min_step, color="orangered", alpha=0.4, label='ood')
        # plt.legend()
        # plt.title('how many steps to achieve the baseline recon error')
        # plt.show()
        #
        # # x = np.arange(0, 20)
        # # y_id = id_decay / 300
        # # y_ood = ood_decay / 300
        # # plt.plot(x, y_id, label='id FashionMNIST')
        # # plt.plot(x, y_ood, label='ood MNIST')
        # # plt.legend()
        # # plt.show()
        # # plt.clf()
        # # print(f'id decay: {id_decay / 300}')
        # # print(f'ood decay: {ood_decay / 300}')
        # # print(f'id {args.id} svd recon needs {np.array(min_step).mean()} steps.')
        # # print(f'ood svd recon needs {np.array(ood_min_step).mean()} steps.')


    return train_mus, id_mus, ood_mus, train_logpx, id_logpx, ood_logpx, train_kl, id_kl, ood_kl


if __name__ == "__main__":
    analyze = args.analyze_mode
    if analyze:
        print('Start analyzing...')
        # 'vae_logs/{}_id_{}/Zdim_{}_beta_{}_{}_checkpoint_{}.pt'.format(args.log_dir, args.id, args.z_dim, args.beta,
        #                                                                args.model, str(epoch))
        model.load_state_dict(torch.load(f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_{args.add_noise}_checkpoint_{args.epochs}' + '.pt'))
        train_mus, id_mus, ood_mus, train_logpx, id_logpx, ood_logpx, train_kl, id_kl, ood_kl = analyze_z(model, train_loader, test_loader, ood_test_loader, args)

        # save the results
        np.save(f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_train_muz' + '.npy', train_mus)
        np.save(f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_id_muz' + '.npy', id_mus)
        np.save(f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_ood_muz' + '.npy', ood_mus)

        np.save(f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_train_logpx' + '.npy', train_logpx)
        np.save(f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_id_logpx' + '.npy', id_logpx)
        np.save(f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_ood_logpx' + '.npy', ood_logpx)

        print(f'id logpx mean: {np.mean(id_logpx)}')
        print(f'ood logpx mean: {np.mean(ood_logpx)}')
        plt.hist(id_logpx,  bins=100, facecolor="deepskyblue", alpha=0.5, label=f'{args.id} test (ID)')
        plt.hist(ood_logpx, bins=100, facecolor="orangered", alpha=0.5, label='(OOD)')
        plt.legend()
        # plt.xlim(-2600, -1200)
        plt.title('log p(x) estimated by ELBO')
        plt.show()
        plt.clf()

        plt.hist(id_kl, bins=100, facecolor="deepskyblue", alpha=0.5, label=f'{args.id} test (ID)')
        plt.hist(ood_kl, bins=100, facecolor="orangered", alpha=0.5, label='(OOD)')
        plt.legend()
        plt.title('KL')
        plt.show()
        plt.clf()

        id_mus_1000 = np.array(random.sample(list(id_mus), 1000))
        ood_mus_1000 = np.array(random.sample(list(ood_mus), 1000))
        N01_mus_1000 = np.random.normal(size=(1000, args.z_dim))
        if args.z_dim > 2:
            tsne_input = np.concatenate([id_mus_1000, ood_mus_1000, N01_mus_1000], 0)
            z_embedded = TSNE(n_components=2).fit_transform(tsne_input)
            color_choice = ['blue', 'firebrick', 'cornflowerblue', 'sienna', 'orangered', 'black', 'yellow',
                            'darkgreen', 'cyan',
                            'dodgerblue', 'purple',
                            'navy']
            label_choice = [f'{args.id} test (ID)', f'SVHN test (OOD)',
                            r'Prior $\bf{z}\sim\mathcal{N}(\bf{0}, \bf{I})$']
            colors = []
            color_id = np.ones(id_mus_1000.shape[0]) * 0
            colors.append(color_id)
            color_ood = np.ones(ood_mus_1000.shape[0]) * 1
            colors.append(color_ood)
            color_n01 = np.ones(N01_mus_1000.shape[0]) * 2
            colors.append(color_n01)
            colors = np.concatenate(colors, 0)
            labels = colors
            alphas = [0.7, 0.7, 0.4]
            for plt_i in range(3):
                plt.scatter(z_embedded[plt_i * 1000:(plt_i + 1) * 1000 - 1, 0],
                            z_embedded[plt_i * 1000:(plt_i + 1) * 1000 - 1, 1], c=color_choice[plt_i], s=10,
                            alpha=alphas[plt_i], label=label_choice[plt_i])
            # plt.title(f'stage z_{stage_i}')
            plt.xticks([])
            plt.yticks([])
            plt.legend()
            # legend_text = f'"t-sne of z" -> red:{args.id} | green: ood'
            # plt.title(legend_text)
            plt.tight_layout()
            plt.savefig(f'./tsne_{args.id}_prior.pdf')
            plt.savefig(f'./tsne_{args.id}_prior.png')
            plt.show()
            plt.clf()

        else:
            plt.scatter(id_mus_1000[:,0], id_mus_1000[:,1], s=10, alpha=0.5, label = f'{args.id} test (ID)')
            plt.title('location of ID data latent codes in prior distribution space')
            # plt.xlim(-3, 3)
            # plt.ylim(-3, 3)
            plt.show()
            plt.clf()

            plt.scatter(id_mus_1000[:,0], id_mus_1000[:,1], s=10, alpha=0.5, label = f'{args.id} test (ID)')
            plt.scatter(ood_mus_1000[:,0], ood_mus_1000[:,1], s=10, alpha=0.5, label = '(OOD)')
            plt.title('location of ID & OOD latent codes in prior distribution space')
            # plt.xlim(-3, 3)
            # plt.ylim(-3, 3)
            plt.legend()
            plt.show()

    else:
        for epoch in range(1, args.epochs + 1):
            train(epoch)
            test(epoch)
            with torch.no_grad():
                sample = model.sample(64).cpu()
                save_image(sample.view(64, -1, 28, 28),
                           'vae_logs/{}_id_{}/Zdim_{}_beta_{}_{}_{}_sample_{}.png'.format(args.log_dir, args.id, args.z_dim, args.beta, args.model, args.add_noise, str(epoch)))
            summary_writer.file_writer.flush()
            if epoch % 10 == 0:
                torch.save(model.state_dict(), 'vae_logs/{}_id_{}/Zdim_{}_beta_{}_{}_{}_checkpoint_{}.pt'.format(args.log_dir, args.id, args.z_dim, args.beta, args.model, args.add_noise, str(epoch)))
        torch.save(model.state_dict(), 'vae_logs/{}_id_{}/Zdim_{}_beta_{}_{}_{}_checkpoint_{}.pt'.format(args.log_dir, args.id, args.z_dim, args.beta, args.model, args.add_noise, str(epoch)))

