"""
Gaussian BNN trained with variational lower bound for CIFAR10
"""
import argparse
import os
import time

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

from lib.bnn_models import (
    CifarBNN,
    GaussianConv2d,
    GaussianLinear,
    BNNVarEstimator,
)
from lib.constants import Q_SAMPLES_BNN

torch.set_default_dtype(torch.float64)
_reparam_types = ["rt", "r2g2"]

# parsing arguments
parser = argparse.ArgumentParser("CIFAR10 BNN demo")

parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--num_epochs", type=int, default=50)
parser.add_argument("--batch_size_train", type=int, default=100)
parser.add_argument("--batch_size_test", type=int, default=10000)
parser.add_argument("--print_every", type=int, default=10)
parser.add_argument(
    "--reparam", type=str, default="r2g2", choices=_reparam_types
)
parser.add_argument("--log_metrics", action="store_false")

args = parser.parse_args()

# read args
seed = args.seed
lr = args.lr
num_epochs = args.num_epochs
batch_size_train = args.batch_size_train
batch_size_test = args.batch_size_test
print_every = args.print_every
reparam = args.reparam
log_metrics = args.log_metrics

# data parameters
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
data_train_kwargs = {
    "root": "data/",
    "train": True,
    "download": True,
    "transform": transform_train,
}
data_test_kwargs = {
    "root": "data/",
    "train": False,
    "download": True,
    "transform": transform_test,
}

# load pytorch dataset
data_train = datasets.CIFAR10(**data_train_kwargs)
data_test = datasets.CIFAR10(**data_test_kwargs)

# setting seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

# network architecture
model = CifarBNN(reparam=reparam)

# setup for data loaders
train_kwargs = {"batch_size": batch_size_train, "shuffle": True}
grad_var_kwargs = {"batch_size": batch_size_train, "shuffle": False}
logs_kwargs = {"batch_size": batch_size_test, "shuffle": False}


cuda = torch.cuda.is_available()
if cuda:
    model = model.cuda()
    cuda_kwargs = {"num_workers": 8, "pin_memory": True}
    train_kwargs.update(cuda_kwargs)
    grad_var_kwargs.update(cuda_kwargs)
    logs_kwargs.update(cuda_kwargs)

train_loader = DataLoader(data_train, **train_kwargs)
grad_var_loader = DataLoader(data_train, **grad_var_kwargs)
logs_test_loader = DataLoader(data_test, **logs_kwargs)
logs_train_loader = DataLoader(data_train, **logs_kwargs)

# optimiser settings
categorical_loss = torch.nn.NLLLoss(reduction="sum")
categorical_loss_test = torch.nn.NLLLoss(reduction="none")
log_softmax = torch.nn.LogSoftmax(dim=-1)
model_parameters = list(model.parameters())
opt = torch.optim.Adam(model_parameters, lr=lr)


def train(epoch):
    # training step
    model.train()
    for _, (X_train, Y_train) in enumerate(
        tqdm(train_loader, desc=f"epoch {epoch} minibatches")
    ):
        if cuda:
            X_train, Y_train = X_train.cuda(), Y_train.cuda()
        X_train, Y_train = X_train.type(torch.double), Y_train.flatten().type(
            torch.long
        )
        opt.zero_grad()
        logp_pred = log_softmax(model(X_train))
        loss_train = categorical_loss(logp_pred, Y_train)
        loss_train *= len(train_loader.dataset) / X_train.shape[0]

        # calculate kl components
        kl_term = torch.sum(torch.zeros(1, dtype=torch.float))
        if cuda:
            kl_term = kl_term.cuda()

        for module in model.net.modules():
            if isinstance(module, GaussianConv2d) or isinstance(
                module, GaussianLinear
            ):
                params_all = module.get_params()
                q_mean, q_stddev = params_all[:, 0], params_all[:, 1]
                prior_mean = torch.zeros(params_all.shape[0])
                prior_stddev = torch.ones(params_all.shape[0])

                if cuda:
                    q_mean, q_stddev = q_mean.cuda(), q_stddev.cuda()
                    prior_mean = prior_mean.cuda()
                    prior_stddev = prior_stddev.cuda()

                q = torch.distributions.Normal(q_mean, q_stddev)
                p = torch.distributions.Normal(prior_mean, prior_stddev)

                kl = torch.distributions.kl.kl_divergence(q, p).sum()
                kl_term += kl

        if cuda:
            loss_train = loss_train.cuda()
            kl_term = kl_term.cuda()

        vi_loss = loss_train + kl_term

        vi_loss.backward()
        opt.step()


def print_accuracy(epoch, data_loader, acc_out=None):
    model.eval()

    num_correct_train = 0.0
    total_train = Q_SAMPLES_BNN * len(data_loader.dataset)

    for _, _ in enumerate(
        tqdm(range(Q_SAMPLES_BNN), desc=f"epoch {epoch} minibatches")
    ):
        for X, Y in data_loader:
            if cuda:
                X, Y = X.cuda(), Y.cuda()
            X, Y = X.type(torch.double), Y.flatten().type(torch.long)

            with torch.no_grad():
                logp_pred = log_softmax(model(X)) # batch_size * Q_SAMPLES_BNN, CLASSES
                Y_pred = logp_pred.argmax(dim=-1) # batch_size * Q_SAMPLES_BNN
                num_correct_train += torch.sum(Y_pred == Y).item()

    prop_correct_train = num_correct_train / total_train
    print(f"acc: {100 * prop_correct_train:2.2f}")

    if acc_out is not None:
        acc_out.write('{} {:2.2f}\n'.format(epoch, 100 * prop_correct_train))
        acc_out.flush()


def print_gradient_variance(epoch, train_grad_var_out=None):
    model.train()
    var_estimator_top = BNNVarEstimator(model, marker="top")
    var_estimator_bottom = BNNVarEstimator(model, marker="bottom")

    X, Y = next(iter(grad_var_loader))
    if cuda:
        X, Y = X.cuda(), Y.cuda()
    X, Y = X.type(torch.double), Y.flatten().type(torch.long)

    for _ in range(100):
        opt.zero_grad()
        logp_pred = log_softmax(model(X))
        loss_train = categorical_loss(logp_pred, Y)
        loss_train *= len(train_loader.dataset) / X.shape[0]

        # calculate kl components
        kl_term = torch.sum(torch.zeros(1, dtype=torch.float))
        if cuda:
            kl_term = kl_term.cuda()

        for module in model.net.modules():
            if isinstance(module, GaussianConv2d) or isinstance(module, GaussianLinear):
                params_all = module.get_params()
                q_mean, q_stddev = params_all[:, 0], params_all[:, 1]
                prior_mean = torch.zeros(params_all.shape[0])
                prior_stddev = torch.ones(params_all.shape[0])

                if cuda:
                    q_mean, q_stddev = q_mean.cuda(), q_stddev.cuda()
                    prior_mean = prior_mean.cuda()
                    prior_stddev = prior_stddev.cuda()

                q = torch.distributions.Normal(q_mean, q_stddev)
                p = torch.distributions.Normal(prior_mean, prior_stddev)

                kl = torch.distributions.kl.kl_divergence(q, p).sum()
                kl_term += kl

        if cuda:
            loss_train = loss_train.cuda()
            kl_term = kl_term.cuda()

        vi_loss = loss_train + kl_term
        vi_loss.backward()
        var_estimator_top.update(model)
        var_estimator_bottom.update(model)
        del vi_loss, loss_train, logp_pred

    var_top = var_estimator_top.get_var()
    var_bottom = var_estimator_bottom.get_var()
    print('Variance top {}'.format(var_top))
    print('Variance bottom {}'.format(var_bottom))
    if train_grad_var_out is not None:
        train_grad_var_out.write('{} {} {} \n'.format(epoch, var_top, var_bottom))
        train_grad_var_out.flush()
    del var_estimator_top, var_estimator_bottom


def print_nll(epoch, data_loader, nll_out=None):
    model.eval()

    nll_loss_total = torch.sum(torch.zeros(1, dtype=torch.float))
    nll_loss_total += (len(data_loader.dataset) * np.log(Q_SAMPLES_BNN))
    if cuda:
        nll_loss_total = nll_loss_total.cuda()
    nll_loss_list = []

    for _, _ in enumerate(
        tqdm(range(Q_SAMPLES_BNN), desc=f"epoch {epoch} minibatches")
    ):
        sample_nll_loss = torch.sum(torch.zeros(1, dtype=torch.float))
        if cuda:
            sample_nll_loss = sample_nll_loss.cuda()
        for X, Y in data_loader:
            if cuda:
                X, Y = X.cuda(), Y.cuda()
            X, Y = X.type(torch.double), Y.flatten().type(torch.long)
            with torch.no_grad():
                logp_pred = log_softmax(model(X)) # batch_size, CLASSES
                nll_loss = categorical_loss(logp_pred, Y) # Q_SAMPLES_BNN * batch_size

                if cuda:
                    nll_loss = nll_loss.cuda()
                sample_nll_loss += nll_loss
        nll_loss_list.append(sample_nll_loss.reshape(1))

    nll_loss = torch.cat(nll_loss_list, dim=0)
    nll_loss_total -= torch.logsumexp(-nll_loss, dim=0)
    print(f"nll: {nll_loss_total}")
    
    if nll_out is not None:
        nll_out.write('{} {} \n'.format(epoch, nll_loss_total.item()))
        nll_out.flush()


def run():
    t0 = time.time()
    if log_metrics:
        save_dir = f"bnn_results/cifar10/{reparam}/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        test_acc_out = open(save_dir+'seed{}_test_acc.txt'.format(seed), 'w')
        train_acc_out = open(save_dir+'seed{}_train_acc.txt'.format(seed), 'w')
        train_grad_var_out = open(save_dir+'seed{}_train_grad_var.txt'.format(seed), 'w')
        test_nll_out = open(save_dir+'seed{}_test_nll.txt'.format(seed), 'w')
        train_nll_out = open(save_dir+'seed{}_train_nll.txt'.format(seed), 'w')
    else:
        test_acc_out = None
        train_acc_out = None
        train_grad_var_out = None
        test_nll_out = None
        train_nll_out = None

    for epoch in range(1, num_epochs + 1):
        train(epoch=epoch)

        # printing step
        if epoch % print_every == 0:
            print_nll(epoch=epoch, data_loader=logs_test_loader, nll_out=test_nll_out)
            print_nll(epoch=epoch, data_loader=logs_train_loader, nll_out=train_nll_out)
            print_accuracy(epoch=epoch, data_loader=logs_test_loader, acc_out=test_acc_out)
            print_accuracy(epoch=epoch, data_loader=logs_train_loader, acc_out=train_acc_out)
            print_gradient_variance(epoch=epoch, train_grad_var_out=train_grad_var_out)

    elapsed_time = time.time() - t0
    print(f"elapsed_time: {elapsed_time}")


if __name__ == '__main__':
    run()
