"""
Gaussian BNN trained with variational lower bound for MNIST
"""
import argparse
import os
import time

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

from lib.bnn_models import GaussianLinear, MnistBNN, BNNVarEstimator
from lib.constants import Q_SAMPLES_BNN

torch.set_default_dtype(torch.float64)
_reparam_types = ["rt", "lrt", "r2g2"]
_fwd_types = ["rt", "lrt"]

# parsing arguments
parser = argparse.ArgumentParser("MNIST BNN demo")

parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--num_epochs", type=int, default=20)
parser.add_argument("--batch_size_train", type=int, default=100)
parser.add_argument("--batch_size_test", type=int, default=10000)
parser.add_argument("--print_every", type=int, default=10)
parser.add_argument(
    "--reparam", type=str, default="r2g2", choices=_reparam_types
)
parser.add_argument(
    "--fwd", type=str, default="lrt", choices=_fwd_types
)
parser.add_argument("--log_metrics", action="store_false")

args = parser.parse_args()

# read args
seed = args.seed
lr = args.lr
num_epochs = args.num_epochs
batch_size_train = args.batch_size_train
batch_size_test = args.batch_size_test
print_every = args.print_every
reparam = args.reparam
fwd = args.fwd
log_metrics = args.log_metrics

if reparam != "r2g2":
    fwd = ""

# data parameters
data_train_kwargs = {
    "root": "data/",
    "train": True,
    "download": True,
    "transform": transforms.ToTensor(),
}
data_test_kwargs = {
    "root": "data/",
    "train": False,
    "download": True,
    "transform": transforms.ToTensor(),
}

# load pytorch dataset
data_train = datasets.MNIST(**data_train_kwargs)
data_test = datasets.MNIST(**data_test_kwargs)

# setting seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

# network architecture
model = MnistBNN(reparam=reparam, fwd=fwd)

# setup for data loaders
train_kwargs = {"batch_size": batch_size_train, "shuffle": True}
grad_var_kwargs = {"batch_size": batch_size_train, "shuffle": False}
logs_kwargs = {"batch_size": batch_size_test, "shuffle": False}

cuda = torch.cuda.is_available()
if cuda:
    model = model.cuda()
    cuda_kwargs = {"num_workers": 8, "pin_memory": True}
    train_kwargs.update(cuda_kwargs)
    grad_var_kwargs.update(cuda_kwargs)
    logs_kwargs.update(cuda_kwargs)

train_loader = DataLoader(data_train, **train_kwargs)
grad_var_loader = DataLoader(data_train, **grad_var_kwargs)
logs_test_loader = DataLoader(data_test, **logs_kwargs)
logs_train_loader = DataLoader(data_train, **logs_kwargs)

# optimiser settings
categorical_loss = torch.nn.NLLLoss(reduction="sum")
categorical_loss_test = torch.nn.NLLLoss(reduction="none")
log_softmax = torch.nn.LogSoftmax(dim=-1)
model_parameters = list(model.parameters())
opt = torch.optim.Adam(model_parameters, lr=lr)


def train(epoch):
    # training step
    model.train()
    for _, (X_train, Y_train) in enumerate(
        tqdm(train_loader, desc=f"epoch {epoch} minibatches")
    ):
        if cuda:
            X_train, Y_train = X_train.cuda(), Y_train.cuda()
        X_train, Y_train = X_train.type(torch.double), Y_train.flatten().type(
            torch.long
        )
        opt.zero_grad()
        logp_pred = log_softmax(model(X_train))
        loss_train = categorical_loss(logp_pred, Y_train)
        loss_train *= len(train_loader.dataset) / X_train.shape[0]

        # calculate kl components
        kl_term = torch.sum(torch.zeros(1, dtype=torch.float))
        if cuda:
            kl_term = kl_term.cuda()

        for module in model.net.modules():
            if isinstance(module, GaussianLinear):
                params_all = module.get_params()
                q_mean, q_stddev = params_all[:, 0], params_all[:, 1]
                prior_mean = torch.zeros(params_all.shape[0])
                prior_stddev = torch.ones(params_all.shape[0])

                if cuda:
                    q_mean, q_stddev = q_mean.cuda(), q_stddev.cuda()
                    prior_mean = prior_mean.cuda()
                    prior_stddev = prior_stddev.cuda()

                q = torch.distributions.Normal(q_mean, q_stddev)
                p = torch.distributions.Normal(prior_mean, prior_stddev)

                kl = torch.distributions.kl.kl_divergence(q, p).sum()
                kl_term += kl

        if cuda:
            loss_train = loss_train.cuda()
            kl_term = kl_term.cuda()

        vi_loss = loss_train + kl_term

        vi_loss.backward()
        opt.step()


def print_accuracy(epoch, data_loader, acc_out=None):
    model.eval()

    num_correct_train = 0.0
    total_train = Q_SAMPLES_BNN * len(data_loader.dataset)

    for _, _ in enumerate(
        tqdm(range(Q_SAMPLES_BNN), desc=f"epoch {epoch} minibatches")
    ):
        for X, Y in data_loader:
            if cuda:
                X, Y = X.cuda(), Y.cuda()
            X, Y = X.type(torch.double), Y.flatten().type(torch.long)

            with torch.no_grad():
                logp_pred = log_softmax(model(X)) # batch_size * Q_SAMPLES_BNN, CLASSES
                Y_pred = logp_pred.argmax(dim=-1) # batch_size * Q_SAMPLES_BNN
                num_correct_train += torch.sum(Y_pred == Y).item()

    prop_correct_train = num_correct_train / total_train
    print(f"acc: {100 * prop_correct_train:2.2f}")

    if acc_out is not None:
        acc_out.write('{} {:2.2f}\n'.format(epoch, 100 * prop_correct_train))
        acc_out.flush()



def print_gradient_variance(epoch, train_grad_var_out=None):
    model.train()
    var_estimator_top = BNNVarEstimator(model, marker="top")
    var_estimator_bottom = BNNVarEstimator(model, marker="bottom")

    X, Y = next(iter(grad_var_loader))
    if cuda:
        X, Y = X.cuda(), Y.cuda()
    X, Y = X.type(torch.double), Y.flatten().type(torch.long)

    for _ in range(100):
        opt.zero_grad()
        logp_pred = log_softmax(model(X))
        loss_train = categorical_loss(logp_pred, Y)
        loss_train *= len(train_loader.dataset) / X.shape[0]

        # calculate kl components
        kl_term = torch.sum(torch.zeros(1, dtype=torch.float))
        if cuda:
            kl_term = kl_term.cuda()

        for module in model.net.modules():
            if isinstance(module, GaussianLinear):
                params_all = module.get_params()
                q_mean, q_stddev = params_all[:, 0], params_all[:, 1]
                prior_mean = torch.zeros(params_all.shape[0])
                prior_stddev = torch.ones(params_all.shape[0])

                if cuda:
                    q_mean, q_stddev = q_mean.cuda(), q_stddev.cuda()
                    prior_mean = prior_mean.cuda()
                    prior_stddev = prior_stddev.cuda()

                q = torch.distributions.Normal(q_mean, q_stddev)
                p = torch.distributions.Normal(prior_mean, prior_stddev)

                kl = torch.distributions.kl.kl_divergence(q, p).sum()
                kl_term += kl

        if cuda:
            loss_train = loss_train.cuda()
            kl_term = kl_term.cuda()

        vi_loss = loss_train + kl_term
        vi_loss.backward()
        var_estimator_top.update(model)
        var_estimator_bottom.update(model)

    var_top = var_estimator_top.get_var()
    var_bottom = var_estimator_bottom.get_var()
    print('Variance top {}'.format(var_top))
    print('Variance bottom {}'.format(var_bottom))
    if train_grad_var_out is not None:
        train_grad_var_out.write('{} {} {} \n'.format(epoch, var_top, var_bottom))
        train_grad_var_out.flush()


def print_nll(epoch, data_loader, nll_out=None):
    model.eval()

    nll_loss_total = torch.sum(torch.zeros(1, dtype=torch.float))
    nll_loss_total += (len(data_loader.dataset) * np.log(Q_SAMPLES_BNN))
    if cuda:
        nll_loss_total = nll_loss_total.cuda()
    nll_loss_list = []

    for _, _ in enumerate(
        tqdm(range(Q_SAMPLES_BNN), desc=f"epoch {epoch} minibatches")
    ):
        sample_nll_loss = torch.sum(torch.zeros(1, dtype=torch.float))
        if cuda:
            sample_nll_loss = sample_nll_loss.cuda()
        for X, Y in data_loader:
            if cuda:
                X, Y = X.cuda(), Y.cuda()
            X, Y = X.type(torch.double), Y.flatten().type(torch.long)
            with torch.no_grad():
                logp_pred = log_softmax(model(X)) # batch_size, CLASSES
                nll_loss = categorical_loss(logp_pred, Y) # Q_SAMPLES_BNN * batch_size

                if cuda:
                    nll_loss = nll_loss.cuda()
                sample_nll_loss += nll_loss
        nll_loss_list.append(sample_nll_loss.reshape(1))

    nll_loss = torch.cat(nll_loss_list, dim=0)
    nll_loss_total -= torch.logsumexp(-nll_loss, dim=0)
    print(f"nll: {nll_loss_total}")
    
    if nll_out is not None:
        nll_out.write('{} {} \n'.format(epoch, nll_loss_total.item()))
        nll_out.flush()


def run():
    t0 = time.time()
    if log_metrics:
        save_dir = f"bnn_results/mnist/{reparam}/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        if reparam == "r2g2":
            save_dir = f"bnn_results/mnist/{reparam}/{fwd}/"
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

        test_acc_out = open(save_dir+'seed{}_test_acc.txt'.format(seed), 'w')
        train_acc_out = open(save_dir+'seed{}_train_acc.txt'.format(seed), 'w')
        train_grad_var_out = open(save_dir+'seed{}_train_grad_var.txt'.format(seed), 'w')
        test_nll_out = open(save_dir+'seed{}_test_nll.txt'.format(seed), 'w')
        train_nll_out = open(save_dir+'seed{}_train_nll.txt'.format(seed), 'w')
    else:
        test_acc_out = None
        train_acc_out = None
        train_grad_var_out = None
        test_nll_out = None
        train_nll_out = None

    for epoch in range(1, num_epochs + 1):
        train(epoch=epoch)
        print_gradient_variance(epoch=epoch, train_grad_var_out=train_grad_var_out)

        # printing step
        if epoch % print_every == 0:
            print_nll(epoch=epoch, data_loader=logs_test_loader, nll_out=test_nll_out)
            print_nll(epoch=epoch, data_loader=logs_train_loader, nll_out=train_nll_out)
            print_accuracy(epoch=epoch, data_loader=logs_test_loader, acc_out=test_acc_out)
            print_accuracy(epoch=epoch, data_loader=logs_train_loader, acc_out=train_acc_out)
            

    elapsed_time = time.time() - t0
    print(f"elapsed_time: {elapsed_time}")


if __name__ == '__main__':
    run()
