import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
import gpytorch

from data import x3_gap_toy, sin_toy
from toy import polynomial_toy, polynomial_gap_toy, g2_toy, g3_toy, isotropy_toy
# from bnn import BNN, OPTBNN
# from bnnp import BNNISO

from bnn_vimc import BNN, BNNF, BNNMC
# from bnnp import BNNMC
from gp_prior import ExactGPModel, prior_sample_functions
from wasserstein_prior_match2 import MapperWasserstein, WassersteinDistance
import copy

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import pandas as pd

from utils.utils import default_plotting_new as init_plotting


################################### Hyper-parameters ###################################

lr_bnn = 0.1
prior_coeff = 10  # gap:2, sin:10  5
bnn_name_string = 'VIMC'            # UAI_toy
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 2000 # 2000
num_functions = 10
num_sample = 1000
n_step_prior_pretraining = 100
# bnn_prior_step = 10000
epochs = 20001
f_coeff = 1   # 10
lr_optbnn = 1e-2
# 0.0012
torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 100  # gap:100, sin:500
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + 'fvi'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'fvi' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + 'fvi'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'fvi_0.1_matern'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

dataset = dict(x3=x3_gap_toy, sin=sin_toy, g2=g2_toy, g3=g3_toy, gap=polynomial_gap_toy, iso=isotropy_toy)['g3']()
original_x_train, original_y_train = dataset.train_samples()
mean_x, std_x = np.mean(original_x_train), np.std(original_x_train)
mean_y, std_y = np.mean(original_y_train), np.std(original_y_train)
train_x = (original_x_train - mean_x) / std_x
train_y = (original_y_train - mean_y) / std_y
original_x_test, original_y_test = dataset.test_samples()
test_x = (original_x_test - mean_x) / std_x
test_y = (original_y_test - mean_y) / std_y

y_logstd = np.log(dataset.y_std / std_y)

lower_ap = (dataset.x_min - mean_x) / std_x
upper_ap = (dataset.x_max - mean_x) / std_x

training_num, input_dim = np.shape(train_x)
input_dim = 1 #np.array([input_dim])
output_dim = 1 #np.array([1])
is_continuous = True

original_x_train = torch.from_numpy(original_x_train).float().to(device)
original_y_train = torch.from_numpy(original_y_train).float().to(device)
original_x_test = torch.from_numpy(original_x_test).float().to(device)
original_y_test = torch.from_numpy(original_y_test).float().to(device)

add_set = np. random. uniform(-5, 5, (40, 1))
add_set_y = 2 * np. sin(4*add_set)
add_set = torch.from_numpy(add_set).float().to(device)
add_set_y = torch.from_numpy(add_set_y).float().to(device)
add_set_y = torch.squeeze(add_set_y).to(device)

original_x_train_add = torch.cat((original_x_train, add_set), 0).to(device)
original_y_train_add = torch.cat((original_y_train, add_set_y), 0).to(device)

add_set_gap = np. random. uniform(-10, 10, (40, 1))
add_set_y_gap = np.sin(add_set_gap) + 0.1 * add_set_gap ** 2
add_set_gap = torch.from_numpy(add_set_gap).float().to(device)
add_set_y_gap = torch.from_numpy(add_set_y_gap).float().to(device)
add_set_y_gap = torch.squeeze(add_set_y_gap).to(device)

original_x_train_add_gap = torch.cat((original_x_train, add_set_gap), 0).to(device)
original_y_train_add_gap = torch.cat((original_y_train, add_set_y_gap), 0).to(device)

#####
add_set_g2 = np. random. uniform(-10, 10, (30, 1))
add_set_y_g2 = np.sin(add_set_g2) + 0.1 * add_set_g2
add_set_g2 = torch.from_numpy(add_set_g2).float().to(device)
add_set_y_g2 = torch.from_numpy(add_set_y_g2).float().to(device)
add_set_y_g2 = torch.squeeze(add_set_y_g2).to(device)

original_x_train_add_g2 = torch.cat((original_x_train, add_set_g2), 0).to(device)
original_y_train_add_g2 = torch.cat((original_y_train, add_set_y_g2), 0).to(device)

#####
add_set_g3 = np. random. uniform(-1, 1, (40, 1))
add_set_y_g3 = np.sin(3 * 3.14 * add_set_g3) + 0.3 * np.cos(9 * 3.14 * add_set_g3) + 0.5 * np.sin(7 * 3.14 * add_set_g3)
add_set_g3 = torch.from_numpy(add_set_g3).float().to(device)
add_set_y_g3 = torch.from_numpy(add_set_y_g3).float().to(device)
add_set_y_g3 = torch.squeeze(add_set_y_g3).to(device)

original_x_train_add_g3 = torch.cat((original_x_train, add_set_g3), 0).to(device)
original_y_train_add_g3 = torch.cat((original_y_train, add_set_y_g3), 0).to(device)

######
add_set_iso = np. random. uniform(-0.5, 1.2, (40, 1))
add_set_y_iso = add_set_iso + 0.3 * np.sin(2 * 3.14 * add_set_iso) + 0.3 * np.sin(4 * 3.14 * add_set_iso)
add_set_iso = torch.from_numpy(add_set_iso).float().to(device)
add_set_y_iso = torch.from_numpy(add_set_y_iso).float().to(device)
add_set_y_iso = torch.squeeze(add_set_y_iso).to(device)

original_x_train_add_iso = torch.cat((original_x_train, add_set_iso), 0).to(device)
original_y_train_add_iso = torch.cat((original_y_train, add_set_y_iso), 0).to(device)


loss = nn.MSELoss(reduction='mean')

print("============================================================================================")
################################## define model ##################################


bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(), 'lr': lr_bnn}])

opt_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# opt_bnn = BNNISO(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# opt_bnn_optimizer = torch.optim.RMSprop([{'params': opt_bnn.parameters(), 'lr': lr_optbnn}])
opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': lr_bnn}])

print("============================================================================================")
################################## prior pre-training ##################################

# pre-train GP prior

likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior = ExactGPModel(original_x_test, original_y_test, likelihood, input_dim).to(device)

prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(original_x_test)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, original_y_test)
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        # prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))

    optimizer.step()

prior.eval()
likelihood.eval()
################################## bnn prior pre-train ##################################

# prior_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# prior_bnn_optimizer = torch.optim.RMSprop([{'params': prior_bnn.parameters(), 'lr': lr_optbnn}])
#
# # loss
# for i in range(10000):
#     pred_prior = prior_bnn.forward(original_x_train)
#     # L1regu = regu_coeff / 20 * L1regu
#     # pred_y = bnn.forward(original_x_train)
#     # y_predicted, distance_prior = bnn.forward_w(train_x)
#
#     pred_prior = pred_prior.squeeze().flatten()
#     # print(pred_y)
#     # print(train_y_tmp)
#     train_loss_prior = loss(pred_prior, original_y_train)
# # optimisation
#     prior_bnn_optimizer.zero_grad()
#     train_loss_prior.backward()
#     prior_bnn_optimizer.step()
# # print
#     print("epoch : {} \t\t training loss \t\t : {}".format(10000, train_loss_prior),
#           )
#
#
#
# bnnprior_samples = prior_bnn.sample_functions(original_x_test, num_sample=128).float().to(device)
# bnnprior_samples = bnnprior_samples.squeeze()
# nns_mean = torch.mean(bnnprior_samples, 0)
# nns_std = torch.std(bnnprior_samples, 0)
#
# with torch.no_grad():
#
#     plt.clf()
#     figure = plt.figure(figsize=(8, 5.5), facecolor='white')
#     init_plotting()
#     plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
#     plt.plot(original_x_test.squeeze().cpu(), nns_mean.cpu().numpy(), 'slateblue', label='BNN posterior', linewidth=2.5)
#
#     #################
#     for i in range(5):
#         plt.fill_between(original_x_test.squeeze().cpu(), nns_mean.cpu().numpy() - i * 1 * nns_std.cpu().numpy(),
#                          nns_mean.cpu().numpy() - (i + 1) * 1 * nns_std.cpu().numpy(), linewidth=0.0,
#                          alpha=1.0 - i * 0.15, color='orchid')  # wheat
#         plt.fill_between(original_x_test.squeeze().cpu(), nns_mean.cpu().numpy() + i * 1 * nns_std.cpu().numpy(),
#                          nns_mean.cpu().numpy() + (i + 1) * 1 * nns_std.cpu().numpy(), linewidth=0.0,
#                          alpha=1.0 - i * 0.15, color='orchid')
#     plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
#
#
#
#     plt.legend(loc='upper center', ncol=3, fontsize='small')
#     plt.title('Posterior BNN')
#     plt.tight_layout()
#     plt.ylim([dataset.y_min, dataset.y_max])
#     plt.tight_layout()
#     plt.savefig(figures_folder + '/Posterior BNN.pdf')


################################## gp2 pre-training ##################################

# pre-train GP prior

likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior2 = ExactGPModel(original_x_train, original_y_train, likelihood, input_dim).to(device)

prior2.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior2.parameters(), lr=0.01)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior2)

for i in range(1001):   # 10001
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior2(original_x_train)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, original_y_train)
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        # prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))

    optimizer.step()

prior2.eval()
likelihood.eval()

with torch.no_grad():
    plt.clf()
    figure = plt.figure(figsize=(8, 5.5), facecolor='white')
    init_plotting()
    p_gp = likelihood(prior2(original_x_test))
    f_var = prior2(original_x_test).variance
    f_std = f_var ** 0.5
    plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
    plt.plot(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy(), 'slateblue', label='GP posterior', linewidth=2.5)
    # for i in range(5):
    #     plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() - i * 1 * p_gp.variance.cpu().numpy(),
    #                      p_gp.mean.cpu().numpy() - (i + 1) * 1 * p_gp.variance.cpu().numpy(), linewidth=0.0,
    #                      alpha=1.0 - i * 0.15, color='orchid')  # wheat
    #     plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() + i * 1 * p_gp.variance.cpu().numpy(),
    #                      p_gp.mean.cpu().numpy() + (i + 1) * 1 * p_gp.variance.cpu().numpy(), linewidth=0.0,
    #                      alpha=1.0 - i * 0.15, color='orchid')
    # plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
    #################
    for i in range(5):
        plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() - i * 1 * f_std.cpu().numpy(),
                         p_gp.mean.cpu().numpy() - (i + 1) * 1 * f_std.cpu().numpy(), linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')  # wheat
        plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() + i * 1 * f_std.cpu().numpy(),
                         p_gp.mean.cpu().numpy() + (i + 1) * 1 * f_std.cpu().numpy(), linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')
    plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')

    ######g3######
    plt.axvline(-0.75, linestyle='--', color='darkorchid')
    plt.axvline(-0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.75, linestyle='--', color='darkorchid')
    plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
    plt.text(0.25, 1, 'Training data range', fontsize='small')

    plt.legend(loc='upper center', ncol=3, fontsize='small')
    plt.title('GP posteior (Matern kernel)')
    plt.tight_layout()
    plt.ylim([dataset.y_min, dataset.y_max])
    plt.tight_layout()
    plt.savefig(figures_folder + '/gp2_RBF.pdf')

################################## start training ##################################


print("============================================================================================")
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

##########################################    check opt_bnn  #######

def sample_measurement_set(X, num_data):
    n = torch.Tensor([40])  # sin_40, gap_20
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

def lipf(X, Y):
    return ((X - Y).abs())

def sample_measurement_set2(X, num_data):
    n = torch.Tensor([60])  # sin_40, gap_20
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set


def sample_measurement_set3(X, num_data):
    n = torch.Tensor([20])  # sin_40, gap_20
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

########## Linear flow ###############
class LinearFlow(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.bias = nn.Parameter(torch.randn(1, dim),  requires_grad=True)   # torch.randn(1)


    def forward(self, z):
        # x = F.linear(z, self.weight, self.bias)
        x = z * self.weight + self.bias
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x

    def inverse(self, x):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        z = x * 1/self.weight - self.bias/self.weight
        log_det_inverse = torch.sum(torch.log(1/self.weight.abs() + 1e-7))

        return z, log_det_inverse

######### RealNVP ######

class RealNVP(nn.Module):
    def __init__(self, nets, nett, mask):
        super(RealNVP, self).__init__()

        self.mask = nn.Parameter(mask, requires_grad=False)
        self.t = torch.nn.ModuleList([nett() for _ in range(len(masks))])
        self.s = torch.nn.ModuleList([nets() for _ in range(len(masks))])

    def forward(self, z):    ######### x = g(z)
        log_det_J, x = z.new_zeros(z.shape[0]), z
        for i in range(len(self.t)):
            x_ = x * self.mask[i]
            s = self.s[i](x_) * (1 - self.mask[i])
            t = self.t[i](x_) * (1 - self.mask[i])
            x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
            log_det_J += s.sum(dim=1)

        return x

    def inverse(self, x):     ########### z = f(x)
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in reversed(range(len(self.t))):
            z_ = self.mask[i] * z
            s = self.s[i](z_) * (1 - self.mask[i])
            t = self.t[i](z_) * (1 - self.mask[i])
            z = z_ + (1 - self.mask[i]) * (z - t) * torch.exp(-s)
            log_det_J -= s.sum(dim=1)

        return z, log_det_J


###### linear_nn flow ###############
hidden_dimsflow = [100]
class LinearFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net   # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()
        x = z * self.weight + nnbias
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()
        print('nnbias.shape: ', nnbias.shape)
        print('x.shape: ', x.shape)
        z = x * 1/self.weight - nnbias/self.weight
        log_det_inverse = torch.sum(torch.log(1/self.weight.abs() + 1e-7))

        return z, log_det_inverse

###################### plannar flow ###################
hidden_dimsflow = [100]
class PlanarFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        self.bias = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        # self.bias = net1
        self.weight = net  # torch.randn(1)
        # self.scale = nn.Parameter(torch.randn(1), requires_grad=True)



    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnweight = self.weight(x_in).squeeze()
        x = z + nnweight * z + self.bias
        # log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnweight

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnweight = self.weight(x_in).squeeze()
        # nnbias = self.bias(x_in).squeeze()
        print('nnweight.shape: ', nnweight.shape)
        print('x.shape: ', x.shape)
        z = (x - self.bias ) / (1 + nnweight)
        log_det_inverse = torch.sum(torch.log(1/(1 + nnweight).abs() + 1e-7))

        return z, log_det_inverse

###
# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

################# S1: FVI #################################
train_loss_all = []
prior_distance_all = []
wdist_all = []
likelihood_loss_all = []

for epoch in range(20001):

    measurement_set = sample_measurement_set(X=original_x_train_add_g3, num_data=60) # sin num_data=60, gap num_data=100
    gpss = prior_sample_functions(measurement_set, prior2, num_sample=128).detach().float().to(device)
    # gpss = prior_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
    gpss = gpss.squeeze()
    nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device) # original_x_test
    nnet_samples = nnet_samples.squeeze()

    functional_wdist = lipf(gpss, nnet_samples)
    functional_wdist = torch.mean(functional_wdist)

    ##################### variance constraint #########
    # print('og3.shape: ', original_y_train_add_g3.shape)
    # print('ot.shape: ', original_x_test.shape)
    measurement_set2 = sample_measurement_set2(X=original_x_train_add_g3, num_data=60)

    with torch.no_grad():
        gp_var = prior2(measurement_set2).variance
        # print('gp_var.shape: ', gp_var.shape)

    opt_samples = opt_bnn.sample_functions(measurement_set2, num_sample=200).float().to(device)
    opt_samples = opt_samples.squeeze()
    opt_var = torch.std(opt_samples, 0) ** 2
    # print('opt_var.shape: ', opt_var.shape)
    dif_var = (gp_var - opt_var).abs()
    dif_var = torch.sum(dif_var)
    ####

    ### calculate bnn likelihood loss ###
    pred_y = opt_bnn.forward(original_x_train)

    pred_y = pred_y.squeeze().flatten()
    # print(pred_y)

    train_loss = loss(pred_y, original_y_train) + 1 * functional_wdist + 1 * dif_var
    likelihood_loss = loss(pred_y, original_y_train)
    print('likelihood_loss: ', likelihood_loss)
    # print('distance_prior: ', bnn_2wassdist)
    print('w_dist: ', functional_wdist)
    #
    # optimisation
    opt_bnn_optimizer.zero_grad()
    train_loss.backward()
    # bnn_optimizer.step()
    # scheduler.step()
    opt_bnn_optimizer.step()
    #
    #
    train_loss_all.append(train_loss.item())
    likelihood_loss_all.append(likelihood_loss.item())
    wdist_all.append(functional_wdist.item())

    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
          datetime.now().replace(microsecond=0) - start_time)

    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    if epoch % test_interval == 0:

        opt_bnn.eval()

        samples_pred_test_y = opt_bnn.forward_eval(original_x_test, num_sample).squeeze().detach()

        # mean_y_pred = pred_test_y
        mean_y_pred = torch.mean(samples_pred_test_y, 0).cpu().numpy()
        std_y_pred = torch.std(samples_pred_test_y, 0).cpu().numpy()
        test_loss = loss(samples_pred_test_y, original_y_test)

        opt_bnn.train()

        plt.clf()
        figure = plt.figure(figsize=(8, 5.5), facecolor='white')
        init_plotting()

        plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
        plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function', linewidth=2.5)
        # plt.plot(original_x_test.squeeze().cpu(), nns.cpu().numpy(), 'steelblue')

        for i in range(5):
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                             mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                             mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
        plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
        # plt.grid(True)
        # plt.tick_params(axis='both', bottom='off', top='off', left='off', right='off',
        #                 labelbottom='off', labeltop='off', labelleft='off', labelright='off')
        ######### sin #################
        # plt.axvline(-2, linestyle='--', color='darkorchid')   # darkorange
        # plt.axvline(2, linestyle='--', color='darkorchid')
        # plt.text(-1.5, -3, 'Training data range')
        ###### gap ###########
        # plt.axvline(-7.5, linestyle='--', color='darkorchid')
        # plt.axvline(-2.5, linestyle='--', color='darkorchid')
        # plt.axvline(5, linestyle='--', color='darkorchid')
        # plt.axvline(7.5, linestyle='--', color='darkorchid')
        # # plt.text(-1.5, -3, 'Training data range')
        # plt.text(-7.4, -1, 'Training data range', fontsize='small')
        # plt.text(5.1, -1, 'Training data', fontsize='xx-small')
        # plt.text(5.1, -1.8, 'range', fontsize='xx-small')

        ######g3######
        plt.axvline(-0.75, linestyle='--', color='darkorchid')
        plt.axvline(-0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.75, linestyle='--', color='darkorchid')
        plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
        plt.text(0.25, 1, 'Training data range', fontsize='small')

        plt.legend(loc='upper center', ncol=3, fontsize='small')
        plt.title('FVI posterior')
        plt.tight_layout()
        plt.ylim([dataset.y_min, dataset.y_max])
        plt.tight_layout()

        print("--------------------------------------------------------------------------------------------")
        plt.savefig(figures_folder + '/plot_epoch{}.pdf'.format(epoch))
        print("figure saved")
        print("epoch : {} \t\t test loss \t\t : {}".format(epoch, test_loss))
        print("--------------------------------------------------------------------------------------------")
#######

###########################
plt.figure()
indices = np.arange(20001)[::200]
train_loss_all = np.array(train_loss_all)
plt.plot(indices, train_loss_all[indices], '-ko', ms=3)
plt.ylabel(r'Training loss')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/train_loss.pdf')

plt.figure()
indices = np.arange(20001)[::200]
likelihood_loss_all = np.array(likelihood_loss_all)
plt.plot(indices, likelihood_loss_all[indices], '-ko', ms=3)
plt.ylabel(r'Likelihood loss')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/likelihood_loss.pdf')

plt.figure()
indices = np.arange(20001)[::200]
wdist_all = np.array(wdist_all)
plt.plot(indices, wdist_all[indices], '-ko', ms=3)
# plt.ylabel(r'$W_1(p_{gp}, p_{prior bnn})$')
plt.ylabel(r'$W_1(p_{f}, p_{g})$')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/wdist.pdf')



# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("============================================================================================")
