import os
import sys
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np

# from bnn import BNN
from bnn_vimc import BNN, BNNF, BNNMC
import gpytorch
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gp_prior import ExactGPModel, prior_sample_functions, DirichletGPModel

import argparse

from utils.logging import get_logger

from data import uci_woval
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets
from sklearn.metrics import roc_auc_score
from sklearn import metrics
import pandas as pd

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt


################################### Hyper-parameters ###################################

lr_bnn = 1e-2
prior_coeff = 10
bnn_name_string = 'VIMC'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 2000
num_sample = 1000
epochs = 601    # 102
n_step_prior_pretraining = 500
lr_optbnn = 0.01
f_coeff = 10

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 50   # 800
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * 255. / 126.),  # divide as in paper
        ])

transform2 = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

train_data = datasets.MNIST(
            root='data',
            train=True,
            download=False,
            transform=transform)

test_data = datasets.MNIST(
            root='data',
            train=False,
            download=False,
            transform=transform)


train_data_fm = datasets.FashionMNIST(
            root='data2',
            train=True,
            download=False,
            transform=transform)

test_data_fm = datasets.FashionMNIST(
            root='data',
            train=False,
            download=False,
            transform=transform)

train_data_cf = datasets.CIFAR10(
            root='data',
            train=True,
            download=False,
            transform=transform2)

test_data_cf = datasets.CIFAR10(
            root='data',
            train=False,
            download=False,
            transform=transform2)
train_data_sv = datasets.SVHN(
            root='data',
            split='train',
            download=True,
            transform=transform2)

training_num = len(train_data_fm)
test_num = len(test_data_fm)
input_dim = 28 * 28
# input_dim = 32 * 32 * 3
output_dim = 10
is_continuous = True
batch_size = 125
eval_batch_size = 1000

# obtain training indices that will be used for validation
valid_size = 1 / 6
indices = list(range(training_num))
split = int(valid_size * training_num)
train_idx, valid_idx = indices[split:], indices[:split]
oob_in_idx = train_idx[:10000]
train_idx = train_idx[40000:50000]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
oob_in_sampler = SubsetRandomSampler(oob_in_idx)

# miniset of training data and test data
mini_train_size = 1 / 12
mini_train_index = int(mini_train_size * training_num)
mini_train_idx = indices[:mini_train_index]
mini_train_sampler = SubsetRandomSampler(mini_train_idx)

indices_test = list(range(test_num))
mini_test_size = 1 / 2
mini_test_index = int(mini_test_size * test_num)
mini_test_idx = indices[:mini_test_index]
mini_test_sampler = SubsetRandomSampler(mini_test_idx)

######### define data loader ############

train_loader = torch.utils.data.DataLoader(
        train_data_fm,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=1)
valid_loader = torch.utils.data.DataLoader(
    train_data_fm,
    batch_size=batch_size,
    sampler=valid_sampler,
    num_workers=1)
test_loader = torch.utils.data.DataLoader(
        test_data_fm,
        batch_size=eval_batch_size,
        num_workers=1)
pre_loader = torch.utils.data.DataLoader(
        test_data_fm,
        batch_size=10000,
        num_workers=1)

oob_out_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    sampler=valid_sampler,
    num_workers=1)
oob_in_loader = torch.utils.data.DataLoader(
    train_data_fm,
    batch_size=batch_size,
    sampler=oob_in_sampler,
    num_workers=1)

mini_train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    sampler=mini_train_sampler,
    num_workers=1)

mini_test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=eval_batch_size,
    sampler=mini_test_sampler,
    num_workers=1)

print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)



def sample_measurement_set2(X, num_data):
    n = torch.Tensor([60])
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

for batch_id, (data, target) in enumerate(pre_loader):
    train_x, train_y = data.to(device), target.to(device)

original_x_train_add_g3 = sample_measurement_set2(train_x, num_data=10000)
original_x_train_add_g3 = torch.reshape(original_x_train_add_g3, (60, 784))


print("============================================================================================")
################################## define model ##################################

# bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
opt_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

# bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(),  'lr': lr_bnn}])
# opt_bnn_optimizer = torch.optim.RMSprop([{'params': opt_bnn.parameters(), 'lr': lr_optbnn}])
opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': lr_bnn}])

print("============================================================================================")
################################## pretrain GP ##################################


for batch_id, (data, target) in enumerate(pre_loader):
    train_x, train_y = data.to(device), target.to(device)

# print('train_x.shape: ', train_x.shape)
# print('train_y.shape: ', train_y.shape)
# train_x = train_x.squeeze()

train_x = torch.reshape(train_x, (10000, 784))
# train_x = torch.reshape(train_x, (10000, 3072))
mask = torch.randperm(10000)[:1000]
train_x = train_x[mask, :]
train_y = train_y[mask]
likelihood = DirichletClassificationLikelihood(train_y, learn_additional_noise=True)
prior = DirichletGPModel(train_x, likelihood.transformed_targets, likelihood, num_classes=likelihood.num_classes).to(device)
print('likelihood.num_classes: ', likelihood.num_classes)
prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(train_x)
    # pred_y = likelihood(prior(train_x)).mean
    # print('pred_y.shape: ', pred_y.shape)
    # Calc loss and backprop gradients
    # print('train_x.shape: ', train_x.shape)
    # print('train_y.shape: ', train_y.shape)
    # print('likelihood.transformed_targets.shape: ', likelihood.transformed_targets.shape)
    loss_gp = -mll(output, likelihood.transformed_targets).sum()
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f lengthscale: %.3f  noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        prior.covar_module.base_kernel.lengthscale.mean().item(),
        prior.likelihood.second_noise_covar.noise.mean().item()
    ))

    optimizer.step()

prior.eval()
likelihood.eval()

print("============================================================================================")

################################## pretrain GP2 ##################################


for batch_id, (data, target) in enumerate(pre_loader):
    train_x, train_y = data.to(device), target.to(device)

# print('train_x.shape: ', train_x.shape)
# print('train_y.shape: ', train_y.shape)
# train_x = train_x.squeeze()

train_x = torch.reshape(train_x, (10000, 784))
# train_x = torch.reshape(train_x, (10000, 3072))
mask = torch.randperm(10000)[:1000]
train_x = train_x[mask, :]
train_y = train_y[mask]
likelihood2 = DirichletClassificationLikelihood(train_y, learn_additional_noise=True)
prior2 = DirichletGPModel(train_x, likelihood2.transformed_targets, likelihood2, num_classes=likelihood2.num_classes).to(device)
print('likelihood.num_classes: ', likelihood2.num_classes)
prior2.train()
likelihood2.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior2.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood2, prior2)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior2(train_x)
    # pred_y = likelihood(prior(train_x)).mean
    # print('pred_y.shape: ', pred_y.shape)
    # Calc loss and backprop gradients
    # print('train_x.shape: ', train_x.shape)
    # print('train_y.shape: ', train_y.shape)
    # print('likelihood.transformed_targets.shape: ', likelihood.transformed_targets.shape)
    loss_gp = -mll(output, likelihood2.transformed_targets).sum()
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f lengthscale: %.3f  noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        prior2.covar_module.base_kernel.lengthscale.mean().item(),
        prior2.likelihood.second_noise_covar.noise.mean().item()
    ))

    optimizer.step()

prior2.eval()
likelihood2.eval()

print("============================================================================================")


############################################
###### linear_nn flow ###############
hidden_dimsflow = [100]
class LinearFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net   # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()
        x = z * self.weight + nnbias
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()   # [60, 10]
        # print('nnbias.shape: ', nnbias.shape)
        # print('x.shape: ', x.shape)
        z = x * 1/self.weight - nnbias/self.weight
        log_det_inverse = torch.sum(torch.log(1/self.weight.abs() + 1e-7))

        return z, log_det_inverse

##
class NonLinearFlow_nn(nn.Module):
    def __init__(self, net1, net2):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias1 = net1  # torch.randn(1)
        self.weight2 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias2 = net2  # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias1 = self.bias1(x_in).squeeze()
        nnbias2 = self.bias2(x_in).squeeze()
        mask = (z - nnbias2) / self.weight2
        clamped_mask = torch.clamp(mask, -0.99, 0.99)
        x = torch.atanh(clamped_mask) / self.weight - nnbias1 / self.weight
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias1, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias1 = self.bias1(x_in).squeeze()
        nnbias2 = self.bias2(x_in).squeeze()
        # print('nnbias1.shape: ', nnbias1)
        # print('nnbias2.shape: ', nnbias2)

        z = self.weight2 * torch.tanh(x * self.weight + nnbias1) + nnbias2

        det = self.weight2 * self.weight * (1 - torch.tanh(x * self.weight + nnbias1) ** 2)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse

########## sigmoid_flow #############
class SigmoidFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.weight2 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()

        mask = (z - nnbias) / self.weight2
        mask = torch.clamp(mask, min=0.01, max=0.99)
        clamped_mask = mask / (1 - mask)
        # x = torch.log(clamped_mask + 1e-7) / self.weight - nnbias1 / self.weight
        x = torch.log(clamped_mask + 1e-7) / self.weight
        log_det = 1.

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()

        mask = torch.nn.functional.sigmoid(x * self.weight)
        z = self.weight2 * mask + nnbias

        det = self.weight2 * self.weight * mask * (1 - mask)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse

########## exp_flow #############
class ExpFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net   # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()
        mask = (z - nnbias) / self.weight
        mask = torch.clamp(mask, min=0.001)
        mask = torch.log(mask + 1e-7)
        # mask = torch.log(mask.abs() + 1e-7)
        clamped_mask = torch.clamp(mask, -0.99, 0.99)
        x = torch.atanh(clamped_mask)
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()
        # print('nnbias.shape: ', nnbias.shape)
        # print('x.shape: ', x.shape)
        z = torch.exp(torch.tanh(x)) * self.weight + nnbias
        det = self.weight * torch.exp(torch.tanh(x)) * (1 - torch.tanh(x) ** 2)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse



######### def train and evalution #############################
def sample_measurement_set(X, num_data):
    n = torch.Tensor([40])
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

def lipf(X, Y):
    return ((X - Y).abs())

def train(loader):
    loss_sum = 0

    for batch_id, (data, target) in enumerate(loader):

        data, target = data.to(device), target.to(device)
        # print('data.shape: ', data.shape)  ([125, 1, 28, 28])
        data = torch.reshape(data, (125, 784))
        # data = torch.reshape(data, (125, 3072))

        measurement_set = sample_measurement_set(X=data, num_data=125)
        # gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
        # gpss = gpss.squeeze()
        gpss = prior(measurement_set).sample(torch.Size((128,)))
        gpss = torch.transpose(gpss, -1, 1)
        # print('gpss.shape: ', gpss.shape)
        nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
        nnet_samples = nnet_samples.squeeze()
        # print('nnet_sample.shape: ', nnet_samples.shape)

        functional_wdist = lipf(gpss, nnet_samples)
        functional_wdist = torch.mean(functional_wdist)

        ##################### variance constraint #########

        with torch.no_grad():
            gp_var = prior(measurement_set).variance
            gp_var = gp_var.t()
            # print('gp_var.shape: ', gp_var.shape)

        opt_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)
        opt_samples = opt_samples.squeeze()
        opt_var = torch.std(opt_samples, 0) ** 2
        # print('opt_var.shape: ', opt_var.shape)
        dif_var = (gp_var - opt_var).abs()
        # dif_var = torch.sum(dif_var)
        dif_var = torch.mean(dif_var)


        ### calculate bnn likelihood loss ###
        pred_y = opt_bnn.forward(data)
        # print('prior_distance: ', distance_prior)
        pred_y = F.softmax(pred_y, dim=1)
        # print(pred_y.shape)
        # print(target.shape)
        output = torch.log(pred_y)
        train_loss = F.nll_loss(output, target, reduction='sum') + 10 * functional_wdist + 1 * dif_var
        # print(pred_y)

        likelihood_loss = F.nll_loss(output, target, reduction='sum')
        # print('batch_id: ', batch_id)
        # print('likelihood_loss: ', likelihood_loss)
        # print('distance_prior: ', bnn_2wassdist)
        # print('w_dist: ', functional_wdist)

        loss_sum += train_loss / len(loader)
        # optimisation

        opt_bnn_optimizer.zero_grad()
        train_loss.backward()

        opt_bnn_optimizer.step()

    return loss_sum, likelihood_loss, functional_wdist, dif_var


def evaluate(model, loader, samples=1):
    acc_sum = 0
    model.eval()
    for idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        data = torch.reshape(data, (1000, 784))
        if samples == 1:
            output = model.forward(data)
            output = F.softmax(output, dim=1)
        else:
            output = model.forward(data)
            output = F.softmax(output, dim=1)
            for i in range(samples - 1):
                output += F.softmax(model.forward(data))
        # print('output.shape: ', output.shape)
        predict = output.data.max(1)[1]
        acc = predict.eq(target.data).cpu().sum().item()
        acc_sum += acc
    return acc_sum / len(loader)


################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')

################# S1: FVI #################################

train_loss_all = []
dif_var_all = []
wdist_all = []
likelihood_loss_all = []

for epoch in range(11):
    train_loss, likelihood, w1, difv = train(train_loader)
    # print('train_loss: ', train_loss)
    test_acc = evaluate(opt_bnn, test_loader, samples=10)
    # print('test_acc: ', round(100 * (test_acc / eval_batch_size), 3))
    test_err = round((1 - test_acc / eval_batch_size), 4)
    # print('test_err: ', test_err)

    train_loss_all.append(train_loss.item())
    likelihood_loss_all.append(likelihood.item())
    dif_var_all.append(difv.item())
    wdist_all.append(w1.item())

    # print
    print("epoch : {} \t\t training loss \t\t : {} \t\t test_error \t\t : {}".format(epoch, train_loss,
                                                                                         test_err),
          datetime.now().replace(microsecond=0) - start_time)


    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        opt_bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

print("--------------------------------------------------------------------------------------------")
######### S2: FMCMC #####################

w_mu_list = []
b_mu_list = []
w_std_list = []
b_std_list = []

for name, p in opt_bnn.named_parameters():
    if "W_mu" in name:
        # w_prior_mu = p
        w_mu_list.append(p)
    if 'b_mu' in name:
        # b_prior_mu = p
        b_mu_list.append(p)
    if 'W_std' in name:
        # w_prior_std = p
        w_std_list.append(p)
    if 'b_std' in name:
        # b_prior_std = p
        b_std_list.append(p)


bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous,
              W_mu=w_mu_list, b_mu=b_mu_list, W_std=w_std_list, b_std=b_std_list, scaled_variance=True).to(device)

# bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

bnnmc_optimizer = torch.optim.Adam(bnnmc.parameters(), lr=0.1)  # yacht-energy: 0.001, whine-protein: 0.01
scheduler = torch.optim.lr_scheduler.StepLR(bnnmc_optimizer, 20, gamma=0.99, last_epoch=-1) #0.999  (20_0.99)

criterion = nn.CrossEntropyLoss()

def noise_loss(lr):
    noise_loss = 0.0
    noise_std = (2/lr)**0.5
    for var in bnnmc.parameters():
        means = torch.zeros(var.size()).to(device)
        sigma = torch.normal(means, std=noise_std).to(device)
        noise_loss += torch.sum(var * sigma)
    return noise_loss


def train_mc(loader, epoch):
    loss_sum = 0

    for batch_id, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        # print('data.shape: ', data.shape)  ([125, 1, 28, 28])
        data = torch.reshape(data, (125, 784))

        pred_y = bnnmc.forward(data)

        pred_y = F.softmax(pred_y, dim=1)
        # print(pred_y.shape)
        # print(target.shape)
        output = torch.log(pred_y)
        t_like = F.nll_loss(output, target, reduction='sum') #'sum' 'mean'

        # t_like = criterion(pred_y, target)
        #
        ###################  measurement set for d_logp
        m_set = sample_measurement_set(X=data, num_data=125)
        y_m_set = bnnmc.forward(m_set).squeeze()
        # print('y_m_set.shape: ', y_m_set.shape)
        mset = torch.reshape(m_set, (40, 784))
        # mset = torch.reshape(m_set, (40, 3072))

        with torch.no_grad():
            fprior_m = likelihood2(prior2(mset))
            prior_mean_m = fprior_m.mean
            # print('prior_mean.shape: ', prior_mean_m.shape)
            prior_mean_m = prior_mean_m.t()
            # print('prior_mean.shape: ', prior_mean_m.shape)
            prior_var_m = fprior_m.variance
            # print('prior_var.shape: ', prior_var_m.shape)
            prior_var_m = prior_var_m.t()
        #
        d_logp_m = ((y_m_set - prior_mean_m)/prior_var_m).detach()
        # print('d_log_m:', d_logp_m)
        d_logp_m = y_m_set * d_logp_m
        d_logp_m = torch.mean(d_logp_m)
        # print('d_log_m:', d_logp_m)
        train_loss = t_like + 1e-2 * d_logp_m  # 1e-3   1e-2

        if epoch > 70:    # 30
            noise = noise_loss(lr_bnn)
            train_loss = train_loss + 1e-8 * noise   # 1e-8

        loss_sum += train_loss / len(loader)
        # optimisation
        bnnmc_optimizer.zero_grad()
        train_loss.backward()
        bnnmc_optimizer.step()
        scheduler.step()


    return loss_sum


#######
mt = 0

for epoch in range(11):
    train_loss = train_mc(train_loader, epoch)
    # print('train_loss: ', train_loss)
    test_acc = evaluate(bnnmc, test_loader, samples=1)
    # print('test_acc: ', round(100 * (test_acc / eval_batch_size), 3))
    test_err = round((1 - test_acc / eval_batch_size), 4)
    # print('test_err: ', test_err)

    # print
    print("epoch : {} \t\t training loss \t\t : {} \t\t test_error \t\t : {}".format(epoch, train_loss,
                                                                                         test_err),
          datetime.now().replace(microsecond=0) - start_time)

    ##############
    if epoch > 0 and (epoch % 10) == 0:  # (100, 3)
        print('save!')
        bnnmc.cpu()
        # torch.save(bnn.state_dict(), save_path + '/g3_rbf_%i.pt'%(mt))
        dd = f'{mt}'
        print(dd)
        save_path = results_folder + '/' + bnn_name_string + '/' + 'fsgld_mnist_t' + '/' + 'mnist_' + dd
        # if not os.path.exists(save_path):
        #     os.makedirs(save_path)
        bnnmc.save(save_path)
        mt += 1
        bnnmc.to(device)

#####################
def testp():
    bnnmc.eval()

    pred_test_y = bnnmc.forward(original_x_test).squeeze().detach()
    return pred_test_y

def test2():
    bnnmc.eval()

    pred_mset = bnnmc.forward(original_x_train_add_g3).squeeze().detach()
    return pred_mset


################ load mcmc model ################
pred_list = []
predm_list = []
num_model = 10 # 10 100 80 70

for m in range(num_model):
    dd = f'{m}'
    bnnmc.load_state_dict(torch.load(results_folder + '/' + bnn_name_string + '/' + 'fsgld_mnist_t' + '/' + 'mnist_' + dd))  # sgld, sghmc

    predm = test2()
    predm_list.append(predm)
    mset_fmc = torch.stack(predm_list)
    print('mset_fmc.shape: ', mset_fmc.shape)   # [50, 60, 10]


############### GP Linearization for FVI posterior ###############
opt_bnn2 = opt_bnn
opt_bnn_optimizer2 = torch.optim.Adam([{'params': opt_bnn2.parameters(), 'lr': lr_bnn}])

w_mu, w_std = opt_bnn2.sample_mu_std()
w_mu = w_mu.detach()
w_std = w_std.detach()
print('w_std.shape: ', w_std.shape)  # [10401]
w_dim = w_std.shape[-1]
print('w_dim: ', w_std.shape[-1])
w_cov = w_std ** 2 * torch.eye(w_dim).to(device)
print('w_cov.shape: ', w_cov.shape)  # [10401, 10401]

grads_x_mset = None

#
for x_i in original_x_train_add_g3:

    pred_y_i = opt_bnn2.forward_mu(x_i)
    pred_y_i = pred_y_i.squeeze().flatten()
    # print('pred_y_i.shape', pred_y_i.shape)

    # evaluate gradient
    opt_bnn_optimizer2.zero_grad()

    for j in range(10):

        pred_y_i[j].backward(retain_graph=True)

        grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
        grad_i = grad_i[:, None]
        # print('grad_i.shape: ', grad_i.shape)

        if grads_x_mset is None:
            grads_x_mset = torch.transpose(grad_i, 0, 1)
        else:
            grads_x_mset = torch.cat((grads_x_mset, torch.transpose(grad_i, 0, 1)), 0)

print('grads_x_mset.shape: ', grads_x_mset.shape)   # [600, ...]



prior_mean_m = opt_bnn2.forward_mu(original_x_train_add_g3).squeeze().detach()
print('prior_mean_m: ', prior_mean_m.shape)    # [60, 10]
prior_mean_m = torch.reshape(prior_mean_m, (1, 600))
print('prior_mean_m: ', prior_mean_m.shape)

prior_var_m = grads_x_mset @ w_cov @ torch.transpose(grads_x_mset, 0, 1)
print('prior_var_m: ', prior_var_m.shape)     # [600, 600]

prior_var_m = torch.diag(prior_var_m, 0)
print('prior_var_m: ', prior_var_m.shape)     # [600]
##


######################### train LinerFolw_nn ##############

flownet = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(device)
flownet2 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(device)

flow = LinearFlow_nn(flownet).to(device)
# flow = CubeFlow_nn(flownet).to(device)
# flow = NonLinearFlow_nn(flownet, flownet2).to(device)
# flow = SigmoidFlow_nn(flownet).to(device)
# flow = ExpFlow_nn(flownet).to(device)
flow_optimizer = torch.optim.Adam(flow.parameters(), lr=0.01)


loss_flow_all = []
flow_likelihood_list = []
flow_det_list = []
flow_loss_list = []

for epoch in range(5000):
    Z, logdet = flow.inverse(mset_fmc, original_x_train_add_g3)
    print('Z.SHAPE: ', Z.shape)   # [50, 60, 10]
    contains_nan = torch.isnan(Z).any()
    print(contains_nan)

    Z = torch.reshape(Z, (10, 600))

    loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
        torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet)

    flow_optimizer.zero_grad()
    loss_flow.backward(retain_graph=True)
    flow_optimizer.step()
    loss_flow_all.append(loss_flow.item())


    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_flow),
          datetime.now().replace(microsecond=0) - start_time)

flow.eval()
a = flow.weight

print('flow.weight: ', flow.weight)

##
acc_sum = 0

for idx, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)

    nnet_samples = opt_bnn.sample_functions(data, num_sample=10).float().to(device)  # original_x_test
    nnet_samples = nnet_samples.squeeze().detach()

    fvi, nnbias_test, _ = flow(nnet_samples, data)
    fvi = fvi.detach()
    fvi = torch.sum(fvi, 0)
    print('fvi.shape: ', fvi.shape)

    output = F.softmax(fvi, dim=1)

    predict = output.data.max(1)[1]
    acc = predict.eq(target.data).cpu().sum().item()
    acc_sum += acc

test_acc = acc_sum / len(test_loader)
test_err_flow = round((1 - test_acc / eval_batch_size), 4)

print('test_err_flow: ', test_err_flow)


################################# FVI-MCMC LOOP #######################

flow_list = [flow]
a_list = [a]
# logdet_list = [flogdet]
cum_det_list = []
flow_name_list = []
##
all_train_losses = []
all_wdist_losses = []
all_likelihood_losses = []
all_m2_losses = []

###
for i in range(2):    # 13
    loop_count = 200 * i
    print('loop_count: ', loop_count)
    outloop = i
    ################# S1: FVI #############################
    train_loss_all = []
    prior_distance_all = []
    wdist_all = []
    likelihood_loss_all = []
    dif_var_all = []
    opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': 0.01}])   #lr_bnn

    for epoch in range(101):  # 1000 4001
        loss_sum = 0
        ## train flow vi

        for batch_id, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            # print('data.shape: ', data.shape)  ([125, 1, 28, 28])
            data = torch.reshape(data, (125, 784))
            # data = torch.reshape(data, (125, 3072))

            measurement_set = sample_measurement_set(X=data, num_data=125)   # [40]
            # gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
            # gpss = gpss.squeeze()
            gpss = prior(measurement_set).sample(torch.Size((128,)))
            gpss = torch.transpose(gpss, -1, 1)
            # print('gpss.shape: ', gpss.shape)      # [128, 40, 10]
            nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)
            nnet_samples = nnet_samples.squeeze()
            # print('nnet_sample.shape: ', nnet_samples.shape)   # [128, 40, 10]
            for flow in flow_list:

                nnet_samples, _, _ = flow(nnet_samples, measurement_set)

            functional_wdist = lipf(gpss, nnet_samples)
            functional_wdist = torch.mean(functional_wdist)

            ##################### variance constraint #########

            with torch.no_grad():
                gp_var = prior(measurement_set).variance
                gp_var = gp_var.t()
                # print('gp_var.shape: ', gp_var.shape)   # [40, 10]

            opt_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)
            opt_samples = opt_samples.squeeze()

            for flow in flow_list:

                opt_samples, _, _ = flow(opt_samples, measurement_set)

            opt_var = torch.std(opt_samples, 0) ** 2
            # print('opt_var.shape: ', opt_var.shape)    # [40, 10]
            dif_var = (gp_var - opt_var).abs()
            # dif_var = torch.sum(dif_var)
            dif_var = torch.mean(dif_var)

            ### calculate bnn likelihood loss ###
            pred_y_vi2 = opt_bnn.forward(data)
            # print('prior_distance: ', distance_prior)
            for flow in flow_list:

                pred_y_vi2, _, _ = flow(pred_y_vi2, data)

            pred_y_vi2 = F.softmax(pred_y_vi2, dim=1)

            output = torch.log(pred_y_vi2)
            train_loss = 10 * F.nll_loss(output, target, reduction='sum') + 10 * functional_wdist + 1 * dif_var   # 10_10_1
            # print(pred_y)

            likelihood_loss = F.nll_loss(output, target, reduction='sum')


            loss_sum += train_loss / len(train_loader)
            # optimisation

            opt_bnn_optimizer.zero_grad()
            train_loss.backward()

            opt_bnn_optimizer.step()

        train_loss_vi = loss_sum

        ######### test acc vi ###########
        acc_sum = 0

        opt_bnn.eval()

        for idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            nnet_samples = opt_bnn.sample_functions(data, num_sample=10).float().to(device)  # original_x_test
            nnet_samples = nnet_samples.squeeze().detach()   # [10 ,1000, 10]

            for flow in flow_list:

                nnet_samples, _, _ = flow(nnet_samples, data)

            nnet_samples = nnet_samples.detach()
            nnet_samples = torch.sum(nnet_samples, 0)
            # print('fvi.shape: ', nnet_samples.shape)   [1000, 10]

            output = F.softmax(nnet_samples, dim=1)

            predict = output.data.max(1)[1]
            acc = predict.eq(target.data).cpu().sum().item()
            acc_sum += acc

        test_acc = acc_sum / len(test_loader)
        test_err_vi = round((1 - test_acc / eval_batch_size), 4)

        opt_bnn.train()

        # print('test_err_flow: ', test_err_flow)

        # print
        print("epoch : {} \t\t training loss \t\t : {} \t\t test_error \t\t : {}".format(epoch, train_loss_vi,
                                                                                         test_err_vi),
              datetime.now().replace(microsecond=0) - start_time)

    print("============================================================================================")

    #####
    ######### S2: FMCMC #####################

    w_mu_list = []
    b_mu_list = []
    w_std_list = []
    b_std_list = []

    for name, p in opt_bnn.named_parameters():
        if "W_mu" in name:
            # w_prior_mu = p
            w_mu_list.append(p)
        if 'b_mu' in name:
            # b_prior_mu = p
            b_mu_list.append(p)
        if 'W_std' in name:
            # w_prior_std = p
            w_std_list.append(p)
        if 'b_std' in name:
            # b_prior_std = p
            b_std_list.append(p)

    bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous,
                  W_mu=w_mu_list, b_mu=b_mu_list, W_std=w_std_list, b_std=b_std_list, scaled_variance=True).to(device)

    # bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

    bnnmc_optimizer = torch.optim.Adam(bnnmc.parameters(), lr=0.1)  # yacht-energy: 0.001, whine-protein: 0.01
    scheduler = torch.optim.lr_scheduler.StepLR(bnnmc_optimizer, 50, gamma=0.99, last_epoch=-1)  # 0.999  (20_0.99)

    criterion = nn.CrossEntropyLoss()


    def noise_loss(lr):
        noise_loss = 0.0
        noise_std = (2 / lr) ** 0.5
        for var in bnnmc.parameters():
            means = torch.zeros(var.size()).to(device)
            sigma = torch.normal(means, std=noise_std).to(device)
            noise_loss += torch.sum(var * sigma)
        return noise_loss

    ######### SGLD #####
    mt = 0
    for epoch in range(101):
        loss_sum = 0

        for batch_id, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            # print('data.shape: ', data.shape)  ([125, 1, 28, 28])
            data = torch.reshape(data, (125, 784))

            pred_y_mc = bnnmc.forward(data)

            for flow in flow_list:

                pred_y_mc, _, _ = flow(pred_y_mc, data)

            t_like = criterion(pred_y_mc, target)
            #
            # pred_y_mc = F.softmax(pred_y_mc, dim=1)
            # output = torch.log(pred_y_mc)
            # # train_loss = F.nll_loss(output, target, reduction='mean') #'sum'
            # t_like = F.nll_loss(output, target, reduction='sum')

            ###################  measurement set for d_logp
            m_set = sample_measurement_set(X=data, num_data=125)
            y_m_set = bnnmc.forward(m_set).squeeze()
            # print('y_m_set.shape: ', y_m_set.shape)
            for flow in flow_list:

                y_m_set, _, _ = flow(y_m_set, m_set)

            mset = torch.reshape(m_set, (40, 784))
            # mset = torch.reshape(m_set, (40, 3072))

            with torch.no_grad():
                fprior_m = likelihood2(prior2(mset))
                prior_mean_m = fprior_m.mean
                # print('prior_mean.shape: ', prior_mean_m.shape)
                prior_mean_m = prior_mean_m.t()
                # print('prior_mean.shape: ', prior_mean_m.shape)
                prior_var_m = fprior_m.variance
                # print('prior_var.shape: ', prior_var_m.shape)
                prior_var_m = prior_var_m.t()
            #
            d_logp_m = ((y_m_set - prior_mean_m) / prior_var_m).detach()
            # print('d_log_m:', d_logp_m)
            d_logp_m = y_m_set * d_logp_m
            d_logp_m = torch.mean(d_logp_m)
            # print('d_log_m:', d_logp_m)
            train_loss = t_like + 1 * d_logp_m  # 1e-2   1

            if epoch > 70:  # 30
                noise = noise_loss(lr_bnn)
                train_loss = train_loss + 1e-8 * noise  # 1e-8

            loss_sum += train_loss / len(train_loader)
            # optimisation
            bnnmc_optimizer.zero_grad()
            train_loss.backward()
            bnnmc_optimizer.step()
            scheduler.step()

        train_loss_mc = loss_sum

        ######## test acc mcmc ######
        acc_sum = 0

        bnnmc.eval()

        for idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            nnet_samples = bnnmc.sample_functions(data, num_sample=1).float().to(device)  # original_x_test
            nnet_samples = nnet_samples.squeeze().detach()  # [1000, 10]

            for flow in flow_list:

                nnet_samples, _, _ = flow(nnet_samples, data)

            nnet_samples = nnet_samples.detach()
            # print('fvi.shape: ', nnet_samples.shape)   [1000, 10]

            output = F.softmax(nnet_samples, dim=1)

            predict = output.data.max(1)[1]
            acc = predict.eq(target.data).cpu().sum().item()
            acc_sum += acc

        test_acc = acc_sum / len(test_loader)
        test_err_mc = round((1 - test_acc / eval_batch_size), 4)

        bnnmc.train()

        # print
        print("epoch : {} \t\t training loss \t\t : {} \t\t test_error \t\t : {}".format(epoch, train_loss_mc,
                                                                                         test_err_mc),
              datetime.now().replace(microsecond=0) - start_time)

        ##############
        if epoch > 50 and (epoch % 1) == 0:  # (100, 3)
            print('save!')
            bnnmc.cpu()
            # torch.save(bnn.state_dict(), save_path + '/g3_rbf_%i.pt'%(mt))
            dd = f'{mt}'
            print(dd)
            save_path = results_folder + '/' + bnn_name_string + '/' + 'fsgld_mnist_t' + '/' + 'mnist_' + dd
            # if not os.path.exists(save_path):
            #     os.makedirs(save_path)
            bnnmc.save(save_path)
            mt += 1
            bnnmc.to(device)

    print("============================================================================================")


    ######
    def test2():
        bnnmc.eval()

        pred_mset = bnnmc.forward(original_x_train_add_g3).squeeze().detach()

        for flow in flow_list:

            pred_mset, _, _ = flow(pred_mset, original_x_train_add_g3)

        pred_mset = pred_mset.detach()  # [60 10]

        return pred_mset

    ################ load mcmc model ################
    predm_list = []
    num_model = 50  # 10 100 80 70

    for m in range(num_model):
        dd = f'{m}'
        bnnmc.load_state_dict(torch.load(
            results_folder + '/' + bnn_name_string + '/' + 'fsgld_mnist_t' + '/' + 'mnist_' + dd))  # sgld, sghmc

        predm = test2()
        print('predm.shape: ', predm.shape)
        predm_list.append(predm)

    mset_fmc = torch.stack(predm_list)   # [50, 60, 10]

    ############### GP Linearization2 ###############
    opt_bnn2 = opt_bnn
    opt_bnn_optimizer2 = torch.optim.Adam([{'params': opt_bnn2.parameters(), 'lr': lr_bnn}])

    w_mu, w_std = opt_bnn2.sample_mu_std()
    w_mu = w_mu.detach()
    w_std = w_std.detach()
    print('w_std.shape: ', w_std.shape)  # [10401]
    w_dim = w_std.shape[-1]
    print('w_dim: ', w_std.shape[-1])
    w_cov = w_std ** 2 * torch.eye(w_dim).to(device)
    print('w_cov.shape: ', w_cov.shape)  # [10401, 10401]

    grads_x_mset = None

    #
    for x_i in original_x_train_add_g3:

        pred_y_i = opt_bnn2.forward_mu(x_i)
        pred_y_i = pred_y_i.squeeze().flatten()
        # print('pred_y_i.shape', pred_y_i.shape)

        # evaluate gradient
        opt_bnn_optimizer2.zero_grad()

        for j in range(10):

            pred_y_i[j].backward(retain_graph=True)

            grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
            grad_i = grad_i[:, None]
            # print('grad_i.shape: ', grad_i.shape)

            if grads_x_mset is None:
                grads_x_mset = torch.transpose(grad_i, 0, 1)
            else:
                grads_x_mset = torch.cat((grads_x_mset, torch.transpose(grad_i, 0, 1)), 0)

    print('grads_x_mset.shape: ', grads_x_mset.shape)  # [600, ...]

    prior_mean_m = opt_bnn2.forward_mu(original_x_train_add_g3).squeeze().detach()
    print('prior_mean_m: ', prior_mean_m.shape)  # [60, 10]
    prior_mean_m = torch.reshape(prior_mean_m, (1, 600))
    print('prior_mean_m: ', prior_mean_m.shape)

    prior_var_m = grads_x_mset @ w_cov @ torch.transpose(grads_x_mset, 0, 1)
    print('prior_var_m: ', prior_var_m.shape)  # [600, 600]

    prior_var_m = torch.diag(prior_var_m, 0)
    print('prior_var_m: ', prior_var_m.shape)  # [600]



    ############################ train flow2 ###################
    flownet3 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
        device)
    flownet4 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
        device)

    flow2 = LinearFlow_nn(flownet3).to(device)
    # flow2 = NonLinearFlow_nn(flownet3, flownet4).to(device)
    # flow2 = SigmoidFlow_nn(flownet3).to(device)
    # flow2 = ExpFlow_nn(flownet3).to(device)
    # flow2 = CubeFlow_nn(flownet2).to(device)
    flow_optimizer2 = torch.optim.Adam(flow2.parameters(), lr=0.01)

    ######
    loss_flow_all = []

    for epoch in range(5000):
        Z, logdet = flow2.inverse(mset_fmc, original_x_train_add_g3)

        #
        # cum_det = torch.stack(logdet_list)
        # cum_det = torch.sum(cum_det)
        logdet_list = []
        flow_list_inv = flow_list[::-1]

        for flow in flow_list_inv:
            Z, logdetf = flow.inverse(Z, original_x_train_add_g3)

            logdetf = torch.mean(logdetf)
            logdet_list.append(logdetf)
        # print('logdet_list: ', logdet_list)
        cum_det = torch.stack(logdet_list)
        cum_det = torch.sum(cum_det)

        Z = torch.reshape(Z, (50, 600))

        loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
            torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet) - cum_det

        # print('sum_var_loss: ', torch.sum(torch.log(prior_var_m.abs())))
        # print('mean_loss: ', torch.mean(torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1)))
        # print('logdet: ', logdet)
        # print('cum_det: ', cum_det)

        flow_optimizer2.zero_grad()
        loss_flow.backward(retain_graph=True)
        flow_optimizer2.step()
        loss_flow_all.append(loss_flow.item())

        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_flow),
              datetime.now().replace(microsecond=0) - start_time)

    ############################
    flow2.eval()

    print('flow2.weight: ', flow2.weight)
    a2 = flow2.weight

    # flow_list.append(flow2)
    a_list.append(a2)

    ####

    plt.figure()
    indices = np.arange(5000)[::50]
    loss_flow_all = np.array(loss_flow_all)
    plt.plot(indices, loss_flow_all[indices], '-ko', ms=3)
    plt.ylabel(r'flow loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(figures_folder + '/flow_loss2.pdf')

    flow_list.append(flow2)

    ########### approximate flow test err ################
    acc_sum = 0

    for idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)

        nnet_samples = opt_bnn.sample_functions(data, num_sample=10).float().to(device)  # original_x_test
        nnet_samples = nnet_samples.squeeze().detach()

        for flow in flow_list:

            nnet_samples, _, _ = flow(nnet_samples, data)

        nnet_samples = nnet_samples.detach()
        nnet_samples = torch.sum(nnet_samples, 0)
        # print('fvi.shape: ', nnet_samples.shape)   [1000, 10]

        output = F.softmax(nnet_samples, dim=1)

        predict = output.data.max(1)[1]
        acc = predict.eq(target.data).cpu().sum().item()
        acc_sum += acc

    test_acc = acc_sum / len(test_loader)
    test_err_flow = round((1 - test_acc / eval_batch_size), 4)

    print('test_err_flow: ', test_err_flow)

print('a_list: ', a_list)

######## OOD_test ###############
log_upper_bound = np.log(output_dim)

entropies_in_sample=[]

for batch_id, (data, target) in enumerate(oob_in_loader):
    data, target = data.to(device), target.to(device)
    data = torch.reshape(data, (125, 784))
    # data = torch.reshape(data, (125, 3072))
    with torch.no_grad():
        p=opt_bnn.forward(data)
        for flow in flow_list[:-1]:

            p, _, _ = flow(p, data)

        p=F.softmax(p, dim=1)
        e = torch.sum(- p * torch.log(p), dim=1)
        e = torch.nan_to_num(e, nan=0.0)
        # print(e)
        entropies_in_sample.append(e)

entropies_in_sample = torch.cat(entropies_in_sample, dim=0)
true_lables_in_sample = torch.zeros_like(entropies_in_sample)

entropies_out_sample = []
for batch_id, (data, target) in enumerate(oob_out_loader):
    data, target = data.to(device), target.to(device)
    data = torch.reshape(data, (125, 784))
    # data = torch.reshape(data, (125, 3072))
    with torch.no_grad():
        p=opt_bnn.forward(data)
        for flow in flow_list[:-1]:

            p, _, _ = flow(p, data)

        p = F.softmax(p, dim=1)
        e = torch.sum(- p * torch.log(p), dim=1)
        # print(e)
        e = torch.nan_to_num(e, nan=0.0)
        entropies_out_sample.append(e)
entropies_out_sample = torch.cat(entropies_out_sample, dim=0)
true_lables_out_sample = torch.ones_like(entropies_out_sample)

in_sample_preds = entropies_in_sample/log_upper_bound
out_sample_preds = entropies_out_sample/log_upper_bound
preds = torch.cat([in_sample_preds, out_sample_preds], dim=0).cpu().numpy()
truths = torch.cat([true_lables_in_sample, true_lables_out_sample], dim=0).cpu().numpy()
AUC = roc_auc_score(truths, preds)
print('auc: ', AUC)






# fpr, tpr, thresholds = metrics.roc_curve(truths, preds)
# fpr_ifbnn_cf = pd.DataFrame.from_dict(fpr)
# fpr_ifbnn_cf.to_csv(results_folder + '/' + bnn_name_string + '/' + 'fpr_ifbnn_cf.csv')
# tpr_ifbnn_cf = pd.DataFrame.from_dict(tpr)
# tpr_ifbnn_cf.to_csv(results_folder + '/' + bnn_name_string + '/' + 'tpr_ifbnn_cf.csv')

# roc_auc = metrics.auc(fpr, tpr)
# display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name='ifbnn')
# display.plot()
# plt.savefig(figures_folder + '/oob_ifbnn_cf.pdf')

log_f.close()

# plt.figure()
# plt.plot(train_loss_all, 'r-')
# plt.title('train loss')
# plt.savefig(figures_folder + '/train_loss.pdf')
#
# plt.figure()
# plt.plot(likelihood_loss_all, 'r-')
# plt.title('likelihood loss')
# plt.savefig(figures_folder + '/likelihood_loss.pdf')
#
# plt.figure()
# plt.plot(prior_distance_all, 'r-')
# plt.title('prior_distance')
# plt.savefig(figures_folder + '/prior_distance.pdf')
#
# plt.figure()
# plt.plot(wdist_all, 'r-')
# plt.title('wdist')
# plt.savefig(figures_folder + '/wdist.pdf')

# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("Test err  : ", test_err_vi)
print("============================================================================================")


