import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
import gpytorch

from data import x3_gap_toy, sin_toy
from toy import polynomial_toy, polynomial_gap_toy, g2_toy, g3_toy, isotropy_toy
from bnn import BNN, OPTBNN
from gp_prior import ExactGPModel, prior_sample_functions
from wasserstein_prior_match2 import MapperWasserstein, WassersteinDistance
import copy

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

from utils.utils import default_plotting_new as init_plotting


################################### Hyper-parameters ###################################

lr_bnn = 0.01
prior_coeff = 1  # gap:2, sin:10
bnn_name_string = 'FWBI'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 2000
num_functions = 10
num_sample = 1000
n_step_prior_pretraining = 100
# bnn_prior_step = 10000
epochs = 10001
f_coeff = 10
lr_optbnn = 1e-2
# 0.0012
torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 100  # gap:100, sin:500
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + 'iso_1_1e-2_10_rbf'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'iso_1_1e-2_10_rbf' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + 'iso_1_1e-2_10_rbf'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'iso_1_1e-2_10_rbf'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

dataset = dict(x3=x3_gap_toy, sin=sin_toy, g2=g2_toy, g3=g3_toy, gap=polynomial_gap_toy, iso=isotropy_toy)['iso']()
original_x_train, original_y_train = dataset.train_samples()
mean_x, std_x = np.mean(original_x_train), np.std(original_x_train)
mean_y, std_y = np.mean(original_y_train), np.std(original_y_train)
train_x = (original_x_train - mean_x) / std_x
train_y = (original_y_train - mean_y) / std_y
original_x_test, original_y_test = dataset.test_samples()
test_x = (original_x_test - mean_x) / std_x
test_y = (original_y_test - mean_y) / std_y

y_logstd = np.log(dataset.y_std / std_y)

lower_ap = (dataset.x_min - mean_x) / std_x
upper_ap = (dataset.x_max - mean_x) / std_x

training_num, input_dim = np.shape(train_x)
input_dim = 1 #np.array([input_dim])
output_dim = 1 #np.array([1])
is_continuous = True

original_x_train = torch.from_numpy(original_x_train).float().to(device)
original_y_train = torch.from_numpy(original_y_train).float().to(device)
original_x_test = torch.from_numpy(original_x_test).float().to(device)
original_y_test = torch.from_numpy(original_y_test).float().to(device)

add_set = np. random. uniform(-5, 5, (40, 1))
add_set_y = 2 * np. sin(4*add_set)
add_set = torch.from_numpy(add_set).float().to(device)
add_set_y = torch.from_numpy(add_set_y).float().to(device)
add_set_y = torch.squeeze(add_set_y).to(device)

original_x_train_add = torch.cat((original_x_train, add_set), 0).to(device)
original_y_train_add = torch.cat((original_y_train, add_set_y), 0).to(device)

add_set_gap = np. random. uniform(-10, 10, (40, 1))
add_set_y_gap = np.sin(add_set_gap) + 0.1 * add_set_gap ** 2
add_set_gap = torch.from_numpy(add_set_gap).float().to(device)
add_set_y_gap = torch.from_numpy(add_set_y_gap).float().to(device)
add_set_y_gap = torch.squeeze(add_set_y_gap).to(device)

original_x_train_add_gap = torch.cat((original_x_train, add_set_gap), 0).to(device)
original_y_train_add_gap = torch.cat((original_y_train, add_set_y_gap), 0).to(device)

#####
add_set_g2 = np. random. uniform(-10, 10, (30, 1))
add_set_y_g2 = np.sin(add_set_g2) + 0.1 * add_set_g2
add_set_g2 = torch.from_numpy(add_set_g2).float().to(device)
add_set_y_g2 = torch.from_numpy(add_set_y_g2).float().to(device)
add_set_y_g2 = torch.squeeze(add_set_y_g2).to(device)

original_x_train_add_g2 = torch.cat((original_x_train, add_set_g2), 0).to(device)
original_y_train_add_g2 = torch.cat((original_y_train, add_set_y_g2), 0).to(device)

#####
add_set_g3 = np. random. uniform(-1, 1, (40, 1))
add_set_y_g3 = np.sin(3 * 3.14 * add_set_g3) + 0.3 * np.cos(9 * 3.14 * add_set_g3) + 0.5 * np.sin(7 * 3.14 * add_set_g3)
add_set_g3 = torch.from_numpy(add_set_g3).float().to(device)
add_set_y_g3 = torch.from_numpy(add_set_y_g3).float().to(device)
add_set_y_g3 = torch.squeeze(add_set_y_g3).to(device)

original_x_train_add_g3 = torch.cat((original_x_train, add_set_g3), 0).to(device)
original_y_train_add_g3 = torch.cat((original_y_train, add_set_y_g3), 0).to(device)

######
add_set_iso = np. random. uniform(-0.5, 1.2, (40, 1))
add_set_y_iso = add_set_iso + 0.3 * np.sin(2 * 3.14 * add_set_iso) + 0.3 * np.sin(4 * 3.14 * add_set_iso)
add_set_iso = torch.from_numpy(add_set_iso).float().to(device)
add_set_y_iso = torch.from_numpy(add_set_y_iso).float().to(device)
add_set_y_iso = torch.squeeze(add_set_y_iso).to(device)

original_x_train_add_iso = torch.cat((original_x_train, add_set_iso), 0).to(device)
original_y_train_add_iso = torch.cat((original_y_train, add_set_y_iso), 0).to(device)


loss = nn.MSELoss(reduction='mean')

print("============================================================================================")
################################## define model ##################################


bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(), 'lr': lr_bnn}])

opt_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
opt_bnn_optimizer = torch.optim.RMSprop([{'params': opt_bnn.parameters(), 'lr': lr_optbnn}])

print("============================================================================================")
################################## prior pre-training ##################################

# pre-train GP prior

likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior = ExactGPModel(original_x_test, original_y_test, likelihood, input_dim).to(device)

prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(original_x_test)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, original_y_test)
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        # prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))

    optimizer.step()

prior.eval()
likelihood.eval()
################################## bnn prior pre-train ##################################

# prior_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# prior_bnn_optimizer = torch.optim.RMSprop([{'params': prior_bnn.parameters(), 'lr': lr_optbnn}])
#
# # loss
# for i in range(10000):
#     pred_prior = prior_bnn.forward(original_x_test)
#     # L1regu = regu_coeff / 20 * L1regu
#     # pred_y = bnn.forward(original_x_train)
#     # y_predicted, distance_prior = bnn.forward_w(train_x)
#
#     pred_prior = pred_prior.squeeze().flatten()
#     # print(pred_y)
#     # print(train_y_tmp)
#     train_loss_prior = loss(pred_prior, original_y_test)
# # optimisation
#     prior_bnn_optimizer.zero_grad()
#     train_loss_prior.backward()
#     prior_bnn_optimizer.step()
# # print
#     print("epoch : {} \t\t training loss \t\t : {}".format(10000, train_loss_prior),
#           )
#
# with torch.no_grad():
#
#     plt.clf()
#     figure = plt.figure(figsize=(8, 5.5), facecolor='white')
#     init_plotting()
#
#     # measurement_set = sample_measurement_set(X=original_x_train_add, num_data=60)
#
#
#     bnnprior_samples = prior_bnn.sample_functions(original_x_test, num_sample=128).float().to(device)
#     # nnet_samples = nnet_samples.squeeze().t()
#     bnnprior_samples = bnnprior_samples.squeeze()
#
#     nns = torch.mean(bnnprior_samples, 0)
#     # nns = nnet_samples[:, :5]
#     plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
#     # plt.plot(original_x_test.squeeze().cpu(), gpss.cpu().numpy(), 'slateblue', label='GP prior', linewidth=2.5)
#     plt.plot(original_x_test.squeeze().cpu(), nns.cpu().numpy(), 'darkorange', label='Prior BNN distribution', linewidth=2.5)
#     # plt.scatter(measurement_set.squeeze().cpu(), gpss.cpu().numpy())
#     # plt.scatter(measurement_set.squeeze().cpu(), nns.cpu().numpy(), cmap='r')
#
#     plt.legend(loc='upper center', ncol=2, fontsize='small')
#     plt.title('Prior BNN')
#     plt.tight_layout()
#     plt.ylim([dataset.y_min, dataset.y_max])
#     plt.tight_layout()
#     plt.savefig(figures_folder + '/Prior BNN.pdf')





print("============================================================================================")
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

##########################################    check opt_bnn  #######

def sample_measurement_set(X, num_data):
    n = torch.Tensor([40])  # sin_40, gap_20
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

def lipf(X, Y):
    return ((X - Y).abs())

####################################################
train_loss_all = []
prior_distance_all = []
wdist_all = []
likelihood_loss_all = []

for epoch in range(epochs):

    measurement_set = sample_measurement_set(X=original_x_train_add_iso, num_data=60) # sin num_data=60, gap num_data=100
    gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
    # gpss = prior_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
    gpss = gpss.squeeze()
    nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device) # original_x_test
    nnet_samples = nnet_samples.squeeze()

    functional_wdist = lipf(gpss, nnet_samples)
    functional_wdist = torch.mean(functional_wdist)


    w_prior_mu_list = []
    b_prior_mu_list = []
    w_prior_std_list = []
    b_prior_std_list = []
    for name, p in opt_bnn.named_parameters():
        if "W_mu" in name:
            # w_prior_mu = p
            w_prior_mu_list.append(p)
        if 'b_mu' in name:
            # b_prior_mu = p
            b_prior_mu_list.append(p)
        if 'W_std' in name:
            # w_prior_std = p
            w_prior_std_list.append(p)
        if 'b_std' in name:
            # b_prior_std = p
            b_prior_std_list.append(p)

    # print(w_prior_mu_list, b_prior_mu_list,  w_prior_std_list, b_prior_std_list)

    input_layer_wd = bnn.input_layer.para_wd(w_prior_mu_list[0], b_prior_mu_list[0], w_prior_std_list[0],
                                             b_prior_std_list[0])
    mid_layer_wd = bnn.mid_layer.para_wd(w_prior_mu_list[1], b_prior_mu_list[1], w_prior_std_list[1],
                                         b_prior_std_list[1])
    output_layer_wd = bnn.output_layer.para_wd(w_prior_mu_list[2], b_prior_mu_list[2], w_prior_std_list[2],
                                               b_prior_std_list[2])

    bnn_2wassdist = (input_layer_wd + mid_layer_wd + output_layer_wd) / 3

    ### calculate bnn likelihood loss ###
    pred_y = bnn.forward(original_x_train)

    pred_y = pred_y.squeeze().flatten()
    # print(pred_y)

    train_loss = loss(pred_y, original_y_train) + prior_coeff * bnn_2wassdist + f_coeff * functional_wdist
    likelihood_loss = loss(pred_y, original_y_train)
    print('likelihood_loss: ', likelihood_loss)
    print('distance_prior: ', bnn_2wassdist)
    print('w_dist: ', functional_wdist)
    #
    # optimisation
    bnn_optimizer.zero_grad()
    opt_bnn_optimizer.zero_grad()
    train_loss.backward()
    bnn_optimizer.step()
    scheduler.step()
    opt_bnn_optimizer.step()
    #
    train_loss_all.append(train_loss.item())
    likelihood_loss_all.append(likelihood_loss.item())
    prior_distance_all.append(bnn_2wassdist.item())
    wdist_all.append(functional_wdist.item())

    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
          datetime.now().replace(microsecond=0) - start_time)

    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    if epoch % test_interval == 0:

        bnn.eval()

        samples_pred_test_y = bnn.forward_eval(original_x_test, num_sample).squeeze().detach()
        # idx = torch.randperm(1000)[:5]
        # nns = samples_pred_test_y.t()[:, idx]

        # mean_y_pred = pred_test_y
        mean_y_pred = torch.mean(samples_pred_test_y, 0).cpu().numpy()
        std_y_pred = torch.std(samples_pred_test_y, 0).cpu().numpy()
        test_loss = loss(samples_pred_test_y, original_y_test)

        bnn.train()

        plt.clf()
        figure = plt.figure(figsize=(8, 5.5), facecolor='white')
        init_plotting()

        plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
        plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function', linewidth=2.5)
        # plt.plot(original_x_test.squeeze().cpu(), nns.cpu().numpy(), 'steelblue')

        for i in range(5):
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                             mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                             mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
        plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
        # plt.grid(True)
        # plt.tick_params(axis='both', bottom='off', top='off', left='off', right='off',
        #                 labelbottom='off', labeltop='off', labelleft='off', labelright='off')
        ######### sin #################
        # plt.axvline(-2, linestyle='--', color='darkorchid')   # darkorange
        # plt.axvline(2, linestyle='--', color='darkorchid')
        # plt.text(-1.5, -3, 'Training data range')
        ###### gap ###########
        # plt.axvline(-7.5, linestyle='--', color='darkorchid')
        # plt.axvline(-2.5, linestyle='--', color='darkorchid')
        # plt.axvline(5, linestyle='--', color='darkorchid')
        # plt.axvline(7.5, linestyle='--', color='darkorchid')
        # # plt.text(-1.5, -3, 'Training data range')
        # plt.text(-7.4, -1, 'Training data range', fontsize='small')
        # plt.text(5.1, -1, 'Training data', fontsize='xx-small')
        # plt.text(5.1, -1.8, 'range', fontsize='xx-small')

        ######g3######
        # plt.axvline(-0.75, linestyle='--', color='darkorchid')
        # plt.axvline(-0.25, linestyle='--', color='darkorchid')
        # plt.axvline(0.25, linestyle='--', color='darkorchid')
        # plt.axvline(0.75, linestyle='--', color='darkorchid')
        # plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
        # plt.text(0.25, 1, 'Training data range', fontsize='small')

        plt.legend(loc='upper center', ncol=3, fontsize='small')
        plt.title('FWBI posterior (RBF kernel)')
        plt.tight_layout()
        plt.ylim([dataset.y_min, dataset.y_max])
        plt.tight_layout()

        print("--------------------------------------------------------------------------------------------")
        plt.savefig(figures_folder + '/plot_epoch{}.pdf'.format(epoch))
        print("figure saved")
        print("epoch : {} \t\t test loss \t\t : {}".format(epoch, test_loss))
        print("--------------------------------------------------------------------------------------------")

log_f.close()

plt.figure()
indices = np.arange(10001)[::100]
train_loss_all = np.array(train_loss_all)
plt.plot(indices, train_loss_all[indices], '-ko', ms=3)
plt.ylabel(r'Training loss')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/train_loss.pdf')

plt.figure()
indices = np.arange(10001)[::100]
likelihood_loss_all = np.array(likelihood_loss_all)
plt.plot(indices, likelihood_loss_all[indices], '-ko', ms=3)
plt.ylabel(r'Likelihood loss')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/likelihood_loss.pdf')

plt.figure()
indices = np.arange(10001)[::100]
prior_distance_all = np.array(prior_distance_all)
plt.plot(indices, prior_distance_all[indices], '-ko', ms=3)
# plt.ylabel(r'$W_2(p_{prior bnn}, p_{bnn})$')
plt.ylabel(r'$W_2(q_{f}, p_{g})$')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/prior_distance.pdf')

plt.figure()
indices = np.arange(10001)[::100]
wdist_all = np.array(wdist_all)
plt.plot(indices, wdist_all[indices], '-ko', ms=3)
# plt.ylabel(r'$W_1(p_{gp}, p_{prior bnn})$')
plt.ylabel(r'$W_1(p_{f}, p_{g})$')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/wdist.pdf')

with torch.no_grad():

    plt.clf()
    figure = plt.figure(figsize=(8, 5.5), facecolor='white')
    init_plotting()

    # measurement_set = sample_measurement_set(X=original_x_train_add, num_data=60)
    gpss = prior_sample_functions(original_x_test, prior, num_sample=128).detach().float().to(device)
    # gpss = prior_bnn.sample_functions(original_x_test, num_sample=128).float().to(device)
    gpss = gpss.squeeze()
    gpss = torch.mean(gpss, 0)


    nnet_samples = opt_bnn.sample_functions(original_x_test, num_sample=128).float().to(device)
    # nnet_samples = nnet_samples.squeeze().t()
    nnet_samples = nnet_samples.squeeze()

    nns = torch.mean(nnet_samples, 0)
    # nns = nnet_samples[:, :5]
    plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
    plt.plot(original_x_test.squeeze().cpu(), gpss.cpu().numpy(), 'slateblue', label='GP prior', linewidth=2.5)
    plt.plot(original_x_test.squeeze().cpu(), nns.cpu().numpy(), 'darkorange', label='Bridging distribution', linewidth=2.5)
    # plt.scatter(measurement_set.squeeze().cpu(), gpss.cpu().numpy())
    # plt.scatter(measurement_set.squeeze().cpu(), nns.cpu().numpy(), cmap='r')

    plt.legend(loc='upper center', ncol=3, fontsize='small')
    plt.title('GP prior and Bridging distribution (RBF kernel)')
    plt.tight_layout()
    plt.ylim([dataset.y_min, dataset.y_max])
    plt.tight_layout()
    plt.savefig(figures_folder + '/gpp-pg.pdf')

# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("============================================================================================")


def test():
    print