import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
import gpytorch

from data import x3_gap_toy, sin_toy
from toy import polynomial_toy, polynomial_gap_toy, g2_toy, g3_toy, isotropy_toy
# from bnn import BNN, OPTBNN
# from bnnp import BNNISO

from bnn_vimc import BNN, BNNF, BNNMC
# from bnnp import BNNMC
from gp_prior import ExactGPModel, prior_sample_functions
from wasserstein_prior_match2 import MapperWasserstein, WassersteinDistance
import copy

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import pandas as pd

from utils.utils import default_plotting_new as init_plotting


################################### Hyper-parameters ###################################

lr_bnn = 0.01
prior_coeff = 10  # gap:2, sin:10  5
bnn_name_string = 'VIMC'            # UAI_toy
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 500 # 2000
num_functions = 10
num_sample = 1000
n_step_prior_pretraining = 100
# bnn_prior_step = 10000
epochs = 10001
f_coeff = 1   # 10
lr_optbnn = 1e-2
# 0.0012
torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 100  # gap:100, sin:500
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + 'fvimc'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'fvimc' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + 'fvimc'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'ttt_full'    # fvimc_loopt_multiflow
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)

loss_folder = results_folder + '/' + bnn_name_string + '/' + 'losses' + '/' + 'ttt_full'     # fvimc_loopt_multiflow
if not os.path.exists(loss_folder):
    os.makedirs(loss_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

dataset = dict(x3=x3_gap_toy, sin=sin_toy, g2=g2_toy, g3=g3_toy, gap=polynomial_gap_toy, iso=isotropy_toy)['g3']()
original_x_train, original_y_train = dataset.train_samples()
mean_x, std_x = np.mean(original_x_train), np.std(original_x_train)
mean_y, std_y = np.mean(original_y_train), np.std(original_y_train)
train_x = (original_x_train - mean_x) / std_x
train_y = (original_y_train - mean_y) / std_y
original_x_test, original_y_test = dataset.test_samples()
test_x = (original_x_test - mean_x) / std_x
test_y = (original_y_test - mean_y) / std_y

y_logstd = np.log(dataset.y_std / std_y)

lower_ap = (dataset.x_min - mean_x) / std_x
upper_ap = (dataset.x_max - mean_x) / std_x

training_num, input_dim = np.shape(train_x)
input_dim = 1 #np.array([input_dim])
output_dim = 1 #np.array([1])
is_continuous = True

original_x_train = torch.from_numpy(original_x_train).float().to(device)
original_y_train = torch.from_numpy(original_y_train).float().to(device)
original_x_test = torch.from_numpy(original_x_test).float().to(device)
original_y_test = torch.from_numpy(original_y_test).float().to(device)

add_set = np. random. uniform(-5, 5, (40, 1))
add_set_y = 2 * np. sin(4*add_set)
add_set = torch.from_numpy(add_set).float().to(device)
add_set_y = torch.from_numpy(add_set_y).float().to(device)
add_set_y = torch.squeeze(add_set_y).to(device)

original_x_train_add = torch.cat((original_x_train, add_set), 0).to(device)
original_y_train_add = torch.cat((original_y_train, add_set_y), 0).to(device)

add_set_gap = np. random. uniform(-10, 10, (40, 1))
add_set_y_gap = np.sin(add_set_gap) + 0.1 * add_set_gap ** 2
add_set_gap = torch.from_numpy(add_set_gap).float().to(device)
add_set_y_gap = torch.from_numpy(add_set_y_gap).float().to(device)
add_set_y_gap = torch.squeeze(add_set_y_gap).to(device)

original_x_train_add_gap = torch.cat((original_x_train, add_set_gap), 0).to(device)
original_y_train_add_gap = torch.cat((original_y_train, add_set_y_gap), 0).to(device)

#####
add_set_g2 = np. random. uniform(-10, 10, (30, 1))
add_set_y_g2 = np.sin(add_set_g2) + 0.1 * add_set_g2
add_set_g2 = torch.from_numpy(add_set_g2).float().to(device)
add_set_y_g2 = torch.from_numpy(add_set_y_g2).float().to(device)
add_set_y_g2 = torch.squeeze(add_set_y_g2).to(device)

original_x_train_add_g2 = torch.cat((original_x_train, add_set_g2), 0).to(device)
original_y_train_add_g2 = torch.cat((original_y_train, add_set_y_g2), 0).to(device)

#####
add_set_g3 = np. random. uniform(-1, 1, (40, 1))
# add_set_g3 = np.linspace(-1, 1, num=40, dtype=np.float32)
# add_set_g3 = np.expand_dims(add_set_g3, axis=-1)

add_set_y_g3 = np.sin(3 * 3.14 * add_set_g3) + 0.3 * np.cos(9 * 3.14 * add_set_g3) + 0.5 * np.sin(7 * 3.14 * add_set_g3)
add_set_g3 = torch.from_numpy(add_set_g3).float().to(device)
add_set_y_g3 = torch.from_numpy(add_set_y_g3).float().to(device)
add_set_y_g3 = torch.squeeze(add_set_y_g3).to(device)

original_x_train_add_g3 = torch.cat((original_x_train, add_set_g3), 0).to(device)
original_y_train_add_g3 = torch.cat((original_y_train, add_set_y_g3), 0).to(device)

######
add_set_iso = np. random. uniform(-0.5, 1.2, (40, 1))
add_set_y_iso = add_set_iso + 0.3 * np.sin(2 * 3.14 * add_set_iso) + 0.3 * np.sin(4 * 3.14 * add_set_iso)
add_set_iso = torch.from_numpy(add_set_iso).float().to(device)
add_set_y_iso = torch.from_numpy(add_set_y_iso).float().to(device)
add_set_y_iso = torch.squeeze(add_set_y_iso).to(device)

original_x_train_add_iso = torch.cat((original_x_train, add_set_iso), 0).to(device)
original_y_train_add_iso = torch.cat((original_y_train, add_set_y_iso), 0).to(device)


loss = nn.MSELoss(reduction='mean')

print("============================================================================================")
################################## define model ##################################


bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(), 'lr': lr_bnn}])

opt_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# opt_bnn = BNNISO(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# opt_bnn_optimizer = torch.optim.RMSprop([{'params': opt_bnn.parameters(), 'lr': lr_optbnn}])
opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': lr_bnn}])

print("============================================================================================")
################################## prior pre-training ##################################

# pre-train GP prior

likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior = ExactGPModel(original_x_test, original_y_test, likelihood, input_dim).to(device)

prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(original_x_test)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, original_y_test)
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        # prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))

    optimizer.step()

prior.eval()
likelihood.eval()
################################## bnn prior pre-train ##################################

# prior_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# prior_bnn_optimizer = torch.optim.RMSprop([{'params': prior_bnn.parameters(), 'lr': lr_optbnn}])
#
# # loss
# for i in range(10000):
#     pred_prior = prior_bnn.forward(original_x_train)
#     # L1regu = regu_coeff / 20 * L1regu
#     # pred_y = bnn.forward(original_x_train)
#     # y_predicted, distance_prior = bnn.forward_w(train_x)
#
#     pred_prior = pred_prior.squeeze().flatten()
#     # print(pred_y)
#     # print(train_y_tmp)
#     train_loss_prior = loss(pred_prior, original_y_train)
# # optimisation
#     prior_bnn_optimizer.zero_grad()
#     train_loss_prior.backward()
#     prior_bnn_optimizer.step()
# # print
#     print("epoch : {} \t\t training loss \t\t : {}".format(10000, train_loss_prior),
#           )
#
#
#
# bnnprior_samples = prior_bnn.sample_functions(original_x_test, num_sample=128).float().to(device)
# bnnprior_samples = bnnprior_samples.squeeze()
# nns_mean = torch.mean(bnnprior_samples, 0)
# nns_std = torch.std(bnnprior_samples, 0)
#
# with torch.no_grad():
#
#     plt.clf()
#     figure = plt.figure(figsize=(8, 5.5), facecolor='white')
#     init_plotting()
#     plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
#     plt.plot(original_x_test.squeeze().cpu(), nns_mean.cpu().numpy(), 'slateblue', label='BNN posterior', linewidth=2.5)
#
#     #################
#     for i in range(5):
#         plt.fill_between(original_x_test.squeeze().cpu(), nns_mean.cpu().numpy() - i * 1 * nns_std.cpu().numpy(),
#                          nns_mean.cpu().numpy() - (i + 1) * 1 * nns_std.cpu().numpy(), linewidth=0.0,
#                          alpha=1.0 - i * 0.15, color='orchid')  # wheat
#         plt.fill_between(original_x_test.squeeze().cpu(), nns_mean.cpu().numpy() + i * 1 * nns_std.cpu().numpy(),
#                          nns_mean.cpu().numpy() + (i + 1) * 1 * nns_std.cpu().numpy(), linewidth=0.0,
#                          alpha=1.0 - i * 0.15, color='orchid')
#     plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
#
#
#
#     plt.legend(loc='upper center', ncol=3, fontsize='small')
#     plt.title('Posterior BNN')
#     plt.tight_layout()
#     plt.ylim([dataset.y_min, dataset.y_max])
#     plt.tight_layout()
#     plt.savefig(figures_folder + '/Posterior BNN.pdf')


################################## gp2 pre-training ##################################

# pre-train GP prior

likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior2 = ExactGPModel(original_x_train, original_y_train, likelihood, input_dim).to(device)

prior2.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior2.parameters(), lr=0.01)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior2)

for i in range(1001):   # 10001
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior2(original_x_train)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, original_y_train)
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        # prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))

    optimizer.step()

prior2.eval()
likelihood.eval()

with torch.no_grad():
    plt.clf()
    figure = plt.figure(figsize=(8, 5.5), facecolor='white')
    init_plotting()
    p_gp = likelihood(prior2(original_x_test))
    f_var = prior2(original_x_test).variance
    f_std = f_var ** 0.5
    plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
    plt.plot(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy(), 'slateblue', label='GP posterior', linewidth=2.5)
    # for i in range(5):
    #     plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() - i * 1 * p_gp.variance.cpu().numpy(),
    #                      p_gp.mean.cpu().numpy() - (i + 1) * 1 * p_gp.variance.cpu().numpy(), linewidth=0.0,
    #                      alpha=1.0 - i * 0.15, color='orchid')  # wheat
    #     plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() + i * 1 * p_gp.variance.cpu().numpy(),
    #                      p_gp.mean.cpu().numpy() + (i + 1) * 1 * p_gp.variance.cpu().numpy(), linewidth=0.0,
    #                      alpha=1.0 - i * 0.15, color='orchid')
    # plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
    #################
    for i in range(5):
        plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() - i * 1 * f_std.cpu().numpy(),
                         p_gp.mean.cpu().numpy() - (i + 1) * 1 * f_std.cpu().numpy(), linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')  # wheat
        plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() + i * 1 * f_std.cpu().numpy(),
                         p_gp.mean.cpu().numpy() + (i + 1) * 1 * f_std.cpu().numpy(), linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')
    plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')

    ######g3######
    plt.axvline(-0.75, linestyle='--', color='darkorchid')
    plt.axvline(-0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.75, linestyle='--', color='darkorchid')
    plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
    plt.text(0.25, 1, 'Training data range', fontsize='small')

    plt.legend(loc='upper center', ncol=3, fontsize='small')
    plt.title('GP prior (RBF kernel)')
    plt.tight_layout()
    plt.ylim([dataset.y_min, dataset.y_max])
    plt.tight_layout()
    plt.savefig(figures_folder + '/gp2_RBF.pdf')

################################## start training ##################################


print("============================================================================================")
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

##########################################    check opt_bnn  #######

def sample_measurement_set(X, num_data):
    n = torch.Tensor([40])  # sin_40, gap_20
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

def lipf(X, Y):
    return ((X - Y).abs())

def sample_measurement_set2(X, num_data):
    n = torch.Tensor([60])  # sin_40, gap_20
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set


def sample_measurement_set3(X, num_data):
    n = torch.Tensor([20])  # sin_40, gap_20
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

########## Linear flow ###############
class LinearFlow(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.bias = nn.Parameter(torch.randn(1, dim),  requires_grad=True)   # torch.randn(1)


    def forward(self, z):
        # x = F.linear(z, self.weight, self.bias)
        x = z * self.weight + self.bias
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x

    def inverse(self, x):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        z = x * 1/self.weight - self.bias/self.weight
        log_det_inverse = torch.sum(torch.log(1/self.weight.abs() + 1e-7))

        return z, log_det_inverse

######### RealNVP ######

class RealNVP(nn.Module):
    def __init__(self, nets, nett, mask):
        super(RealNVP, self).__init__()

        self.mask = nn.Parameter(mask, requires_grad=False)
        self.t = torch.nn.ModuleList([nett() for _ in range(len(masks))])
        self.s = torch.nn.ModuleList([nets() for _ in range(len(masks))])

    def forward(self, z):    ######### x = g(z)
        log_det_J, x = z.new_zeros(z.shape[0]), z
        for i in range(len(self.t)):
            x_ = x * self.mask[i]
            s = self.s[i](x_) * (1 - self.mask[i])
            t = self.t[i](x_) * (1 - self.mask[i])
            x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
            log_det_J += s.sum(dim=1)

        return x

    def inverse(self, x):     ########### z = f(x)
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in reversed(range(len(self.t))):
            z_ = self.mask[i] * z
            s = self.s[i](z_) * (1 - self.mask[i])
            t = self.t[i](z_) * (1 - self.mask[i])
            z = z_ + (1 - self.mask[i]) * (z - t) * torch.exp(-s)
            log_det_J -= s.sum(dim=1)

        return z, log_det_J


###### linear_nn flow ###############
hidden_dimsflow = [100]
class LinearFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net   # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()
        x = z * self.weight + nnbias
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()
        # print('nnbias.shape: ', nnbias.shape)
        # print('x.shape: ', x.shape)
        z = x * 1/self.weight - nnbias/self.weight
        log_det_inverse = torch.sum(torch.log(1/self.weight.abs() + 1e-7))

        return z, log_det_inverse

###################### plannar flow ###################
hidden_dimsflow = [100]
class PlanarFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        self.bias = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        # self.bias = net1
        self.weight = net  # torch.randn(1)
        # self.scale = nn.Parameter(torch.randn(1), requires_grad=True)



    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnweight = self.weight(x_in).squeeze()
        x = z + nnweight * z + self.bias
        # log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnweight

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnweight = self.weight(x_in).squeeze()
        # nnbias = self.bias(x_in).squeeze()
        print('nnweight.shape: ', nnweight.shape)
        print('x.shape: ', x.shape)
        z = (x - self.bias ) / (1 + nnweight)
        log_det_inverse = torch.sum(torch.log(1/(1 + nnweight).abs() + 1e-7))

        return z, log_det_inverse

###
###### non-linear_nn flow ###############
hidden_dimsflow = [100]
# class NonLinearFlow_nn(nn.Module):
#     def __init__(self, net1, net2):
#         super().__init__()
#         # self.dim = dim
#         # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
#         self.weight1 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
#         self.bias1 = net1  # torch.randn(1)
#         self.weight2 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
#         self.bias2 = net2  # torch.randn(1)
#
#
#     def forward(self, z, x_in):
#         # x = F.linear(z, self.weight, self.bias)
#         nnbias1 = self.bias1(x_in).squeeze()
#         nnbias2 = self.bias2(x_in).squeeze()
#         x = self.weight2 * torch.tanh(z * self.weight1 + nnbias1) + nnbias2
#         log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))
#
#         return x, nnbias1, log_det
#
#     def inverse(self, x, x_in):
#         # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
#         nnbias1 = self.bias1(x_in).squeeze()
#         nnbias2 = self.bias2(x_in).squeeze()
#         print('nnbias1.shape: ', nnbias1)
#         print('nnbias2.shape: ', nnbias2)
#         print('x.shape: ', x.shape)
#         z = (torch.atanh((x - nnbias2) / self.weight2)) / self.weight1 - nnbias1 / self.weight1
#         print('z: ', z)
#         # log_det_inverse = torch.sum(torch.log(1/self.weight.abs() + 1e-7))
#         det = self.weight1 * self.weight2 * (1 - ((x - nnbias2) / self.weight2) ** 2)
#         log_det_inverse = torch.sum(torch.log(1 / det.abs() + 1e-7))
#         print('log_det_inverse: ', log_det_inverse)
#
#         return z, log_det_inverse

##
class NonLinearFlow_nn(nn.Module):
    def __init__(self, net1, net2):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias1 = net1  # torch.randn(1)
        self.weight2 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias2 = net2  # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias1 = self.bias1(x_in).squeeze()
        nnbias2 = self.bias2(x_in).squeeze()
        mask = (z - nnbias2) / self.weight2
        clamped_mask = torch.clamp(mask, -0.99, 0.99)
        x = torch.atanh(clamped_mask) / self.weight - nnbias1 / self.weight
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias1, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias1 = self.bias1(x_in).squeeze()
        nnbias2 = self.bias2(x_in).squeeze()
        # print('nnbias1.shape: ', nnbias1)
        # print('nnbias2.shape: ', nnbias2)

        z = self.weight2 * torch.tanh(x * self.weight + nnbias1) + nnbias2

        det = self.weight2 * self.weight * (1 - torch.tanh(x * self.weight + nnbias1) ** 2)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse


########## sigmoid_flow #############
class SigmoidFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.weight2 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()

        mask = (z - nnbias) / self.weight2
        mask = torch.clamp(mask, min=0.01, max=0.99)
        clamped_mask = mask / (1 - mask)
        # x = torch.log(clamped_mask + 1e-7) / self.weight - nnbias1 / self.weight
        x = torch.log(clamped_mask + 1e-7) / self.weight
        log_det = 1.

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()

        mask = torch.nn.functional.sigmoid(x * self.weight)
        z = self.weight2 * mask + nnbias

        det = self.weight2 * self.weight * mask * (1 - mask)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse

########## exp_flow #############
# class ExpFlow_nn(nn.Module):
#     def __init__(self, net):
#         super().__init__()
#         # self.dim = dim
#         # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
#         self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
#         self.bias = net   # torch.randn(1)
#
#
#     def forward(self, z, x_in):
#         # x = F.linear(z, self.weight, self.bias)
#         nnbias = self.bias(x_in).squeeze()
#         mask = (z - nnbias) / self.weight
#         mask = torch.clamp(mask, min=0.001)
#         x = torch.log(mask.abs() + 1e-7)
#         log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))
#
#         return x, nnbias, log_det
#
#     def inverse(self, x, x_in):
#         # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
#         nnbias = self.bias(x_in).squeeze()
#         # print('nnbias.shape: ', nnbias.shape)
#         # print('x.shape: ', x.shape)
#         z = torch.exp(x) * self.weight + nnbias
#         # log_det_inverse = torch.sum(x) + torch.log(self.weight.abs() + 1e-7)
#         # log_det_inverse = log_det_inverse.squeeze()
#
#         det = self.weight * torch.exp(x)
#         log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)
#
#         return z, log_det_inverse

########## exp_flow #############
class ExpFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net   # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()
        mask = (z - nnbias) / self.weight
        mask = torch.clamp(mask, min=0.001)
        mask = torch.log(mask + 1e-7)
        # mask = torch.log(mask.abs() + 1e-7)
        clamped_mask = torch.clamp(mask, -0.99, 0.99)
        x = torch.atanh(clamped_mask)
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()
        # print('nnbias.shape: ', nnbias.shape)
        # print('x.shape: ', x.shape)
        z = torch.exp(torch.tanh(x)) * self.weight + nnbias
        det = self.weight * torch.exp(torch.tanh(x)) * (1 - torch.tanh(x) ** 2)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse



###### cube_nn flow ###############
def cube_root(x):
    return torch.sign(x) * torch.abs(x) ** (1/3)

def power_negative_base(x):
    result = torch.abs(x) ** (-2/3)
    if torch.is_tensor(result):
        return result
    else:
        return torch.sign(x) * result

# def power_negative_base(x):
    # return torch.sign(x) * torch.abs(x) ** (-2/3)

hidden_dimsflow = [100]

# class CubeFlow_nn(nn.Module):
#     def __init__(self, net):
#         super().__init__()
#         # self.dim = dim
#         # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
#         self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
#         self.bias = net   # torch.randn(1)
#
#
#     def forward(self, z, x_in):
#         # x = F.linear(z, self.weight, self.bias)
#         nnbias = self.bias(x_in).squeeze()
#
#         # x = (z * 1/self.weight - nnbias/self.weight) ** (1/3)
#         # log_det = torch.sum(torch.log((((z * 1/self.weight - nnbias/self.weight) ** (-2/3)) / 3).abs() + 1e-7))
#
#         x = (z ** 3) * self.weight + nnbias
#         log_det = torch.sum(torch.log(((z ** 2) * self.weight * 3).abs() + 1e-7))
#
#         return x, nnbias, log_det
#
#
#     def inverse(self, x, x_in):
#         # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
#         nnbias = self.bias(x_in).squeeze()
#
#         # z = (torch.exp(x)) * self.weight + nnbias
#         # z = (x ** 3) * self.weight + nnbias
#         # log_det_inverse = torch.sum(torch.log(((x ** 2) * self.weight * 3).abs() + 1e-7))
#         z1 = x * 1 / self.weight - nnbias / self.weight
#         z = cube_root(z1)
#
#         p_z = power_negative_base(z1)
#         # print('p_z: ',p_z)
#         # log_det_inverse = torch.sum(torch.log((p_z / (3)).abs() + 1e-7))
#         log_det_inverse = torch.sum(torch.log((p_z / (3 * self.weight)).abs() + 1e-7))
#         # print('log_det_inverse: ', log_det_inverse)
#         # if torch.isinf(log_det_inverse) or torch.isnan(log_det_inverse):
#         #     log_det_inverse = torch.zeros_like(log_det_inverse)
#
#         return z, log_det_inverse


########## cube_flow #############
class CubeFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net   # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()
        mask = (z - nnbias) / self.weight

        clamped_mask = torch.clamp(mask, -0.99, 0.99)
        x = torch.atanh(clamped_mask)
        # x = torch.sign(x) * torch.abs(x) ** (1/3)
        x = torch.copysign(torch.pow(torch.abs(x), 1/3), x)

        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()
        # print('nnbias.shape: ', nnbias.shape)
        # print('x.shape: ', x.shape)
        z = torch.tanh(x ** 3) * self.weight + nnbias
        # z = torch.sin(x) * self.weight + nnbias
        det = 3 * self.weight * (x ** 2) * (1 - torch.tanh(x) ** 6)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse


####
# class CubeFlow_nn(nn.Module):
#     def __init__(self, net):
#         super().__init__()
#         # self.dim = dim
#         # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
#         self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
#         self.bias = net   # torch.randn(1)
#
#
#     def forward(self, z, x_in):
#         # x = F.linear(z, self.weight, self.bias)
#         nnbias = self.bias(x_in).squeeze()
#
#         x1 = z * 1/self.weight - nnbias/self.weight
#         x = cube_root(x1)
#         p_x = power_negative_base(x1)
#         # print('p_z: ',p_z)
#         log_det = torch.sum(torch.log((p_x / (3 * self.weight)).abs() + 1e-7))
#
#         return x, nnbias, log_det
#
#     def inverse(self, x, x_in):
#         # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
#         nnbias = self.bias(x_in).squeeze()
#
#         z = (x ** 3) * self.weight + nnbias
#
#         log_det_inverse = torch.sum(torch.log(((z ** 2) * self.weight * 3).abs() + 1e-7))
#
#         print('log_det_inverse: ', log_det_inverse)
#
#         return z, log_det_inverse

# hidden_dimsflow = [100]
# class CubeFlow_nn(nn.Module):
#     def __init__(self, net):
#         super().__init__()
#         # self.dim = dim
#         # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
#         self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
#         self.bias = net   # torch.randn(1)
#
#
#     def forward(self, z, x_in):
#         # x = F.linear(z, self.weight, self.bias)
#         nnbias = self.bias(x_in).squeeze()
#         # x = (z ** 3) * self.weight + nnbias
#         z1 = z * 1 / self.weight - nnbias / self.weight
#         x = torch.log(z1.abs() + 1e-7)
#         # log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))
#         log_det = torch.sum(torch.log((1/z1).abs() + 1e-7))
#
#         return x, nnbias, log_det
#
#     def inverse(self, x, x_in):
#         # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
#         nnbias = self.bias(x_in).squeeze()
#         print('nnbias.shape: ', nnbias.shape)
#         print('x.shape: ', x.shape)
#         z = (torch.exp(x)) * self.weight + nnbias
#         # det = torch.pow(z1, -2/3)
#         # log_det_inverse = torch.sum(-2/3 * torch.log(det.abs() + 1e-7) + torch.log(1/(3 * self.weight.abs()) + 1e-7))
#         log_det_inverse = torch.sum(x) + torch.log(self.weight.abs() + 1e-7)
#
#         return z, log_det_inverse



# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

################# S1: FVI #################################
train_loss_all = []
prior_distance_all = []
wdist_all = []
likelihood_loss_all = []

for epoch in range(1001):    # 4001

    measurement_set = sample_measurement_set(X=original_x_train_add_g3, num_data=60) # sin num_data=60, gap num_data=100
    gpss = prior_sample_functions(measurement_set, prior2, num_sample=128).detach().float().to(device)
    # gpss = prior_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
    gpss = gpss.squeeze()
    nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device) # original_x_test
    nnet_samples = nnet_samples.squeeze()

    functional_wdist = lipf(gpss, nnet_samples)
    functional_wdist = torch.mean(functional_wdist)

    ##################### variance constraint #########
    # print('og3.shape: ', original_y_train_add_g3.shape)
    # print('ot.shape: ', original_x_test.shape)
    measurement_set2 = sample_measurement_set2(X=original_x_train_add_g3, num_data=60)

    with torch.no_grad():
        gp_var = prior2(measurement_set2).variance
        # print('gp_var.shape: ', gp_var.shape)

    opt_samples = opt_bnn.sample_functions(measurement_set2, num_sample=200).float().to(device)
    opt_samples = opt_samples.squeeze()
    opt_var = torch.std(opt_samples, 0) ** 2
    # print('opt_var.shape: ', opt_var.shape)
    dif_var = (gp_var - opt_var).abs()
    dif_var = torch.sum(dif_var)
    ####

    ### calculate bnn likelihood loss ###
    pred_y = opt_bnn.forward(original_x_train)

    pred_y = pred_y.squeeze().flatten()
    # print(pred_y)

    train_loss = 10 * loss(pred_y, original_y_train) + 10 * functional_wdist + 1 * dif_var
    likelihood_loss = loss(pred_y, original_y_train)
    print('likelihood_loss: ', likelihood_loss)
    # print('distance_prior: ', bnn_2wassdist)
    print('w_dist: ', functional_wdist)
    #
    # optimisation
    opt_bnn_optimizer.zero_grad()
    train_loss.backward()
    # bnn_optimizer.step()
    scheduler.step()
    opt_bnn_optimizer.step()
    #
    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
          datetime.now().replace(microsecond=0) - start_time)

    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    if epoch % test_interval == 0:

        opt_bnn.eval()

        samples_pred_test_y = opt_bnn.forward_eval(original_x_test, num_sample).squeeze().detach()

        # mean_y_pred = pred_test_y
        mean_y_pred = torch.mean(samples_pred_test_y, 0).cpu().numpy()
        std_y_pred = torch.std(samples_pred_test_y, 0).cpu().numpy()
        test_loss = loss(samples_pred_test_y, original_y_test)

        opt_bnn.train()

        plt.clf()
        figure = plt.figure(figsize=(8, 5.5), facecolor='white')
        init_plotting()

        plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
        plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function', linewidth=2.5)
        # plt.plot(original_x_test.squeeze().cpu(), nns.cpu().numpy(), 'steelblue')

        for i in range(5):
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                             mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                             mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
        plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
        # plt.grid(True)
        # plt.tick_params(axis='both', bottom='off', top='off', left='off', right='off',
        #                 labelbottom='off', labeltop='off', labelleft='off', labelright='off')
        ######### sin #################
        # plt.axvline(-2, linestyle='--', color='darkorchid')   # darkorange
        # plt.axvline(2, linestyle='--', color='darkorchid')
        # plt.text(-1.5, -3, 'Training data range')
        ###### gap ###########
        # plt.axvline(-7.5, linestyle='--', color='darkorchid')
        # plt.axvline(-2.5, linestyle='--', color='darkorchid')
        # plt.axvline(5, linestyle='--', color='darkorchid')
        # plt.axvline(7.5, linestyle='--', color='darkorchid')
        # # plt.text(-1.5, -3, 'Training data range')
        # plt.text(-7.4, -1, 'Training data range', fontsize='small')
        # plt.text(5.1, -1, 'Training data', fontsize='xx-small')
        # plt.text(5.1, -1.8, 'range', fontsize='xx-small')

        ######g3######
        plt.axvline(-0.75, linestyle='--', color='darkorchid')
        plt.axvline(-0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.75, linestyle='--', color='darkorchid')
        plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
        plt.text(0.25, 1, 'Training data range', fontsize='small')

        plt.legend(loc='upper center', ncol=3, fontsize='small')
        plt.title('FVIMC posterior')
        plt.tight_layout()
        plt.ylim([dataset.y_min, dataset.y_max])
        plt.tight_layout()

        print("--------------------------------------------------------------------------------------------")
        plt.savefig(figures_folder + '/plot_epoch{}.pdf'.format(epoch))
        print("figure saved")
        print("epoch : {} \t\t test loss \t\t : {}".format(epoch, test_loss))
        print("--------------------------------------------------------------------------------------------")
#######


print("--------------------------------------------------------------------------------------------")
######### S2: FMCMC #####################

w_mu_list = []
b_mu_list = []
w_std_list = []
b_std_list = []

for name, p in opt_bnn.named_parameters():
    if "W_mu" in name:
        # w_prior_mu = p
        w_mu_list.append(p)
    if 'b_mu' in name:
        # b_prior_mu = p
        b_mu_list.append(p)
    if 'W_std' in name:
        # w_prior_std = p
        w_std_list.append(p)
    if 'b_std' in name:
        # b_prior_std = p
        b_std_list.append(p)


bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous,
              W_mu=w_mu_list, b_mu=b_mu_list, W_std=w_std_list, b_std=b_std_list, scaled_variance=True).to(device)

# bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

bnnmc_optimizer = torch.optim.Adam(bnnmc.parameters(), lr=0.001)  #5e-6  1

def noise_loss(lr):
    noise_loss = 0.0
    noise_std = (2/lr)**0.5
    for var in bnnmc.parameters():
        means = torch.zeros(var.size()).to(device)
        sigma = torch.normal(means, std=noise_std).to(device)
        noise_loss += torch.sum(var * sigma)
    return noise_loss

########## SGLD #####
mt = 0
for epoch in range(501):    # 2001

    # loss

    pred_y_mc1 = bnnmc.forward(original_x_train)
    # y_predicted, distance_prior = bnn.forward_w(train_x)

    pred_y_mc1 = pred_y_mc1.squeeze().flatten()

    ###################  measurement set for d_logp
    m_set = sample_measurement_set(X=original_x_train_add_g3, num_data=60)
    y_m_set = bnnmc.forward(m_set).squeeze().flatten()


    with torch.no_grad():
        fprior_m = likelihood(prior2(m_set))
        prior_mean_m = fprior_m.mean
        prior_var_m = fprior_m.variance
        # print('var_m: ', prior_var_m)

    #
    d_logp_m = ((y_m_set - prior_mean_m) / prior_var_m).detach()
    # print('d_log_m:', d_logp_m)
    d_logp_m = y_m_set * d_logp_m
    d_logp_m = torch.mean(d_logp_m)
    # print('d_log_m:', d_logp_m)
    #
    t_like = pred_y_mc1 * ((pred_y_mc1 - original_y_train).detach())
    t_like = torch.mean(t_like)

    ####################
    train_loss_mc1 = t_like + 1 * d_logp_m  #  * 1e-1

    # train_loss = loss(pred_y, original_y_train) + 1e-6 * d_logp  # coeff: 1e-6 1e-3
    if epoch > 300:  # 1000
        noise = noise_loss(lr=0.001)
        train_loss_mc1 = train_loss_mc1 + 1e-5 * noise

    # optimisation
    bnnmc_optimizer.zero_grad()
    train_loss_mc1.backward()
    bnnmc_optimizer.step()
    scheduler.step()

    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss_mc1),
          datetime.now().replace(microsecond=0) - start_time)


    if epoch % test_interval == 0:

        bnnmc.eval()

        samples_pred_test_y = bnnmc.forward_eval(original_x_test, num_sample).squeeze().detach()
        y_m = bnnmc.forward(m_set).squeeze().detach().cpu().numpy()

        # mean_y_pred = pred_test_y
        mean_y_pred = torch.mean(samples_pred_test_y, 0).cpu().numpy()
        std_y_pred = torch.std(samples_pred_test_y, 0).cpu().numpy()
        test_loss = loss(samples_pred_test_y, original_y_test)

        bnnmc.train()

        plt.clf()
        figure = plt.figure(figsize=(8, 5.5), facecolor='white')
        init_plotting()

        plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
        plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function', linewidth=2.5)  # chocolate
        for i in range(5):
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                             mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')  # wheat
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                             mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
        plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
        # plt.plot(m_set.cpu(), y_m, 'oy', zorder=10, ms=6, label='Measurement points')
        # plt.grid(True)

        ######g3######
        plt.axvline(-0.75, linestyle='--', color='darkorchid')
        plt.axvline(-0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.75, linestyle='--', color='darkorchid')
        plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
        plt.text(0.25, 1, 'Training data range', fontsize='small')


        plt.legend(loc='upper center', ncol=3, fontsize='small')
        plt.title('MCMC posterior')
        plt.tight_layout()
        plt.ylim([dataset.y_min, dataset.y_max])
        plt.tight_layout()
        print("--------------------------------------------------------------------------------------------")
        plt.savefig(figures_folder + '/plot_epoch{}.pdf'.format(epoch+1001))   # +4001
        print("figure saved")
        print("epoch : {} \t\t test loss \t\t : {}".format(epoch, test_loss))
        print("--------------------------------------------------------------------------------------------")

    ##############
    if epoch > 100 and (epoch % 4) == 0:   # 0, 20
        print('save!')
        bnnmc.cpu()
        # torch.save(bnn.state_dict(), save_path + '/g3_rbf_%i.pt'%(mt))
        dd = f'{mt}'
        print(dd)
        save_path = results_folder + '/' + bnn_name_string + '/' + 'fsgld_toy_samples' + '/' + 'g3_' + dd  # sgld
        # if not os.path.exists(save_path):
        #     os.makedirs(save_path)
        bnnmc.save(save_path)
        mt += 1
        bnnmc.to(device)


###
def testp():
    bnnmc.eval()

    pred_test_y = bnnmc.forward(original_x_test).squeeze().detach()
    return pred_test_y

def test2():
    bnnmc.eval()

    pred_mset = bnnmc.forward(original_x_train_add_g3).squeeze().detach()
    return pred_mset


################ load mcmc model ################
pred_list = []
predm_list = []
num_model = 100 # 10 100 80 70

for m in range(num_model):
    dd = f'{m}'
    bnnmc.load_state_dict(torch.load(results_folder + '/' + bnn_name_string + '/' + 'fsgld_toy_samples' + '/' + 'g3_' + dd))  # sgld, sghmc
    pred = testp()
    print('pred.shape: ', pred.shape)
    pred_list.append(pred)

    predm = test2()
    predm_list.append(predm)
    mset_fmc = torch.stack(predm_list)

########################
total_pred = torch.stack(pred_list)
mean_y_pred = torch.mean(total_pred, 0).cpu().numpy()
std_y_pred = torch.std(total_pred, 0).cpu().numpy()

plt.clf()
figure = plt.figure(figsize=(8, 5.5), facecolor='white')
init_plotting()

plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function', linewidth=2.5)  # chocolate
for i in range(5):
    plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                     mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                     alpha=1.0 - i * 0.15, color='orchid')  # wheat
    plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                     mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                     alpha=1.0 - i * 0.15, color='orchid')
plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
# plt.grid(True)
# plt.tick_params(axis='both', bottom='off', top='off', left='off', right='off',
#                 labelbottom='off', labeltop='off', labelleft='off', labelright='off')

######g3######
plt.axvline(-0.75, linestyle='--', color='darkorchid')
plt.axvline(-0.25, linestyle='--', color='darkorchid')
plt.axvline(0.25, linestyle='--', color='darkorchid')
plt.axvline(0.75, linestyle='--', color='darkorchid')
plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
plt.text(0.25, 1, 'Training data range', fontsize='small')

plt.legend(loc='upper center', ncol=3, fontsize='small')
plt.title('fsgld posterior')
plt.tight_layout()
plt.ylim([dataset.y_min, dataset.y_max])
plt.tight_layout()
print("--------------------------------------------------------------------------------------------")
plt.savefig(figures_folder + '/' + 'fsgld_posterior.pdf')



############### GP Linearization for FVI posterior ###############
opt_bnn2 = opt_bnn
opt_bnn_optimizer2 = torch.optim.Adam([{'params': opt_bnn2.parameters(), 'lr': lr_bnn}])

w_mu, w_std = opt_bnn2.sample_mu_std()
w_mu = w_mu.detach()
w_std = w_std.detach()
print('w_std.shape: ', w_std.shape)  # [10401]
w_dim = w_std.shape[-1]
print('w_dim: ', w_std.shape[-1])
w_cov = w_std ** 2 * torch.eye(w_dim).to(device)
print('w_cov.shape: ', w_cov.shape)  # [10401, 10401]

grads_x_mset = None

for x_i in original_x_train_add_g3:

    pred_y_i = opt_bnn2.forward_mu(x_i)
    pred_y_i = pred_y_i.squeeze().flatten()

    # evaluate gradient
    opt_bnn_optimizer2.zero_grad()
    pred_y_i.backward()

    grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
    grad_i = grad_i[:, None]
    print('grad_i.shape: ', grad_i.shape)

    if grads_x_mset is None:
        grads_x_mset = torch.transpose(grad_i, 0, 1)
    else:
        grads_x_mset = torch.cat((grads_x_mset, torch.transpose(grad_i, 0, 1)), 0)

print('grads_x_mset.shape: ', grads_x_mset.shape)

prior_mean_m = opt_bnn2.forward_mu(original_x_train_add_g3).squeeze().detach()
print('prior_mean_m: ', prior_mean_m.shape)

prior_var_m = grads_x_mset @ w_cov @ torch.transpose(grads_x_mset, 0, 1)
print('prior_var_m: ', prior_var_m.shape)

prior_var_m = torch.diag(prior_var_m, 0)
print('prior_var_m: ', prior_var_m.shape)
##

grads_x_test = None

for x_i in original_x_test:

    pred_y_i = opt_bnn2.forward_mu(x_i)
    pred_y_i = pred_y_i.squeeze().flatten()

    # evaluate gradient
    opt_bnn_optimizer2.zero_grad()
    pred_y_i.backward()

    grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
    grad_i = grad_i[:, None]
    # print('grad_i.shape: ', grad_i.shape)

    if grads_x_test is None:
        grads_x_test = torch.transpose(grad_i, 0, 1)
    else:
        grads_x_test = torch.cat((grads_x_test, torch.transpose(grad_i, 0, 1)), 0)

print('grads_x_test.shape: ', grads_x_test.shape)

prior_var_test = grads_x_test @ w_cov @ torch.transpose(grads_x_test, 0, 1)
print('prior_var_test: ', prior_var_test.shape)

prior_var_test = torch.diag(prior_var_test, 0)
print('prior_var_test: ', prior_var_test.shape)
print('gp_linear_std: ', torch.mean(prior_var_test))  # 18.0297

prior_mean_test = opt_bnn2.forward_mu(original_x_test).squeeze().detach()

with torch.no_grad():
    plt.clf()
    figure = plt.figure(figsize=(8, 5.5), facecolor='white')
    init_plotting()
    linear_gp = opt_bnn2.forward_mu(original_x_test).squeeze().cpu().numpy()
    f_var = prior_var_test
    f_std = f_var ** 0.5
    plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
    plt.plot(original_x_test.squeeze().cpu(), linear_gp, 'slateblue', label='GP posterior', linewidth=2.5)

    #################
    for i in range(5):
        plt.fill_between(original_x_test.squeeze().cpu(), linear_gp - i * 1 * f_std.cpu().numpy(),
                         linear_gp - (i + 1) * 1 * f_std.cpu().numpy(), linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')  # wheat
        plt.fill_between(original_x_test.squeeze().cpu(), linear_gp + i * 1 * f_std.cpu().numpy(),
                         linear_gp + (i + 1) * 1 * f_std.cpu().numpy(), linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')
    plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')

    ######g3######
    plt.axvline(-0.75, linestyle='--', color='darkorchid')
    plt.axvline(-0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.75, linestyle='--', color='darkorchid')
    plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
    plt.text(0.25, 1, 'Training data range', fontsize='small')

    plt.legend(loc='upper center', ncol=3, fontsize='small')
    plt.title('GP linear approximation')
    plt.tight_layout()
    plt.ylim([dataset.y_min, dataset.y_max])
    plt.tight_layout()
    plt.savefig(figures_folder + '/gp_linear.pdf')


######################### train LinerFolw_nn ##############

flownet = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(device)
flownet2 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(device)

flow = LinearFlow_nn(flownet).to(device)
# flow = CubeFlow_nn(flownet).to(device)
# flow = NonLinearFlow_nn(flownet, flownet2).to(device)
# flow = SigmoidFlow_nn(flownet).to(device)
# flow = ExpFlow_nn(flownet).to(device)
flow_optimizer = torch.optim.Adam(flow.parameters(), lr=0.01)


loss_flow_all = []
flow_likelihood_list = []
flow_det_list = []
flow_loss_list = []

for epoch in range(5000):
    Z, logdet = flow.inverse(mset_fmc, original_x_train_add_g3)
    # print('Z: ', Z)
    contains_nan = torch.isnan(Z).any()
    print(contains_nan)

    loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
        torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet)

    flow_optimizer.zero_grad()
    loss_flow.backward(retain_graph=True)
    flow_optimizer.step()
    loss_flow_all.append(loss_flow.item())


    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_flow),
          datetime.now().replace(microsecond=0) - start_time)

flow.eval()
a = flow.weight

print('flow.weight: ', flow.weight)
# print('flow.weight2: ', flow.weight2)

flow_likelihood = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
        torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1))
flow_det = torch.mean(logdet)
flow_likelihood_list.append(flow_likelihood)
flow_det_list.append(flow_det)
flow_loss_list.append(loss_flow)

####

plt.figure()
indices = np.arange(5000)[::50]
loss_flow_all = np.array(loss_flow_all)
plt.plot(indices, loss_flow_all[indices], '-ko', ms=3)
plt.ylabel(r'flow loss')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/flow_loss.pdf')

##
nnet_samples = opt_bnn.sample_functions(original_x_train_add_g3, num_sample=128).float().to(device)  # original_x_test
nnet_samples = nnet_samples.squeeze().detach()

fvi, nnbias_m, flogdet = flow(nnet_samples, original_x_train_add_g3)

##

nnet_samples = opt_bnn.sample_functions(original_x_test, num_sample=100).float().to(device)  # original_x_test
nnet_samples = nnet_samples.squeeze().detach()
nnet_mean = torch.mean(nnet_samples, 0).cpu().numpy()

# fvi = flow2(gpss).detach()
fvi, nnbias_test, _ = flow(nnet_samples, original_x_test)
fvi = fvi.detach()
print('fvi.shape: ', fvi.shape)

mean_y_pred = torch.mean(fvi, 0).cpu().numpy()
std_y_pred = torch.std(fvi, 0).cpu().numpy()

with torch.no_grad():
    plt.clf()
    figure = plt.figure(figsize=(8, 5.5), facecolor='white')
    init_plotting()

    # plt.plot(original_x_train_add_g3.squeeze().cpu(), original_y_train_add_g3.cpu(), 'g', label="True function")
    plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
    plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function',
             linewidth=2.5)  # chocolate
    # plt.plot(original_x_test.squeeze().cpu(), nnet_mean, 'y', label='FVI Mean function')
    for i in range(5):
        plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                         mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')  # wheat
        plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                         mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')
    plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
    # plt.plot(original_x_train_add_g3.cpu(), original_y_train_add_g3.cpu(), 'or', zorder=10, ms=6, label='Measurement points')

    ######g3######
    plt.axvline(-0.75, linestyle='--', color='darkorchid')
    plt.axvline(-0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.75, linestyle='--', color='darkorchid')
    plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
    plt.text(0.25, 1, 'Training data range', fontsize='small')

    plt.legend(loc='upper center', ncol=3, fontsize='small')
    plt.title('Flow approximation')
    plt.tight_layout()
    plt.ylim([dataset.y_min, dataset.y_max])
    plt.tight_layout()
    plt.savefig(figures_folder + '/flow approximation_GPLinear_test.pdf')


################################# FVI-MCMC LOOP #######################

## GP mean linear transfer ##
class Mean_transfer(nn.Module):
    def __init__(self, a, b):
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x):
        return self.a * x + self.b

class Var_transfer(nn.Module):
    def __init__(self, a):
        super().__init__()
        self.a = a

    def forward(self, x):
        return (self.a ** 2) * x

### log normal mean transfer
class Log_Mean_transfer(nn.Module):
    def __init__(self, a, b):
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x):
        t_mean = (x - self.b) / self.a
        return torch.log(t_mean.abs() + 1e-7)


class Log_Var_transfer(nn.Module):
    def __init__(self, a):
        super().__init__()
        self.a = a

    def forward(self, x):
        return  x / (self.a ** 2)


flow_list = [flow]
a_list = [a]
nnbias_m_list = [nnbias_m]
nnbias_test_list = [nnbias_test]
# logdet_list = [flogdet]
cum_det_list = []
flow_name_list = []

mean_transfer_list_m = [Mean_transfer(a, b) for a, b in zip(a_list, nnbias_m_list)]
# mean_transfer_list_m = [Log_Mean_transfer(a, b) for a, b in zip(a_list, nnbias_m_list)]
var_transfer_list = [Var_transfer(a) for a in a_list]
# var_transfer_list = [Log_Var_transfer(a) for a in a_list]

mean_transfer_list_t = [Mean_transfer(a, b) for a, b in zip(a_list, nnbias_test_list)]
# mean_transfer_list_t = [Log_Mean_transfer(a, b) for a, b in zip(a_list, nnbias_test_list)]

all_train_losses = []
all_wdist_losses = []
all_likelihood_losses = []
all_m2_losses = []
##
for i in range(13):    # 13
    loop_count = 1500 * i
    print('loop_count: ', loop_count)
    outloop = i
    ################# S1: FVI #############################
    train_loss_all = []
    prior_distance_all = []
    wdist_all = []
    likelihood_loss_all = []
    dif_var_all = []
    opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': 0.1}])   #lr_bnn

    for epoch in range(1001):  # 1000 4001

        pred_y_vi2 = opt_bnn.forward(original_x_train)
        print('pred_y_vi2.shape: ', pred_y_vi2.shape)
        pred_y_vi2 = pred_y_vi2.t()

        for flow in flow_list:

            pred_y_vi2, _, _ = flow(pred_y_vi2, original_x_train)
            # print('pred_y_vi2: ', pred_y_vi2)

        print('pred_y_vi2.shape: ', pred_y_vi2.shape)
        pred_y_vi2 = pred_y_vi2.squeeze().flatten()

        #####
        measurement_set = original_x_train_add_g3
        gpss = prior_sample_functions(measurement_set, prior2, num_sample=128).detach().float().to(device)
        # gpss = prior_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
        gpss = gpss.squeeze()
        nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
        nnet_samples = nnet_samples.squeeze()

        for flow in flow_list:

            nnet_samples, _, _ = flow(nnet_samples, original_x_train_add_g3)

        print('nnet_samples.shape: ', nnet_samples.shape)
        functional_wdist = lipf(gpss, nnet_samples)
        functional_wdist = torch.mean(functional_wdist)

        #####
        # measurement_set2 = sample_measurement_set2(X=original_x_train_add_g3, num_data=60)
        measurement_set2 = original_x_train_add_g3

        with torch.no_grad():
            gp_var = prior2(measurement_set2).variance
            # print('gp_var.shape: ', gp_var.shape)

        opt_samples = opt_bnn.sample_functions(measurement_set2, num_sample=200).float().to(device)
        opt_samples = opt_samples.squeeze()

        for flow in flow_list:

            opt_samples, _, _ = flow(opt_samples, original_x_train_add_g3)

        opt_var = torch.std(opt_samples, 0) ** 2
        # print('opt_var.shape: ', opt_var.shape)
        dif_var = (gp_var - opt_var).abs()
        dif_var = torch.sum(dif_var)
        # dif_var = torch.mean(dif_var)

        ######
        # if outloop < 12:
        #     train_loss_vi2 = 10 * loss(pred_y_vi2, original_y_train) + 10 * functional_wdist + dif_var
        #
        # else:
        #     train_loss_vi2 = 100 * loss(pred_y_vi2, original_y_train) + 100 * functional_wdist + 10 * dif_var

        train_loss_vi2 = 100 * loss(pred_y_vi2, original_y_train) + 100 * functional_wdist + 10 * dif_var    # 10_10_1
        likelihood_loss = loss(pred_y_vi2, original_y_train)
        print('likelihood_loss: ', likelihood_loss)
        print('functional_wdist: ', functional_wdist)

        # optimisation
        opt_bnn_optimizer.zero_grad()
        train_loss_vi2.backward()
        opt_bnn_optimizer.step()

        #
        train_loss_all.append(train_loss_vi2.item())
        likelihood_loss_all.append(likelihood_loss.item())
        wdist_all.append(functional_wdist.item())
        dif_var_all.append(dif_var.item())


        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss_vi2),
              datetime.now().replace(microsecond=0) - start_time)

        if epoch % 500 == 0:  # 2000 500

            opt_bnn.eval()

            samples_pred_test_y = opt_bnn.forward_eval(original_x_test, num_sample).squeeze().detach()
            for flow in flow_list:

                samples_pred_test_y, _, _ = flow(samples_pred_test_y, original_x_test)

            samples_pred_test_y = samples_pred_test_y.detach()

            # mean_y_pred = pred_test_y
            mean_y_pred = torch.mean(samples_pred_test_y, 0).cpu().numpy()
            std_y_pred = torch.std(samples_pred_test_y, 0).cpu().numpy()
            test_loss = loss(samples_pred_test_y, original_y_test)

            opt_bnn.train()

            plt.clf()
            figure = plt.figure(figsize=(8, 5.5), facecolor='white')
            init_plotting()

            plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
            plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function',
                     linewidth=2.5)  # chocolate
            for i in range(5):
                plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                                 mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                                 alpha=1.0 - i * 0.15, color='orchid')  # wheat
                plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                                 mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                                 alpha=1.0 - i * 0.15, color='orchid')
            plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
            # plt.plot(original_x_train_add_g3.cpu(), original_y_train_add_g3.cpu(), 'oy', zorder=10, ms=6, label='M points')
            # plt.grid(True)
            # plt.tick_params(axis='both', bottom='off', top='off', left='off', right='off',
            #                 labelbottom='off', labeltop='off', labelleft='off', labelright='off')
            ############sin#########
            # plt.axvline(-2, linestyle='--', color='darkorchid')   # darkorange
            # plt.axvline(2, linestyle='--', color='darkorchid')
            # plt.text(-1.5, -3, 'Training data range')
            #######gap
            # plt.axvline(-7.5, linestyle='--', color='darkorchid')
            # plt.axvline(-2.5, linestyle='--', color='darkorchid')
            # plt.axvline(5, linestyle='--', color='darkorchid')
            # plt.axvline(7.5, linestyle='--', color='darkorchid')
            # plt.text(-7.4, -1, 'Training data range', fontsize='small')
            # plt.text(5.1, -1, 'Training data', fontsize='xx-small')
            # plt.text(5.1, -1.8, 'range', fontsize='xx-small')

            ######g3######
            plt.axvline(-0.75, linestyle='--', color='darkorchid')
            plt.axvline(-0.25, linestyle='--', color='darkorchid')
            plt.axvline(0.25, linestyle='--', color='darkorchid')
            plt.axvline(0.75, linestyle='--', color='darkorchid')
            plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
            plt.text(0.25, 1, 'Training data range', fontsize='small')

            plt.legend(loc='upper center', ncol=3, fontsize='small')
            plt.title('FVIMC posterior')
            plt.tight_layout()
            plt.ylim([dataset.y_min, dataset.y_max])
            plt.tight_layout()
            print("--------------------------------------------------------------------------------------------")
            plt.savefig(figures_folder + '/plot_epoch{}.pdf'.format(epoch + 1502 + loop_count))   # + 6002
            print("figure saved")
            print("epoch : {} \t\t test loss \t\t : {}".format(epoch, test_loss))
            print("--------------------------------------------------------------------------------------------")

    all_train_losses.append(train_loss_all)
    all_wdist_losses.append(wdist_all)
    all_likelihood_losses.append(likelihood_loss_all)
    all_m2_losses.append(dif_var_all)

    ###########################
    plt.figure()
    indices = np.arange(1001)[::10]
    train_loss_all = np.array(train_loss_all)
    plt.plot(indices, train_loss_all[indices], '-ko', ms=3)
    plt.ylabel(r'Training loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(figures_folder + '/train_loss.pdf')

    plt.figure()
    indices = np.arange(1001)[::10]
    likelihood_loss_all = np.array(likelihood_loss_all)
    plt.plot(indices, likelihood_loss_all[indices], '-ko', ms=3)
    plt.ylabel(r'Likelihood loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(figures_folder + '/likelihood_loss.pdf')

    plt.figure()
    indices = np.arange(1001)[::10]
    wdist_all = np.array(wdist_all)
    plt.plot(indices, wdist_all[indices], '-ko', ms=3)
    # plt.ylabel(r'$W_1(p_{gp}, p_{prior bnn})$')
    plt.ylabel(r'$W_1(p_{f}, p_{g})$')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(figures_folder + '/wdist.pdf')


    ################## MCMC #################
    w_mu_list = []
    b_mu_list = []
    w_std_list = []
    b_std_list = []

    for name, p in opt_bnn.named_parameters():
        if "W_mu" in name:
            # w_prior_mu = p
            w_mu_list.append(p)
        if 'b_mu' in name:
            # b_prior_mu = p
            b_mu_list.append(p)
        if 'W_std' in name:
            # w_prior_std = p
            w_std_list.append(p)
        if 'b_std' in name:
            # b_prior_std = p
            b_std_list.append(p)

    bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous,
                  W_mu=w_mu_list, b_mu=b_mu_list, W_std=w_std_list, b_std=b_std_list, scaled_variance=True).to(device)

    # bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

    bnnmc_optimizer = torch.optim.Adam(bnnmc.parameters(), lr=0.1)  # 5e-6  1 lr=0.001


    def noise_loss(lr):
        noise_loss = 0.0
        noise_std = (2 / lr) ** 0.5
        for var in bnnmc.parameters():
            means = torch.zeros(var.size()).to(device)
            sigma = torch.normal(means, std=noise_std).to(device)
            noise_loss += torch.sum(var * sigma)
        return noise_loss


    ######### SGLD #####
    mt = 0
    for epoch in range(501):   # 500 2001
        pred_y_mc1 = bnnmc.forward(original_x_train)

        pred_y_mc1 = pred_y_mc1.t()

        for flow in flow_list:

            pred_y_mc1, _, _ = flow(pred_y_mc1, original_x_train)

        pred_y_mc1 = pred_y_mc1.squeeze().flatten()

        ######

        m_set = original_x_train_add_g3
        y_m_set = bnnmc.forward(m_set).t()

        for flow in flow_list:

            y_m_set, _, _ = flow(y_m_set, original_x_train_add_g3)

        y_m_set = y_m_set.squeeze().flatten()

        with torch.no_grad():
            fprior_m = likelihood(prior2(m_set))
            prior_mean_m = fprior_m.mean
            prior_var_m = fprior_m.variance

        d_logp_m = ((y_m_set - prior_mean_m) / prior_var_m).detach()
        # print('d_log_m:', d_logp_m)
        d_logp_m = y_m_set * d_logp_m
        d_logp_m = torch.mean(d_logp_m)

        ######

        t_like = pred_y_mc1 * ((pred_y_mc1 - original_y_train).detach())
        t_like = torch.mean(t_like)

        train_loss_mc1 = t_like + 1 * d_logp_m  # 1e-6  1e-2   1e-3
        # if i > 4:
        #     train_loss_mc1 = 10 ** (i-3) * train_loss_mc1
        # if outloop < 5:
        #     if epoch > 100:  # 1500  100
        #         noise = noise_loss(lr_bnn)
        #         train_loss_mc1 = train_loss_mc1 + 1e-3 * noise
        #
        # else:
        #     if epoch > 100:  # 1500  100
        #         noise = noise_loss(lr_bnn)
        #         train_loss_mc1 = train_loss_mc1 + 1e-5 * noise

        # train_loss = loss(pred_y, original_y_train) + 1e-6 * d_logp  # coeff: 1e-6 1e-3
        if epoch > 100:  # 1500  100
            noise = noise_loss(lr_bnn)
            train_loss_mc1 = train_loss_mc1 + 1e-5 * noise   #1e-5

        # optimisation
        bnnmc_optimizer.zero_grad()
        train_loss_mc1.backward()
        bnnmc_optimizer.step()
        # scheduler.step()

        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss_mc1),
              datetime.now().replace(microsecond=0) - start_time)

        if epoch % 500 == 0:    # 2000   500

            bnnmc.eval()

            ##
            nnet_samples = bnnmc.sample_functions(original_x_test, num_sample=1).float().to(device)  # original_x_test
            nnet_samples = nnet_samples.squeeze().detach()

            for flow in flow_list:

                nnet_samples, _, _ = flow(nnet_samples, original_x_test)

            nnet_samples = nnet_samples.detach()
            print('fvi.shape: ', nnet_samples.shape)

            mean_y_pred = torch.mean(nnet_samples.unsqueeze(0), 0).cpu().numpy()
            print('mean_y_pred.shape: ', mean_y_pred.shape)
            std_y_pred = torch.std(nnet_samples.unsqueeze(0), 0).cpu().numpy()

            bnnmc.train()

            plt.clf()
            figure = plt.figure(figsize=(8, 5.5), facecolor='white')
            init_plotting()

            plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
            plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function',
                     linewidth=2.5)  # chocolate
            for i in range(5):
                plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                                 mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                                 alpha=1.0 - i * 0.15, color='orchid')  # wheat
                plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                                 mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                                 alpha=1.0 - i * 0.15, color='orchid')
            plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
            # plt.grid(True)


            ######g3######
            plt.axvline(-0.75, linestyle='--', color='darkorchid')
            plt.axvline(-0.25, linestyle='--', color='darkorchid')
            plt.axvline(0.25, linestyle='--', color='darkorchid')
            plt.axvline(0.75, linestyle='--', color='darkorchid')
            plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
            plt.text(0.25, 1, 'Training data range', fontsize='small')

            plt.legend(loc='upper center', ncol=3, fontsize='small')
            plt.title('MCMC posterior')
            plt.tight_layout()
            plt.ylim([dataset.y_min, dataset.y_max])
            plt.tight_layout()
            print("--------------------------------------------------------------------------------------------")
            plt.savefig(figures_folder + '/plot_epoch{}.pdf'.format(epoch + 2503 + loop_count))    # +2503
            print("figure saved")
            print("epoch : {} \t\t test loss \t\t : {}".format(epoch, test_loss))
            print("--------------------------------------------------------------------------------------------")

        ##############
        if epoch > 0 and (epoch % 5) == 0:  # (0,20)  (0, 5)
            print('save!')
            bnnmc.cpu()
            # torch.save(bnn.state_dict(), save_path + '/g3_rbf_%i.pt'%(mt))
            dd = f'{mt}'
            print(dd)
            save_path = results_folder + '/' + bnn_name_string + '/' + 'fsgld_toy_samples1' + '/' + 'g3_' + dd  # sgld
            # if not os.path.exists(save_path):
            #     os.makedirs(save_path)
            bnnmc.save(save_path)
            mt += 1
            bnnmc.to(device)


    ###
    def testp():
        bnnmc.eval()

        pred_test_y = bnnmc.forward(original_x_test).squeeze().detach()
        for flow in flow_list:

            pred_test_y, _, _ = flow(pred_test_y, original_x_test)

        pred_test_y = pred_test_y.squeeze().detach()

        return pred_test_y


    def test2():
        bnnmc.eval()

        pred_mset = bnnmc.forward(original_x_train_add_g3).squeeze().detach()
        for flow in flow_list:

            pred_mset, _, _ = flow(pred_mset, original_x_train_add_g3)

        pred_mset = pred_mset.detach()

        return pred_mset

    ################ load mcmc model ################
    pred_list = []
    predm_list = []
    num_model = 100  # 10 100 80 70

    for m in range(num_model):
        dd = f'{m}'
        bnnmc.load_state_dict(torch.load(
            results_folder + '/' + bnn_name_string + '/' + 'fsgld_toy_samples1' + '/' + 'g3_' + dd))  # sgld, sghmc
        pred = testp()
        print('pred.shape: ', pred.shape)
        pred_list.append(pred)

        predm = test2()
        print('predm.shape: ', predm.shape)
        predm_list.append(predm)

    ########################
    total_pred = torch.stack(pred_list)
    mset_fmc = torch.stack(predm_list)
    mean_y_pred = torch.mean(total_pred, 0).cpu().numpy()
    std_y_pred = torch.std(total_pred, 0).cpu().numpy()

    plt.clf()
    figure = plt.figure(figsize=(8, 5.5), facecolor='white')
    init_plotting()

    plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
    plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function',
             linewidth=2.5)  # chocolate
    for i in range(5):
        plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                         mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')  # wheat
        plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                         mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                         alpha=1.0 - i * 0.15, color='orchid')
    plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
    # plt.grid(True)
    # plt.tick_params(axis='both', bottom='off', top='off', left='off', right='off',
    #                 labelbottom='off', labeltop='off', labelleft='off', labelright='off')

    ######g3######
    plt.axvline(-0.75, linestyle='--', color='darkorchid')
    plt.axvline(-0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.25, linestyle='--', color='darkorchid')
    plt.axvline(0.75, linestyle='--', color='darkorchid')
    plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
    plt.text(0.25, 1, 'Training data range', fontsize='small')

    # y_range = np.linspace(-2, 2, 500)
    # plt.fill_betweenx(y_range, -0.75, -0.25, color='slategray', alpha=0.3)
    # plt.fill_betweenx(y_range, 0.25, 0.75, color='slategray', alpha=0.3)

    plt.legend(loc='upper center', ncol=3, fontsize='small')
    plt.title('fsgld posterior')
    plt.tight_layout()
    plt.ylim([dataset.y_min, dataset.y_max])
    plt.tight_layout()
    print("--------------------------------------------------------------------------------------------")
    # plt.savefig(figures_folder + '/' + 'fsgld_posterior2.pdf')
    plt.savefig(figures_folder + '/fsgld_posterior_loop{}.pdf'.format(outloop+2))



    ############################# GP Linearization2 ###################################
    opt_bnn2 = opt_bnn
    opt_bnn_optimizer2 = torch.optim.Adam([{'params': opt_bnn2.parameters(), 'lr': lr_bnn}])

    ##
    w_mu, w_std = opt_bnn2.sample_mu_std()
    w_mu = w_mu.detach()
    w_std = w_std.detach()
    print('w_std.shape: ', w_std.shape)  # [10401]
    w_dim = w_std.shape[-1]
    print('w_dim: ', w_std.shape[-1])
    w_cov = w_std ** 2 * torch.eye(w_dim).to(device)
    print('w_cov.shape: ', w_cov.shape)  # [10401, 10401]

    ###
    grads_x_mset = None

    for x_i in original_x_train_add_g3:
        pred_y_i = opt_bnn2.forward_mu(x_i)
        pred_y_i = pred_y_i.squeeze().flatten()

        # evaluate gradient
        opt_bnn_optimizer2.zero_grad()
        pred_y_i.backward()

        grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
        grad_i = grad_i[:, None]

        if grads_x_mset is None:
            grads_x_mset = torch.transpose(grad_i, 0, 1)
        else:
            grads_x_mset = torch.cat((grads_x_mset, torch.transpose(grad_i, 0, 1)), 0)

    print('grads_x_mset.shape: ', grads_x_mset.shape)  # [60, 10401]

    ##
    prior_mean_m = opt_bnn2.forward_mu(original_x_train_add_g3).squeeze().detach()

    # for mean_trans in mean_transfer_list_m:
    #
    #     prior_mean_m = mean_trans(prior_mean_m)
    #
    # print('prior_mean_m: ', prior_mean_m.shape)

    ##
    prior_var_m = grads_x_mset @ w_cov @ torch.transpose(grads_x_mset, 0, 1)
    prior_var_m = torch.diag(prior_var_m, 0)
    #
    # for var_trans in var_transfer_list:
    #
    #     prior_var_m = var_trans(prior_var_m)
    #
    # print('prior_var_m: ', prior_var_m.shape)
    #
    # ###
    grads_x_test = None

    for x_i in original_x_test:
        pred_y_i = opt_bnn2.forward_mu(x_i)
        pred_y_i = pred_y_i.squeeze().flatten()

        # evaluate gradient
        opt_bnn_optimizer2.zero_grad()
        pred_y_i.backward()

        grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
        grad_i = grad_i[:, None]
        # print('grad_i.shape: ', grad_i.shape)

        if grads_x_test is None:
            grads_x_test = torch.transpose(grad_i, 0, 1)
        else:
            grads_x_test = torch.cat((grads_x_test, torch.transpose(grad_i, 0, 1)), 0)

    print('grads_x_test.shape: ', grads_x_test.shape)

    ##
    prior_mean_test = opt_bnn2.forward_mu(original_x_test).squeeze().detach()

    for mean_trans in mean_transfer_list_t:

        prior_mean_test = mean_trans(prior_mean_test)

    print('prior_mean_test: ', prior_mean_test.shape)

    ##
    prior_var_test = grads_x_test @ w_cov @ torch.transpose(grads_x_test, 0, 1)
    prior_var_test = torch.diag(prior_var_test, 0)

    for var_trans in var_transfer_list:
        prior_var_test = var_trans(prior_var_test)

    print('gp_linear_std: ', torch.mean(prior_var_test))

    with torch.no_grad():
        plt.clf()
        figure = plt.figure(figsize=(8, 5.5), facecolor='white')
        init_plotting()
        linear_gp = opt_bnn2.forward_mu(original_x_test).squeeze().cpu().numpy()
        print('p_gp_mean.shape: ', linear_gp.shape)
        f_var = prior_var_test
        f_std = f_var ** 0.5
        plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
        plt.plot(original_x_test.squeeze().cpu(), linear_gp, 'y', label='vi posterior', linewidth=2.5)
        plt.plot(original_x_test.squeeze().cpu(), prior_mean_test.cpu().numpy(), 'slateblue', label='GP posterior', linewidth=2.5)
        # for i in range(5):
        #     plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() - i * 1 * p_gp.variance.cpu().numpy(),
        #                      p_gp.mean.cpu().numpy() - (i + 1) * 1 * p_gp.variance.cpu().numpy(), linewidth=0.0,
        #                      alpha=1.0 - i * 0.15, color='orchid')  # wheat
        #     plt.fill_between(original_x_test.squeeze().cpu(), p_gp.mean.cpu().numpy() + i * 1 * p_gp.variance.cpu().numpy(),
        #                      p_gp.mean.cpu().numpy() + (i + 1) * 1 * p_gp.variance.cpu().numpy(), linewidth=0.0,
        #                      alpha=1.0 - i * 0.15, color='orchid')
        # plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
        #################
        for i in range(5):
            plt.fill_between(original_x_test.squeeze().cpu(), prior_mean_test.cpu().numpy() - i * 1 * f_std.cpu().numpy(),
                             prior_mean_test.cpu().numpy() - (i + 1) * 1 * f_std.cpu().numpy(), linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')  # wheat
            plt.fill_between(original_x_test.squeeze().cpu(), prior_mean_test.cpu().numpy() + i * 1 * f_std.cpu().numpy(),
                             prior_mean_test.cpu().numpy() + (i + 1) * 1 * f_std.cpu().numpy(), linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
        plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')

        ######g3######
        plt.axvline(-0.75, linestyle='--', color='darkorchid')
        plt.axvline(-0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.75, linestyle='--', color='darkorchid')
        plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
        plt.text(0.25, 1, 'Training data range', fontsize='small')

        plt.legend(loc='upper center', ncol=3, fontsize='small')
        plt.title('GP linear approximation')
        plt.tight_layout()
        plt.ylim([dataset.y_min, dataset.y_max])
        plt.tight_layout()
        plt.savefig(figures_folder + '/gp_linear2.pdf')



    ############################ train flow2 ###################
    # flownet3 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
    #     device)
    # flownet4 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
    #     device)
    #
    # # flow2 = LinearFlow_nn(flownet3).to(device)
    # flow2 = NonLinearFlow_nn(flownet3, flownet4).to(device)
    # # flow2 = SigmoidFlow_nn(flownet3).to(device)
    # # flow2 = ExpFlow_nn(flownet3).to(device)
    # # flow2 = CubeFlow_nn(flownet2).to(device)
    # flow_optimizer2 = torch.optim.Adam(flow2.parameters(), lr=0.01)
    #
    # # if outloop < 11:
    # #     flownet2 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
    # #         device)
    # #
    # #     flow2 = LinearFlow_nn(flownet2).to(device)
    # #     # flow2 = PlanarFlow_nn(flownet2).to(device)
    # #     flow_optimizer2 = torch.optim.Adam(flow2.parameters(), lr=0.01)
    # #
    # # else:
    # #     flownet2 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
    # #         device)
    # #
    # #     flow2 = CubeFlow_nn(flownet2).to(device)
    # #     # flow2 = PlanarFlow_nn(flownet2).to(device)
    # #     flow_optimizer2 = torch.optim.Adam(flow2.parameters(), lr=0.01)
    #
    #
    # ######
    # loss_flow_all = []
    #
    # for epoch in range(5000):
    #     Z, logdet = flow2.inverse(mset_fmc, original_x_train_add_g3)
    #     # print('Z.shape: ', Z.shape)
    #
    #     # loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
    #     #     torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet)
    #
    #     #
    #     # cum_det = torch.stack(logdet_list)
    #     # cum_det = torch.sum(cum_det)
    #     logdet_list = []
    #     flow_list_inv = flow_list[::-1]
    #
    #     for flow in flow_list_inv:
    #
    #         Z, logdetf = flow.inverse(Z, original_x_train_add_g3)
    #
    #         logdetf = torch.mean(logdetf)
    #         logdet_list.append(logdetf)
    #     # print('logdet_list: ', logdet_list)
    #     cum_det = torch.stack(logdet_list)
    #     cum_det = torch.sum(cum_det)
    #
    #     loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
    #         torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet) - cum_det
    #
    #     # print('sum_var_loss: ', torch.sum(torch.log(prior_var_m.abs())))
    #     # print('mean_loss: ', torch.mean(torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1)))
    #     # print('logdet: ', logdet)
    #     # print('cum_det: ', cum_det)
    #
    #     flow_optimizer2.zero_grad()
    #     loss_flow.backward(retain_graph=True)
    #     flow_optimizer2.step()
    #     loss_flow_all.append(loss_flow.item())
    #
    #     # print
    #     print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_flow),
    #           datetime.now().replace(microsecond=0) - start_time)

    ############## choose different flow ############

    flownet3 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
        device)
    flownet4 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
    device)
    flownet5 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
        device)
    # flownet6 = BNNF(input_dim, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
    #     device)

    model_a = LinearFlow_nn(flownet3).to(device)
    model_b = NonLinearFlow_nn(flownet4, flownet5).to(device)
    # model_c = SigmoidFlow_nn(flownet6).to(device)
    # model_c = ExpFlow_nn(flownet3).to(device)
    model_a_optimizer = torch.optim.Adam(model_a.parameters(), lr=0.01)
    model_b_optimizer = torch.optim.Adam(model_b.parameters(), lr=0.01)
    # model_c_optimizer = torch.optim.Adam(model_c.parameters(), lr=0.01)

    ######
    loss_a_all = []

    for epoch in range(5000):
        Z, logdet = model_a.inverse(mset_fmc, original_x_train_add_g3)

        logdet_list = []
        flow_list_inv = flow_list[::-1]

        for flow in flow_list_inv:
            Z, logdetf = flow.inverse(Z, original_x_train_add_g3)
            logdetf = torch.mean(logdetf)
            logdet_list.append(logdetf)
        # print('logdet_list: ', logdet_list)
        cum_det = torch.stack(logdet_list)
        cum_det = torch.sum(cum_det)

        loss_a = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
            torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet) - cum_det



        model_a_optimizer.zero_grad()
        loss_a.backward(retain_graph=True)
        model_a_optimizer.step()
        loss_a_all.append(loss_a.item())

        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_a),
              datetime.now().replace(microsecond=0) - start_time)

    #########
    loss_b_all = []
    for epoch in range(5000):
        Z, logdet = model_b.inverse(mset_fmc, original_x_train_add_g3)

        logdet_list = []
        flow_list_inv = flow_list[::-1]

        for flow in flow_list_inv:
            Z, logdetf = flow.inverse(Z, original_x_train_add_g3)
            logdetf = torch.mean(logdetf)
            logdet_list.append(logdetf)
        # print('logdet_list: ', logdet_list)
        cum_det = torch.stack(logdet_list)
        cum_det = torch.sum(cum_det)

        loss_b = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
            torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet) - cum_det


        model_b_optimizer.zero_grad()
        loss_b.backward(retain_graph=True)
        model_b_optimizer.step()
        loss_b_all.append(loss_b.item())

        print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_b),
              datetime.now().replace(microsecond=0) - start_time)

     ##########
    # loss_c_all = []
    # for epoch in range(5000):
    #     Z, logdet = model_c.inverse(mset_fmc, original_x_train_add_g3)
    #
    #     logdet_list = []
    #     flow_list_inv = flow_list[::-1]
    #
    #     for flow in flow_list_inv:
    #         Z, logdetf = flow.inverse(Z, original_x_train_add_g3)
    #         logdetf = torch.mean(logdetf)
    #         logdet_list.append(logdetf)
    #     # print('logdet_list: ', logdet_list)
    #     cum_det = torch.stack(logdet_list)
    #     cum_det = torch.sum(cum_det)
    #
    #     loss_c = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
    #         torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet) - cum_det
    #
    #
    #     model_c_optimizer.zero_grad()
    #     loss_c.backward(retain_graph=True)
    #     model_c_optimizer.step()
    #     loss_c_all.append(loss_c.item())
    #
    #     # print
    #     print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_c),
    #           datetime.now().replace(microsecond=0) - start_time)

    ######
    min_loss = min(loss_a, loss_b)   # loss_c
    if min_loss == loss_a:
        flow2 = model_a
        loss_flow_all = loss_a_all
        print('model A is the best model')
        flow_name = 'Linear'
    # elif min_loss == loss_b:
    #     flow2 = model_b
    #     loss_flow_all = loss_b_all
    #     print('model B is the best model')
    #     flow_name = 'Tanh'
    else:
        flow2 = model_b
        loss_flow_all = loss_b_all
        print('model B is the best model')
        flow_name = 'Tanh'   # sigmoid

    flow_name_list.append(flow_name)


    ############################
    flow2.eval()

    print('flow2.weight: ', flow2.weight)
    a2 = flow2.weight

    # flow_list.append(flow2)
    a_list.append(a2)

    flow_likelihood = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
            torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1))
    flow_likelihood_list.append(flow_likelihood)
    flow_det_list.append(logdet)
    cum_det_list.append(cum_det)
    flow_loss_list.append(loss_flow)


    ####

    plt.figure()
    indices = np.arange(5000)[::50]
    loss_flow_all = np.array(loss_flow_all)
    plt.plot(indices, loss_flow_all[indices], '-ko', ms=3)
    plt.ylabel(r'flow loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(figures_folder + '/flow_loss2.pdf')


    ### append mean_transfer_list_m
    # nnet_samples = opt_bnn.sample_functions(original_x_train_add_g3, num_sample=128).float().to(device)  # original_x_test
    # nnet_samples = nnet_samples.squeeze().detach()
    #
    # for flow in flow_list:
    #
    #     nnet_samples, _, _ = flow(nnet_samples, original_x_train_add_g3)
    #
    # fvi, nnbias_m_new, logdet_new = flow2(nnet_samples, original_x_train_add_g3)
    # nnbias_m_list.append(nnbias_m_new)
    # # logdet_list.append(logdet_new)
    #
    # new_mean_trans_m = Mean_transfer(a2, nnbias_m_new)
    # mean_transfer_list_m.append(new_mean_trans_m)
    #
    #
    # ### append mean_transfer_list_t
    # nnet_samples = opt_bnn.sample_functions(original_x_test, num_sample=100).float().to(device)  # original_x_test
    # nnet_samples = nnet_samples.squeeze().detach()
    #
    # for flow in flow_list:
    #
    #     nnet_samples, _, _ = flow(nnet_samples, original_x_test)
    #
    # fvi, nnbias_test_new, _ = flow2(nnet_samples, original_x_test)
    # nnbias_test_list.append(nnbias_test_new)
    #
    # new_mean_trans_t = Mean_transfer(a2, nnbias_test_new)
    # mean_transfer_list_t.append(new_mean_trans_t)
    #
    # ### append var_transfer_list
    # new_var_trans = Var_transfer(a2)
    # var_transfer_list.append(new_var_trans)

    flow_list.append(flow2)



    ###########
    nnet_samples = opt_bnn.sample_functions(original_x_test, num_sample=100).float().to(device)  # original_x_test
    nnet_samples = nnet_samples.squeeze().detach()

    for flow in flow_list:

        nnet_samples, _, _ = flow(nnet_samples, original_x_test)
        # print('nnet_samples: ', nnet_samples)

    nnet_samples = nnet_samples.detach()
    print('fvi.shape: ', nnet_samples.shape)

    mean_y_pred = torch.mean(nnet_samples, 0).cpu().numpy()
    std_y_pred = torch.std(nnet_samples, 0).cpu().numpy()

    with torch.no_grad():
        plt.clf()
        figure = plt.figure(figsize=(8, 5.5), facecolor='white')
        init_plotting()

        # plt.plot(original_x_train_add_g3.squeeze().cpu(), original_y_train_add_g3.cpu(), 'g', label="True function")
        plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
        plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function',
                 linewidth=2.5)  # chocolate
        # plt.plot(original_x_test.squeeze().cpu(), nnet_mean, 'y', label='FVI Mean function')
        for i in range(5):
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                             mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')  # wheat
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                             mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
        plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
        # plt.plot(original_x_train_add_g3.cpu(), original_y_train_add_g3.cpu(), 'or', zorder=10, ms=6, label='Measurement points')

        ######g3######
        plt.axvline(-0.75, linestyle='--', color='darkorchid')
        plt.axvline(-0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.25, linestyle='--', color='darkorchid')
        plt.axvline(0.75, linestyle='--', color='darkorchid')
        plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
        plt.text(0.25, 1, 'Training data range', fontsize='small')

        plt.legend(loc='upper center', ncol=3, fontsize='small')
        plt.title('Flow approximation')
        plt.tight_layout()
        plt.ylim([dataset.y_min, dataset.y_max])
        plt.tight_layout()
        # plt.savefig(figures_folder + '/flow approximation_GPLinear_test2.pdf')
        plt.savefig(figures_folder + '/Flow approximation_loop{}.pdf'.format(outloop + 2))

print('a_list: ', a_list)
print('flow_likelihood: ', flow_likelihood_list)
# print('flow_det: ', flow_det_list)
print('cum_det_list: ', cum_det_list)
print('flow_loss: ', flow_loss_list)
print('flow_name: ', flow_name_list)



###########################


indices = np.arange(1001)[::10]
for i in range(13):
    plt.figure()
    train_loss_all = all_train_losses[i]
    train_loss_all = np.array(train_loss_all)
    plt.plot(indices, train_loss_all[indices], '-ko', ms=3, label=f'flow {i + 1}')
    plt.ylabel(r'Training loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(loss_folder + '/train_loss{}.pdf'.format(i + 1))

for i in range(13):
    plt.figure()
    wdist_all = all_wdist_losses[i]
    wdist_all = np.array(wdist_all)
    plt.plot(indices, wdist_all[indices], '-ko', ms=3, label=f'flow {i + 1}')
    plt.ylabel(r'wdist')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(loss_folder + '/wdist{}.pdf'.format(i + 1))

for i in range(13):
    plt.figure()
    likelihood_loss_all = all_likelihood_losses[i]
    likelihood_loss_all = np.array(likelihood_loss_all)
    plt.plot(indices, likelihood_loss_all[indices], '-ko', ms=3, label=f'flow {i + 1}')
    plt.ylabel(r'likelihood_loss')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(loss_folder + '/likelihood_loss{}.pdf'.format(i + 1))

for i in range(13):
    plt.figure()
    dif_var_all = all_m2_losses[i]
    dif_var_all = np.array(dif_var_all)
    plt.plot(indices, dif_var_all[indices], '-ko', ms=3, label=f'flow {i + 1}')
    plt.ylabel(r'dif_var')
    plt.tight_layout()
    plt.xlabel('Iteration')
    plt.tight_layout()
    plt.savefig(loss_folder + '/dif_var{}.pdf'.format(i + 1))






# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("============================================================================================")

