""" pretrain VAE for VAE-FFJORD and VAE-ACNF
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torch.optim as optim
import argparse
import matplotlib
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
matplotlib.style.use('ggplot')
import os
from vae_lib.models.VAE import VAE 


def fit(model, dataloader, criterion, device = torch.device("cpu")):
    model.train()
    running_loss = 0.0
    for i, (x, _) in tqdm(enumerate(dataloader), total=len(dataloader)):
        x = x.to(device)
        # [batch_size, 1, 28, 28]
        x = x.view(-1, *model.input_size)
        optimizer.zero_grad()
        reconstruction, mu, var, _, _, _ = model(x)
        bce_loss = criterion(reconstruction, x)
        loss = final_loss(bce_loss, mu, torch.log(var))
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

def validate(model, dataloader, criterion, device = torch.device("cpu")):
    model.eval()
    running_loss = 0.0
    running_loss_bce = 0.0
    running_loss_kl = 0.0
    with torch.no_grad():
        for i, (x, _) in tqdm(enumerate(dataloader), total=len(dataloader)):
            x = x.to(device)
            # [batch_size, 1, 28, 28]
            x = x.view(-1, *model.input_size)
            reconstruction, mu, var,_, _, _ = model(x)
            bce_loss = criterion(reconstruction, x)
            loss = final_loss(bce_loss, mu, torch.log(var))
            running_loss += loss.item()
            running_loss_bce += bce_loss.item()
            running_loss_kl += (loss-bce_loss).item()
            """
            # save the last batch input and output of every epoch
            if i == int(len(val_data)/dataloader.batch_size) - 1:
                num_rows = 8
                both = torch.cat((data.view(batch_size, 1, 28, 28)[:8], 
                                  reconstruction.view(batch_size, 1, 28, 28)[:8]))
                save_image(both.cpu(), f"../outputs/output{epoch}.png", nrow=num_rows)
            """
    val_loss = running_loss/len(dataloader.dataset)
    loss_bce = running_loss_bce/len(dataloader.dataset)
    loss_kl = running_loss_kl/len(dataloader.dataset)

    print('total val loss: {}, bce loss: {}, kl loss: {}'.format(val_loss, loss_bce, loss_kl))

    return val_loss

def final_loss(bce_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

if __name__ == '__main__':
    # construct the argument parser and parser the arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', '--epochs', default=100, type=int, 
                        help='number of epochs to train the VAE for')
    parser.add_argument('--gpu', default = 1, type = int)
    parser.add_argument('--z_dim', default = 4, type = int)
    args = parser.parse_args()
    

    # leanring parameters
    epochs = args.epochs
    features = args.z_dim
    batch_size = 64
    lr = 0.001
    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')


    # prepare MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    # train and validation data
    train_data = datasets.MNIST(
        root='../../data',
        train=True,
        download=True,
        transform=transform
    )
    val_data = datasets.MNIST(
        root='../../data',
        train=False,
        download=True,
        transform=transform
    )

    # training and validation data loaders
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True
    )
    val_loader = DataLoader(
        val_data,
        batch_size=batch_size,
        shuffle=False
    )

    model = VAE(features, input_size = [1,  28, 28], input_type = "binary", device = device).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCELoss(reduction='sum')

    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        print(f"Epoch {epoch+1} of {epochs}")
        train_epoch_loss = fit(model, train_loader, criterion, device)
        val_epoch_loss = validate(model, val_loader, criterion, device)
        train_loss.append(train_epoch_loss)
        val_loss.append(val_epoch_loss)
        print(f"Train Loss: {train_epoch_loss:.4f}")
        print(f"Val Loss: {val_epoch_loss:.4f}")

        train_dir = './experiments/acnf_vae'
        if not os.path.exists(train_dir):
            os.makedirs(train_dir)
        ckpt_path = os.path.join(train_dir, 'vae_ckpt_d{}.pth'.format(features))
        
        torch.save({
            'func_state_dict': model.state_dict(),
        }, ckpt_path)


        print('Model saved at {}'.format(ckpt_path))