import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
import gpytorch
from gp_prior import ExactGPModel, prior_sample_functions, prior_sample_functions2


# from bnn import BNN, OPTBNN
from bnn_vimc import BNN, BNNF, BNNMC
import torch.nn.functional as F
import argparse

from utils.logging import get_logger

from data import uci_woval

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt


################################### Hyper-parameters ###################################

lr_bnn = 0.01
prior_coeff = 10
bnn_name_string = 'VIMC'
uci_dataset_name_string = 'kin8nm' #'protein'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 200
num_sample = 1000
epochs = 2001
n_step_prior_pretraining = 100
lr_optbnn = 0.01
f_coeff = 10

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 10
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + 'uci'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'uci' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + 'uci' + '/'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'uci' + '/'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)

loss_folder = results_folder + '/' + bnn_name_string + '/' + 'losses' + '/' + 'uci'     # fvimc_loopt_multiflow
if not os.path.exists(loss_folder):
    os.makedirs(loss_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

# dataset = uci_woval(uci_dataset_name_string, seed=random_seed)
# train_x, test_x, train_y, test_y = dataset.x_train, dataset.x_test, dataset.y_train, dataset.y_test

train_x, test_x, train_y, test_y = uci_woval(uci_dataset_name_string, seed=random_seed)
# print(test_y)

training_num, input_dim = train_x.shape
test_num = test_x.shape[0]
output_dim = 1
is_continuous = True

print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)

original_x_train = torch.from_numpy(train_x).float().to(device)
original_y_train = torch.from_numpy(train_y).float().to(device)
original_x_test = torch.from_numpy(test_x).float().to(device)
original_y_test = torch.from_numpy(test_y).float().to(device)



loss = nn.MSELoss(reduction='mean')


# mask = torch.randperm(training_num)[:test_num]
# prior_x = original_x_train[mask, :]
# prior_y = original_y_train[mask]

print("============================================================================================")
################################## define model ##################################


bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
opt_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

# bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(),  'lr': lr_bnn}])
# opt_bnn_optimizer = torch.optim.RMSprop([{'params': opt_bnn.parameters(), 'lr': lr_optbnn}])
opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': lr_bnn}])

print("============================================================================================")
################################## prior pre-training ##################################

# pre-train GP prior

likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior = ExactGPModel(original_x_test, original_y_test, likelihood, input_dim).to(device)

prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(original_x_test)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, original_y_test)
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        # prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))

    optimizer.step()

prior.eval()
likelihood.eval()

sigma = prior.likelihood.noise.item()


print("============================================================================================")

##########################################    check opt_bnn  #######

def sample_measurement_set(X, num_data):
    n = torch.Tensor([100])
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

def lipf(X, Y):
    return ((X - Y).abs())

original_x_train_add_g3 = sample_measurement_set(X=original_x_train, num_data=training_num)

############################################
###### linear_nn flow ###############
hidden_dimsflow = [100]
class LinearFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net   # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()
        x = z * self.weight + nnbias
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()
        # print('nnbias.shape: ', nnbias.shape)
        # print('x.shape: ', x.shape)
        z = x * 1/self.weight - nnbias/self.weight
        log_det_inverse = torch.sum(torch.log(1/self.weight.abs() + 1e-7))

        return z, log_det_inverse

##
class NonLinearFlow_nn(nn.Module):
    def __init__(self, net1, net2):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias1 = net1  # torch.randn(1)
        self.weight2 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias2 = net2  # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias1 = self.bias1(x_in).squeeze()
        nnbias2 = self.bias2(x_in).squeeze()
        mask = (z - nnbias2) / self.weight2
        clamped_mask = torch.clamp(mask, -0.99, 0.99)
        x = torch.atanh(clamped_mask) / self.weight - nnbias1 / self.weight
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias1, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias1 = self.bias1(x_in).squeeze()
        nnbias2 = self.bias2(x_in).squeeze()
        # print('nnbias1.shape: ', nnbias1)
        # print('nnbias2.shape: ', nnbias2)

        z = self.weight2 * torch.tanh(x * self.weight + nnbias1) + nnbias2

        det = self.weight2 * self.weight * (1 - torch.tanh(x * self.weight + nnbias1) ** 2)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse

########## sigmoid_flow #############
class SigmoidFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.weight2 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()

        mask = (z - nnbias) / self.weight2
        mask = torch.clamp(mask, min=0.01, max=0.99)
        clamped_mask = mask / (1 - mask)
        # x = torch.log(clamped_mask + 1e-7) / self.weight - nnbias1 / self.weight
        x = torch.log(clamped_mask + 1e-7) / self.weight
        log_det = 1.

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()

        mask = torch.nn.functional.sigmoid(x * self.weight)
        z = self.weight2 * mask + nnbias

        det = self.weight2 * self.weight * mask * (1 - mask)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse

########## exp_flow #############
class ExpFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net   # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()
        mask = (z - nnbias) / self.weight
        mask = torch.clamp(mask, min=0.001)
        mask = torch.log(mask + 1e-7)
        # mask = torch.log(mask.abs() + 1e-7)
        clamped_mask = torch.clamp(mask, -0.99, 0.99)
        x = torch.atanh(clamped_mask)
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()
        # print('nnbias.shape: ', nnbias.shape)
        # print('x.shape: ', x.shape)
        z = torch.exp(torch.tanh(x)) * self.weight + nnbias
        det = self.weight * torch.exp(torch.tanh(x)) * (1 - torch.tanh(x) ** 2)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse

################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)


# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
# scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

################# S1: FVI #################################
train_loss_all = []
prior_distance_all = []
wdist_all = []
likelihood_loss_all = []

fvi_mse_all = []
fmc_mse_all = []
fvi_nll_all = []
fmc_nll_all = []
app_mse_all = []
app_nll_all = []

for epoch in range(401):

    measurement_set = sample_measurement_set(X=original_x_train, num_data=training_num)
    # add_set = sample_measurement_set(original_x_test, num_data=test_num)
    # measurement_set = torch.cat((measurement_set, add_set), 0).to(device)
    gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
    gpss = gpss.squeeze()
    nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device) # original_x_test
    nnet_samples = nnet_samples.squeeze()

    functional_wdist = lipf(gpss, nnet_samples)
    functional_wdist = torch.mean(functional_wdist)

    ##################### variance constraint #########

    with torch.no_grad():
        gp_var = prior(measurement_set).variance
        # print('gp_var.shape: ', gp_var.shape)

    opt_samples = opt_bnn.sample_functions(measurement_set, num_sample=200).float().to(device)
    opt_samples = opt_samples.squeeze()
    opt_var = torch.std(opt_samples, 0) ** 2
    # print('opt_var.shape: ', opt_var.shape)
    dif_var = (gp_var - opt_var).abs()
    # dif_var = torch.sum(dif_var)
    dif_var = torch.mean(dif_var)


    ### calculate bnn likelihood loss ###
    pred_y = opt_bnn.forward(original_x_train)

    pred_y = pred_y.squeeze().flatten()
    # print(pred_y)

    train_loss = loss(pred_y, original_y_train) + 10 * functional_wdist + 1 * dif_var    # 1_10_1   1_10_10
    likelihood_loss = loss(pred_y, original_y_train)


    # optimisation
    opt_bnn_optimizer.zero_grad()
    train_loss.backward()

    opt_bnn_optimizer.step()

    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
          datetime.now().replace(microsecond=0) - start_time)

    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    if epoch % test_interval == 0:
        opt_bnn.eval()

        samples_pred_test_y = opt_bnn.forward_eval(original_x_test, num_sample).squeeze().detach()

        opt_bnn.train()

        mse_fvi = loss(samples_pred_test_y, original_y_test).detach().cpu()
        # nll_test_loss = -F.nll_loss(samples_pred_test_y, original_y_test, reduction='sum')
        # print('nll_test_loss: ', nll_test_loss)
        print('mse_test_loss: ', mse_fvi)
        fvi_mse_all.append(mse_fvi)


        ################ NLL ################
        mean_y_pred = torch.mean(samples_pred_test_y, 0)
        std_y_pred = torch.std(samples_pred_test_y, 0)
        variance_y_pred = torch.std(samples_pred_test_y, 0) ** 2 + sigma
        log_std = torch.log(std_y_pred).sum()
        vec = 0.5 * (original_y_test - mean_y_pred) ** 2 / variance_y_pred
        vec = vec.sum()
        const = 0.5 * np.log(np.pi * 2) + np.log(sigma)
        nll = test_num * const + log_std + vec
        nll_fvi = nll / test_num
        nll_fvi = nll_fvi.cpu()

        print('nll_mean: ', nll_fvi)
        fvi_nll_all.append(nll_fvi)
        print("--------------------------------------------------------------------------------------------")
#######


print("--------------------------------------------------------------------------------------------")
######### S2: FMCMC #####################
w_mu_list = []
b_mu_list = []
w_std_list = []
b_std_list = []

for name, p in opt_bnn.named_parameters():
    if "W_mu" in name:
        # w_prior_mu = p
        w_mu_list.append(p)
    if 'b_mu' in name:
        # b_prior_mu = p
        b_mu_list.append(p)
    if 'W_std' in name:
        # w_prior_std = p
        w_std_list.append(p)
    if 'b_std' in name:
        # b_prior_std = p
        b_std_list.append(p)


bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous,
              W_mu=w_mu_list, b_mu=b_mu_list, W_std=w_std_list, b_std=b_std_list, scaled_variance=True).to(device)

# bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

bnnmc_optimizer = torch.optim.Adam(bnnmc.parameters(), lr=0.001)  # yacht-energy: 0.001, whine-protein: 0.01

def noise_loss(lr):
    noise_loss = 0.0
    noise_std = (2/lr)**0.5
    for var in bnnmc.parameters():
        means = torch.zeros(var.size()).to(device)
        sigma = torch.normal(means, std=noise_std).to(device)
        noise_loss += torch.sum(var * sigma)
    return noise_loss

def sample_measurement_set(X, num_data):
    n = torch.Tensor([1000])  # 1000 40
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

########### SGLD ###############
mt = 0
for epoch in range(401):

    pred_y = bnnmc.forward(original_x_train)
    # y_predicted, distance_prior = bnn.forward_w(train_x)

    pred_y = pred_y.squeeze().flatten()

    ###################  measurement set for d_logp
    m_set = sample_measurement_set(X=original_x_train, num_data=training_num)
    y_m_set = bnnmc.forward(m_set).squeeze().flatten()
    # print('y_m_set.shape: ', y_m_set.shape)

    with torch.no_grad():
        fprior_m = likelihood(prior(m_set))
        prior_mean_m = fprior_m.mean
        prior_var_m = fprior_m.variance

    #
    d_logp_m = ((y_m_set - prior_mean_m) / prior_var_m).detach()
    d_logp_m = y_m_set * d_logp_m
    d_logp_m = torch.mean(d_logp_m)
    # print('d_log_m:', d_logp_m)
    #

    t_like = pred_y * ((pred_y - original_y_train).detach())
    t_like = torch.mean(t_like)

    ####################
    train_loss_mc1 = t_like + 1 * d_logp_m  # 1e-2

    # train_loss = loss(pred_y, original_y_train) + 1e-6 * d_logp  # coeff: 1e-6 1e-3
    if epoch > 300:           # 300
        noise = noise_loss(lr=0.001)
        train_loss_mc1 = train_loss_mc1 + 1e-3 * noise

    # optimisation
    bnnmc_optimizer.zero_grad()
    train_loss_mc1.backward()
    bnnmc_optimizer.step()
    # scheduler.step()
    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss_mc1),
          datetime.now().replace(microsecond=0) - start_time)

    if epoch % test_interval == 0:
        bnnmc.eval()

        samples_pred_test_y = bnnmc.forward_eval(original_x_test, num_sample).squeeze().detach()

        bnnmc.train()

        mse_test_loss = loss(samples_pred_test_y, original_y_test)
        # nll_test_loss = -F.nll_loss(samples_pred_test_y, original_y_test, reduction='sum')
        # print('nll_test_loss: ', nll_test_loss)
        print('mse_test_loss: ', mse_test_loss)

    ##############
    if epoch > 100 and (epoch % 3) == 0:   # (100, 3)
        print('save!')
        bnnmc.cpu()
        # torch.save(bnn.state_dict(), save_path + '/g3_rbf_%i.pt'%(mt))
        dd = f'{mt}'
        print(dd)
        save_path = results_folder + '/' + bnn_name_string + '/' + 'fsgld_yacht' + '/' + 'yacht_' + dd
        # if not os.path.exists(save_path):
        #     os.makedirs(save_path)
        bnnmc.save(save_path)
        mt += 1
        bnnmc.to(device)

#####################
def testp():
    bnnmc.eval()

    pred_test_y = bnnmc.forward(original_x_test).squeeze().detach()
    return pred_test_y

def test2():
    bnnmc.eval()

    pred_mset = bnnmc.forward(original_x_train_add_g3).squeeze().detach()
    return pred_mset

################ load mcmc model ################
pred_list = []
predm_list = []
num_model = 100 # 10 100 80 70

for m in range(num_model):
    dd = f'{m}'
    bnnmc.load_state_dict(torch.load(results_folder + '/' + bnn_name_string + '/' + 'fsgld_yacht' + '/' + 'yacht_' + dd))  # sgld, sghmc
    pred = testp()
    print('pred.shape: ', pred.shape)
    pred_list.append(pred)

    predm = test2()
    predm_list.append(predm)
    mset_fmc = torch.stack(predm_list)

########################
total_pred = torch.stack(pred_list)
mean_y_pred = torch.mean(total_pred, 0)
std_y_pred = torch.std(total_pred, 0)

mse_fmc = loss(total_pred, original_y_test).detach().cpu()
print('mse_fmc: ', mse_fmc)
fmc_mse_all.append(mse_fmc)

variance_y_pred = torch.std(total_pred, 0) ** 2
log_std = torch.log(std_y_pred).sum()
# log_std = torch.log(variance_y_pred).sum()
vec = 0.5 * ((original_y_test - mean_y_pred) ** 2 + variance_y_pred)
vec = vec.sum()
const = 0.5 * np.log(np.pi * 2)
nll = test_num * const + log_std + vec
nll_fmc = nll / test_num
nll_fmc = nll_fmc.detach().cpu()

print('nll_fmc: ', nll_fmc)
fmc_nll_all.append(nll_fmc)


############### GP Linearization for FVI posterior ###############
opt_bnn2 = opt_bnn
opt_bnn_optimizer2 = torch.optim.Adam([{'params': opt_bnn2.parameters(), 'lr': lr_bnn}])

w_mu, w_std = opt_bnn2.sample_mu_std()
w_mu = w_mu.detach()
w_std = w_std.detach()
print('w_std.shape: ', w_std.shape)  # [10401]
w_dim = w_std.shape[-1]
print('w_dim: ', w_std.shape[-1])
w_cov = w_std ** 2 * torch.eye(w_dim).to(device)
print('w_cov.shape: ', w_cov.shape)  # [10401, 10401]

grads_x_mset = None

for x_i in original_x_train_add_g3:

    pred_y_i = opt_bnn2.forward_mu(x_i)
    pred_y_i = pred_y_i.squeeze().flatten()

    # evaluate gradient
    opt_bnn_optimizer2.zero_grad()
    pred_y_i.backward()

    grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
    grad_i = grad_i[:, None]
    print('grad_i.shape: ', grad_i.shape)

    if grads_x_mset is None:
        grads_x_mset = torch.transpose(grad_i, 0, 1)
    else:
        grads_x_mset = torch.cat((grads_x_mset, torch.transpose(grad_i, 0, 1)), 0)

print('grads_x_mset.shape: ', grads_x_mset.shape)

prior_mean_m = opt_bnn2.forward_mu(original_x_train_add_g3).squeeze().detach()
print('prior_mean_m: ', prior_mean_m.shape)

prior_var_m = grads_x_mset @ w_cov @ torch.transpose(grads_x_mset, 0, 1)
print('prior_var_m: ', prior_var_m.shape)

prior_var_m = torch.diag(prior_var_m, 0)
print('prior_var_m: ', prior_var_m.shape)
##

######################### train LinerFolw_nn ##############

flownet = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(device)
flownet2 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(device)

flow = LinearFlow_nn(flownet).to(device)
# flow = CubeFlow_nn(flownet).to(device)
# flow = NonLinearFlow_nn(flownet, flownet2).to(device)
# flow = SigmoidFlow_nn(flownet).to(device)
# flow = ExpFlow_nn(flownet).to(device)
flow_optimizer = torch.optim.Adam(flow.parameters(), lr=0.01)


loss_flow_all = []
flow_likelihood_list = []
flow_det_list = []
flow_loss_list = []

for epoch in range(5000):
    Z, logdet = flow.inverse(mset_fmc, original_x_train_add_g3)
    # print('Z: ', Z)
    contains_nan = torch.isnan(Z).any()
    print(contains_nan)

    loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
        torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet)

    flow_optimizer.zero_grad()
    loss_flow.backward(retain_graph=True)
    flow_optimizer.step()
    loss_flow_all.append(loss_flow.item())


    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_flow),
          datetime.now().replace(microsecond=0) - start_time)

flow.eval()
a = flow.weight

print('flow.weight: ', flow.weight)
# print('flow.weight2: ', flow.weight2)

flow_likelihood = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
        torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1))
flow_det = torch.mean(logdet)
flow_likelihood_list.append(flow_likelihood)
flow_det_list.append(flow_det)
flow_loss_list.append(loss_flow)

####

plt.figure()
indices = np.arange(5000)[::50]
loss_flow_all = np.array(loss_flow_all)
plt.plot(indices, loss_flow_all[indices], '-ko', ms=3)
plt.ylabel(r'flow loss')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/flow_loss.pdf')

####
nnet_samples = opt_bnn.sample_functions(original_x_test, num_sample=100).float().to(device)  # original_x_test
nnet_samples = nnet_samples.squeeze().detach()

# fvi = flow2(gpss).detach()
fvi, nnbias_test, _ = flow(nnet_samples, original_x_test)
fvi = fvi.detach()
print('fvi.shape: ', fvi.shape)

mse_app = loss(fvi, original_y_test).detach().cpu()
print('mse_app: ', mse_app)
app_mse_all.append(mse_app)
#
# variance_y_pred = torch.std(fvi, 0) ** 2
# log_std = torch.log(std_y_pred).sum()
# # log_std = torch.log(variance_y_pred).sum()
# vec = 0.5 * ((original_y_test - mean_y_pred) ** 2 + variance_y_pred)
# vec = vec.sum()
# const = 0.5 * np.log(np.pi * 2)
# nll = test_num * const + log_std + vec
# nll_mean = nll / test_num
#
# print('nll_fvi: ', nll_mean)

################################# FVI-MCMC LOOP #######################

flow_list = [flow]
a_list = [a]
# logdet_list = [flogdet]
cum_det_list = []
flow_name_list = []
##
all_train_losses = []
all_wdist_losses = []
all_likelihood_losses = []
all_m2_losses = []

def sample_measurement_set2(X, num_data):
    n = torch.Tensor([40])
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set
##
for i in range(2):    # 13
    loop_count = 800 * i
    print('loop_count: ', loop_count)
    outloop = i
    ################# S1: FVI #############################
    train_loss_all = []
    prior_distance_all = []
    wdist_all = []
    likelihood_loss_all = []
    dif_var_all = []
    opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': 0.01}])   #lr_bnn

    for epoch in range(401):  # 1000 4001

        pred_y_vi2 = opt_bnn.forward(original_x_train)
        print('pred_y_vi2.shape: ', pred_y_vi2.shape)
        pred_y_vi2 = pred_y_vi2.t()

        for flow in flow_list:

            pred_y_vi2, _, _ = flow(pred_y_vi2, original_x_train)
            # print('pred_y_vi2: ', pred_y_vi2)

        print('pred_y_vi2.shape: ', pred_y_vi2.shape)
        pred_y_vi2 = pred_y_vi2.squeeze().flatten()

        #####
        measurement_set = original_x_train_add_g3
        # measurement_set = sample_measurement_set2(X=original_x_train, num_data=training_num)
        gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
        # gpss = prior_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
        gpss = gpss.squeeze()
        nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
        nnet_samples = nnet_samples.squeeze()

        for flow in flow_list:

            nnet_samples, _, _ = flow(nnet_samples, measurement_set) #original_x_train_add_g3

        print('nnet_samples.shape: ', nnet_samples.shape)
        functional_wdist = lipf(gpss, nnet_samples)
        functional_wdist = torch.mean(functional_wdist)

        #####
        # measurement_set2 = sample_measurement_set2(X=original_x_train, num_data=training_num)
        measurement_set2 = original_x_train_add_g3

        with torch.no_grad():
            gp_var = prior(measurement_set2).variance
            # print('gp_var.shape: ', gp_var.shape)

        opt_samples = opt_bnn.sample_functions(measurement_set2, num_sample=200).float().to(device)
        opt_samples = opt_samples.squeeze()

        for flow in flow_list:

            opt_samples, _, _ = flow(opt_samples, measurement_set2)  # original_x_train_add_g3

        opt_var = torch.std(opt_samples, 0) ** 2
        # print('opt_var.shape: ', opt_var.shape)
        dif_var = (gp_var - opt_var).abs()
        # dif_var = torch.sum(dif_var)
        dif_var = torch.mean(dif_var)

        ######
        # if outloop < 12:
        #     train_loss_vi2 = 10 * loss(pred_y_vi2, original_y_train) + 10 * functional_wdist + dif_var
        #
        # else:
        #     train_loss_vi2 = 100 * loss(pred_y_vi2, original_y_train) + 100 * functional_wdist + 10 * dif_var

        train_loss_vi2 = 1 * loss(pred_y_vi2, original_y_train) + 10 * functional_wdist + 1 * dif_var    # 1_10_1
        likelihood_loss = loss(pred_y_vi2, original_y_train)
        print('likelihood_loss: ', likelihood_loss)
        print('functional_wdist: ', functional_wdist)

        # optimisation
        opt_bnn_optimizer.zero_grad()
        train_loss_vi2.backward()
        opt_bnn_optimizer.step()

        #
        train_loss_all.append(train_loss_vi2.item())
        likelihood_loss_all.append(likelihood_loss.item())
        wdist_all.append(functional_wdist.item())
        dif_var_all.append(dif_var.item())


        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss_vi2),
              datetime.now().replace(microsecond=0) - start_time)

        if epoch % 200 == 0:  # 2000 500

            opt_bnn.eval()

            samples_pred_test_y = opt_bnn.forward_eval(original_x_test, num_sample).squeeze().detach()
            for flow in flow_list:

                samples_pred_test_y, _, _ = flow(samples_pred_test_y, original_x_test)

            samples_pred_test_y = samples_pred_test_y.detach()

            # mean_y_pred = pred_test_y

            mse_fvi = loss(samples_pred_test_y, original_y_test).detach().cpu()
            print('mse_fvi: ', mse_fvi)
            fvi_mse_all.append(mse_fvi)

            ################ NLL ################
            mean_y_pred = torch.mean(samples_pred_test_y, 0)
            std_y_pred = torch.std(samples_pred_test_y, 0)
            variance_y_pred = torch.std(samples_pred_test_y, 0) ** 2 + sigma
            log_std = torch.log(std_y_pred).sum()
            vec = 0.5 * (original_y_test - mean_y_pred) ** 2 / variance_y_pred
            vec = vec.sum()
            const = 0.5 * np.log(np.pi * 2) + np.log(sigma)
            nll = test_num * const + log_std + vec
            nll_fvi = nll / test_num
            nll_fvi = nll_fvi.cpu()

            print('nll_mean: ', nll_fvi)
            fvi_nll_all.append(nll_fvi)

            opt_bnn.train()

    all_train_losses.append(train_loss_all)
    all_wdist_losses.append(wdist_all)
    all_likelihood_losses.append(likelihood_loss_all)
    all_m2_losses.append(dif_var_all)

    ###########################
    plt.figure()
    indices = np.arange(401)[::10]
    train_loss_all = np.array(train_loss_all)
    plt.plot(indices, train_loss_all[indices], '-ko', ms=3)
    plt.ylabel(r'Training loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(figures_folder + '/train_loss.pdf')

    plt.figure()
    indices = np.arange(401)[::10]
    likelihood_loss_all = np.array(likelihood_loss_all)
    plt.plot(indices, likelihood_loss_all[indices], '-ko', ms=3)
    plt.ylabel(r'Likelihood loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(figures_folder + '/likelihood_loss.pdf')

    plt.figure()
    indices = np.arange(401)[::10]
    wdist_all = np.array(wdist_all)
    plt.plot(indices, wdist_all[indices], '-ko', ms=3)
    # plt.ylabel(r'$W_1(p_{gp}, p_{prior bnn})$')
    plt.ylabel(r'$W_1(p_{f}, p_{g})$')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(figures_folder + '/wdist.pdf')

    ################## MCMC #################
    w_mu_list = []
    b_mu_list = []
    w_std_list = []
    b_std_list = []

    for name, p in opt_bnn.named_parameters():
        if "W_mu" in name:
            # w_prior_mu = p
            w_mu_list.append(p)
        if 'b_mu' in name:
            # b_prior_mu = p
            b_mu_list.append(p)
        if 'W_std' in name:
            # w_prior_std = p
            w_std_list.append(p)
        if 'b_std' in name:
            # b_prior_std = p
            b_std_list.append(p)

    bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous,
                  W_mu=w_mu_list, b_mu=b_mu_list, W_std=w_std_list, b_std=b_std_list, scaled_variance=True).to(device)

    # bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

    bnnmc_optimizer = torch.optim.Adam(bnnmc.parameters(), lr=0.01)  #  lr=0.001


    def noise_loss(lr):
        noise_loss = 0.0
        noise_std = (2 / lr) ** 0.5
        for var in bnnmc.parameters():
            means = torch.zeros(var.size()).to(device)
            sigma = torch.normal(means, std=noise_std).to(device)
            noise_loss += torch.sum(var * sigma)
        return noise_loss

    ######### SGLD #####
    mt = 0
    for epoch in range(401):   # 500 2001
        pred_y_mc1 = bnnmc.forward(original_x_train)

        pred_y_mc1 = pred_y_mc1.t()

        for flow in flow_list:

            pred_y_mc1, _, _ = flow(pred_y_mc1, original_x_train)

        pred_y_mc1 = pred_y_mc1.squeeze().flatten()

        ######

        m_set = original_x_train    # original_x_train: for all except protein
        # m_set = sample_measurement_set(original_x_train, num_data=training_num)    # for protein
        y_m_set = bnnmc.forward(m_set).t()

        for flow in flow_list:

            y_m_set, _, _ = flow(y_m_set, original_x_train) # original_x_train, m_set

        y_m_set = y_m_set.squeeze().flatten()

        with torch.no_grad():
            fprior_m = likelihood(prior(m_set))
            prior_mean_m = fprior_m.mean
            prior_var_m = fprior_m.variance

        d_logp_m = ((y_m_set - prior_mean_m) / prior_var_m).detach()
        # print('d_log_m:', d_logp_m)
        d_logp_m = y_m_set * d_logp_m
        d_logp_m = torch.mean(d_logp_m)

        ######

        t_like = pred_y_mc1 * ((pred_y_mc1 - original_y_train).detach())
        t_like = torch.mean(t_like)

        train_loss_mc1 = t_like + 1 * d_logp_m  # 1e-6  1e-2   1e-3
        # if i > 4:
        #     train_loss_mc1 = 10 ** (i-3) * train_loss_mc1
        # if outloop < 5:
        #     if epoch > 100:  # 1500  100
        #         noise = noise_loss(lr_bnn)
        #         train_loss_mc1 = train_loss_mc1 + 1e-3 * noise
        #
        # else:
        #     if epoch > 100:  # 1500  100
        #         noise = noise_loss(lr_bnn)
        #         train_loss_mc1 = train_loss_mc1 + 1e-5 * noise

        # train_loss = loss(pred_y, original_y_train) + 1e-6 * d_logp  # coeff: 1e-6 1e-3
        if epoch > 100:  #  other_100  protein_300
            noise = noise_loss(lr_bnn)
            train_loss_mc1 = train_loss_mc1 + 1e-3 * noise   #1e-5

        # optimisation
        bnnmc_optimizer.zero_grad()
        train_loss_mc1.backward()
        bnnmc_optimizer.step()
        # scheduler.step()

        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss_mc1),
              datetime.now().replace(microsecond=0) - start_time)

        if epoch % 200 == 0:    # 2000   500

            bnnmc.eval()

            ##
            nnet_samples = bnnmc.sample_functions(original_x_test, num_sample=1).float().to(device)  # original_x_test
            nnet_samples = nnet_samples.squeeze().detach()

            for flow in flow_list:

                nnet_samples, _, _ = flow(nnet_samples, original_x_test)

            nnet_samples = nnet_samples.detach()
            print('fvi.shape: ', nnet_samples.shape)

            mse_test_loss = loss(nnet_samples, original_y_test)
            # nll_test_loss = -F.nll_loss(samples_pred_test_y, original_y_test, reduction='sum')
            # print('nll_test_loss: ', nll_test_loss)
            print('mse_mcmc: ', mse_test_loss)

            bnnmc.train()

        ##############
        if epoch > 0 and (epoch % 4) == 0:  # (0, 4)
            print('save!')
            bnnmc.cpu()
            # torch.save(bnn.state_dict(), save_path + '/g3_rbf_%i.pt'%(mt))
            dd = f'{mt}'
            print(dd)
            save_path = results_folder + '/' + bnn_name_string + '/' + 'fsgld_yacht' + '/' + 'yacht_' + dd  # sgld
            # if not os.path.exists(save_path):
            #     os.makedirs(save_path)
            bnnmc.save(save_path)
            mt += 1
            bnnmc.to(device)

    ###
    def testp():
        bnnmc.eval()

        pred_test_y = bnnmc.forward(original_x_test).squeeze().detach()
        for flow in flow_list:

            pred_test_y, _, _ = flow(pred_test_y, original_x_test)

        pred_test_y = pred_test_y.squeeze().detach()

        return pred_test_y


    def test2():
        bnnmc.eval()

        pred_mset = bnnmc.forward(original_x_train_add_g3).squeeze().detach()
        for flow in flow_list:

            pred_mset, _, _ = flow(pred_mset, original_x_train_add_g3)

        pred_mset = pred_mset.detach()

        return pred_mset

    ################ load mcmc model ################
    pred_list = []
    predm_list = []
    num_model = 100  # 10 100 80 70

    for m in range(num_model):
        dd = f'{m}'
        bnnmc.load_state_dict(torch.load(
            results_folder + '/' + bnn_name_string + '/' + 'fsgld_yacht' + '/' + 'yacht_' + dd))  # sgld, sghmc
        pred = testp()
        print('pred.shape: ', pred.shape)
        pred_list.append(pred)

        predm = test2()
        print('predm.shape: ', predm.shape)
        predm_list.append(predm)

    ########################
    total_pred = torch.stack(pred_list)
    mset_fmc = torch.stack(predm_list)

    mean_y_pred = torch.mean(total_pred, 0)
    std_y_pred = torch.std(total_pred, 0)

    mse_fmc = loss(total_pred, original_y_test).detach().cpu()
    print('mse_mcmc: ', mse_fmc)
    fmc_mse_all.append(mse_fmc)

    variance_y_pred = torch.std(total_pred, 0) ** 2
    log_std = torch.log(std_y_pred).sum()
    # log_std = torch.log(variance_y_pred).sum()
    vec = 0.5 * ((original_y_test - mean_y_pred) ** 2 + variance_y_pred)
    vec = vec.sum()
    const = 0.5 * np.log(np.pi * 2)
    nll = test_num * const + log_std + vec
    nll_fmc = nll / test_num
    nll_fmc = nll_fmc.detach().cpu()

    print('nll_mean: ', nll_fmc)
    fmc_nll_all.append(nll_fmc)

    ############################# GP Linearization2 ###################################
    opt_bnn2 = opt_bnn
    opt_bnn_optimizer2 = torch.optim.Adam([{'params': opt_bnn2.parameters(), 'lr': lr_bnn}])

    ##
    w_mu, w_std = opt_bnn2.sample_mu_std()
    w_mu = w_mu.detach()
    w_std = w_std.detach()
    print('w_std.shape: ', w_std.shape)  # [10401]
    w_dim = w_std.shape[-1]
    print('w_dim: ', w_std.shape[-1])
    w_cov = w_std ** 2 * torch.eye(w_dim).to(device)
    print('w_cov.shape: ', w_cov.shape)  # [10401, 10401]

    ###
    grads_x_mset = None

    for x_i in original_x_train_add_g3:
        pred_y_i = opt_bnn2.forward_mu(x_i)
        pred_y_i = pred_y_i.squeeze().flatten()

        # evaluate gradient
        opt_bnn_optimizer2.zero_grad()
        pred_y_i.backward()

        grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
        grad_i = grad_i[:, None]

        if grads_x_mset is None:
            grads_x_mset = torch.transpose(grad_i, 0, 1)
        else:
            grads_x_mset = torch.cat((grads_x_mset, torch.transpose(grad_i, 0, 1)), 0)

    print('grads_x_mset.shape: ', grads_x_mset.shape)  # [60, 10401]

    ##
    prior_mean_m = opt_bnn2.forward_mu(original_x_train_add_g3).squeeze().detach()

    ##
    prior_var_m = grads_x_mset @ w_cov @ torch.transpose(grads_x_mset, 0, 1)
    prior_var_m = torch.diag(prior_var_m, 0)
    #


 ############################ train flow2 ###################
    flownet3 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
        device)
    flownet4 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
        device)

    flow2 = LinearFlow_nn(flownet3).to(device)
    # flow2 = NonLinearFlow_nn(flownet3, flownet4).to(device)
    # flow2 = SigmoidFlow_nn(flownet3).to(device)
    # flow2 = ExpFlow_nn(flownet3).to(device)
    # flow2 = CubeFlow_nn(flownet2).to(device)
    flow_optimizer2 = torch.optim.Adam(flow2.parameters(), lr=0.01)


    ######
    loss_flow_all = []

    for epoch in range(5000):
        Z, logdet = flow2.inverse(mset_fmc, original_x_train_add_g3)
        # print('Z.shape: ', Z.shape)

        # loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
        #     torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet)

        #
        # cum_det = torch.stack(logdet_list)
        # cum_det = torch.sum(cum_det)
        logdet_list = []
        flow_list_inv = flow_list[::-1]

        for flow in flow_list_inv:

            Z, logdetf = flow.inverse(Z, original_x_train_add_g3)

            logdetf = torch.mean(logdetf)
            logdet_list.append(logdetf)
        # print('logdet_list: ', logdet_list)
        cum_det = torch.stack(logdet_list)
        cum_det = torch.sum(cum_det)

        loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
            torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet) - cum_det

        # print('sum_var_loss: ', torch.sum(torch.log(prior_var_m.abs())))
        # print('mean_loss: ', torch.mean(torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1)))
        # print('logdet: ', logdet)
        # print('cum_det: ', cum_det)

        flow_optimizer2.zero_grad()
        loss_flow.backward(retain_graph=True)
        flow_optimizer2.step()
        loss_flow_all.append(loss_flow.item())

        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_flow),
              datetime.now().replace(microsecond=0) - start_time)


   ############################
    flow2.eval()

    print('flow2.weight: ', flow2.weight)
    a2 = flow2.weight

    # flow_list.append(flow2)
    a_list.append(a2)

    flow_likelihood = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
            torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1))
    flow_likelihood_list.append(flow_likelihood)
    flow_det_list.append(logdet)
    cum_det_list.append(cum_det)
    flow_loss_list.append(loss_flow)


    ####

    plt.figure()
    indices = np.arange(5000)[::50]
    loss_flow_all = np.array(loss_flow_all)
    plt.plot(indices, loss_flow_all[indices], '-ko', ms=3)
    plt.ylabel(r'flow loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(figures_folder + '/flow_loss2.pdf')

    flow_list.append(flow2)

    ###########
    nnet_samples = opt_bnn.sample_functions(original_x_test, num_sample=100).float().to(device)  # original_x_test
    nnet_samples = nnet_samples.squeeze().detach()

    for flow in flow_list:
        nnet_samples, _, _ = flow(nnet_samples, original_x_test)
        # print('nnet_samples: ', nnet_samples)

    nnet_samples = nnet_samples.detach()
    print('fvi.shape: ', nnet_samples.shape)

    mse_app = loss(nnet_samples, original_y_test).detach().cpu()
    print('mse_app: ', mse_app)
    app_mse_all.append(mse_app)

#############

print('fvi_mse_all: ', fvi_mse_all)
print('fvi_nll_all: ', fvi_nll_all)
print('fmc_mse_all: ', fmc_mse_all)
print('fmc_nll_all: ', fmc_nll_all)
print('app_mse_all: ', app_mse_all)
print('a_list: ', a_list)

indices = np.arange(401)[::10]
for i in range(2):
    plt.figure()
    train_loss_all = all_train_losses[i]
    train_loss_all = np.array(train_loss_all)
    plt.plot(indices, train_loss_all[indices], '-ko', ms=3, label=f'flow {i + 1}')
    plt.ylabel(r'Training loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(loss_folder + '/train_loss{}.pdf'.format(i + 1))

for i in range(2):
    plt.figure()
    wdist_all = all_wdist_losses[i]
    wdist_all = np.array(wdist_all)
    plt.plot(indices, wdist_all[indices], '-ko', ms=3, label=f'flow {i + 1}')
    plt.ylabel(r'wdist')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(loss_folder + '/wdist{}.pdf'.format(i + 1))

for i in range(2):
    plt.figure()
    likelihood_loss_all = all_likelihood_losses[i]
    likelihood_loss_all = np.array(likelihood_loss_all)
    plt.plot(indices, likelihood_loss_all[indices], '-ko', ms=3, label=f'flow {i + 1}')
    plt.ylabel(r'likelihood_loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(loss_folder + '/likelihood_loss{}.pdf'.format(i + 1))

for i in range(2):
    plt.figure()
    dif_var_all = all_m2_losses[i]
    dif_var_all = np.array(dif_var_all)
    plt.plot(indices, dif_var_all[indices], '-ko', ms=3, label=f'flow {i + 1}')
    plt.ylabel(r'dif_var')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(loss_folder + '/dif_var{}.pdf'.format(i + 1))



# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("============================================================================================")



















