"""
Three-layer MLP VAE trained with variational lower bound for MNIST/Omniglot/F-MNIST
"""
import argparse
import os
import time

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

from lib.vae_models import ThreeLayerVAE
from lib.constants import CUDA
from lib.utils import load_omniglot, OmniglotDataset


torch.set_default_dtype(torch.float64)
_datasets = ["mnist", "fmnist", "omniglot"]
_reparam_types = ["rt", "r2g2"]

# parsing arguments
parser = argparse.ArgumentParser("MNIST/FMNIST/Omniglot three-layer VAE demo")

parser.add_argument("--seed", type=int, default=1)
parser.add_argument(
    "--dataset", type=str, default="mnist", choices=_datasets
)
parser.add_argument("--lr", type=float, default=0.0003)
parser.add_argument("--steps", type=int, default=100000)
parser.add_argument("--batch_size_train", type=int, default=80)
parser.add_argument("--batch_size_test", type=int, default=100)
parser.add_argument("--print_every", type=int, default=10000)
parser.add_argument("--dim_latent", type=str, default=50)
parser.add_argument(
    "--reparam", type=str, default="r2g2", choices=_reparam_types
)
parser.add_argument("--log_metrics", action="store_false")

args = parser.parse_args()

# read args
seed = args.seed
dataset = args.dataset
lr = args.lr
steps = args.steps
batch_size_train = args.batch_size_train
batch_size_test = args.batch_size_test
print_every = args.print_every
dim_latent = args.dim_latent
reparam = args.reparam
log_metrics = args.log_metrics

# data parameters and load pytorch dataset
if dataset in ["mnist", "fmnist"]:
    data_train_kwargs = {
        "root": "data/",
        "train": True,
        "download": True,
        "transform": transforms.ToTensor(),
    }
    data_test_kwargs = {
        "root": "data/",
        "train": False,
        "download": True,
        "transform": transforms.ToTensor(),
    }

    if dataset == "mnist":
        data_train = datasets.MNIST(**data_train_kwargs)
        data_test = datasets.MNIST(**data_test_kwargs)

    elif dataset == "fmnist":
        data_train = datasets.FashionMNIST(**data_train_kwargs)
        data_test = datasets.FashionMNIST(**data_test_kwargs)

elif dataset == "omniglot":
    X_train, Y_train, X_test, Y_test = load_omniglot()
    data_train = OmniglotDataset(X_train, Y_train)
    data_test = OmniglotDataset(X_test, Y_test)

else:
    raise NotImplementedError

# setting seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

# network architecture
model = ThreeLayerVAE(reparam=reparam)

# setup for data loaders
train_kwargs = {"batch_size": batch_size_train, "shuffle": True}
grad_var_kwargs = {"batch_size": batch_size_train, "shuffle": False}
elbo_kwargs = {"batch_size": batch_size_test, "shuffle": False}

if CUDA:
    model = model.cuda()
    cuda_kwargs = {"num_workers": 8, "pin_memory": True}
    train_kwargs.update(cuda_kwargs)
    grad_var_kwargs.update(cuda_kwargs)
    elbo_kwargs.update(cuda_kwargs)

train_loader = DataLoader(data_train, **train_kwargs)
grad_var_loader = DataLoader(data_train, **grad_var_kwargs)
test_elbo_loader = DataLoader(data_test, **elbo_kwargs)
train_elbo_loader = DataLoader(data_train, **elbo_kwargs)

# optimiser settings
bce_loss_train = torch.nn.BCELoss(reduction="sum")
bce_loss_test = torch.nn.BCELoss(reduction="none")
model_parameters = list(model.parameters())
opt = torch.optim.Adam(model_parameters, lr=lr)


def run():
    t0 = time.time()
    if log_metrics:
        save_dir = f"vae_results/3/{dataset}/{reparam}/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        test_elbo_out = open(save_dir+'seed{}_test_elbo.txt'.format(seed), 'w')
        train_elbo_out = open(save_dir+'seed{}_train_elbo.txt'.format(seed), 'w')
    else:
        test_elbo_out = None
        train_elbo_out = None
    
    train_generator = iter(train_loader)
    for step in range(print_every, steps + 1, print_every):
        # apply print_every optimisation steps
        for _ in tqdm(range(print_every), desc=f"optimisation steps"):
            # sample next minibatch and reset data loader if needed
            try:
                X_train, _ = next(train_generator)
            except StopIteration:
                train_generator = iter(train_loader)
                X_train, _ = next(train_generator)
            model.train_one_step(X_train=X_train, nll_loss=bce_loss_train, opt=opt)

        # printing steps
        print(f"computing elbos and var")
        model.print_elbo(step=step, elbo_loader=test_elbo_loader, nll_loss=bce_loss_test, elbo_out=test_elbo_out)
        model.print_elbo(step=step, elbo_loader=train_elbo_loader, nll_loss=bce_loss_test, elbo_out=train_elbo_out)

    elapsed_time = time.time() - t0
    print(f"elapsed_time: {elapsed_time}")


if __name__ == '__main__':
    run()
