import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
import gpytorch
from gp_prior import ExactGPModel, prior_sample_functions, prior_sample_functions2


from bnn import BNN, OPTBNN
import torch.nn.functional as F
import argparse

from utils.logging import get_logger

from data import uci_woval

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt


################################### Hyper-parameters ###################################

lr_bnn = 0.001
prior_coeff = 10
bnn_name_string = 'UAI_toy'
uci_dataset_name_string = 'protein' #'protein'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 1000
num_sample = 1000
epochs = 2001
n_step_prior_pretraining = 100
lr_optbnn = 0.01
f_coeff = 10

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 10
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + 'fwbi_uci'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'fwbi_uci' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + 'fwbi_uci' + '/'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'fwbi_uci' + '/'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

# dataset = uci_woval(uci_dataset_name_string, seed=random_seed)
# train_x, test_x, train_y, test_y = dataset.x_train, dataset.x_test, dataset.y_train, dataset.y_test

train_x, test_x, train_y, test_y = uci_woval(uci_dataset_name_string, seed=random_seed)
# print(test_y)

training_num, input_dim = train_x.shape
test_num = test_x.shape[0]
output_dim = 1
is_continuous = True

print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)

original_x_train = torch.from_numpy(train_x).float().to(device)
original_y_train = torch.from_numpy(train_y).float().to(device)
original_x_test = torch.from_numpy(test_x).float().to(device)
original_y_test = torch.from_numpy(test_y).float().to(device)



loss = nn.MSELoss(reduction='mean')


# mask = torch.randperm(training_num)[:test_num]
# prior_x = original_x_train[mask, :]
# prior_y = original_y_train[mask]

print("============================================================================================")
################################## define model ##################################


bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
opt_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(),  'lr': lr_bnn}])
opt_bnn_optimizer = torch.optim.RMSprop([{'params': opt_bnn.parameters(), 'lr': lr_optbnn}])

print("============================================================================================")
################################## prior pre-training ##################################

# pre-train GP prior

likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior = ExactGPModel(original_x_test, original_y_test, likelihood, input_dim).to(device)

prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(original_x_test)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, original_y_test)
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        # prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))

    optimizer.step()

prior.eval()
likelihood.eval()

sigma = prior.likelihood.noise.item()


print("============================================================================================")
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)


# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

# define lipschitz function and sample function
##########################################    check opt_bnn  #######

def sample_measurement_set(X, num_data):
    n = torch.Tensor([40])
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

def lipf(X, Y):
    return ((X - Y).abs())

####################################################
train_loss_all = []
prior_distance_all = []
wdist_all = []
likelihood_loss_all = []
w1_bnn_optbnn = []

for epoch in range(epochs):

    measurement_set = sample_measurement_set(X=original_x_train, num_data=training_num)
    # add_set = sample_measurement_set(original_x_test, num_data=test_num)
    # measurement_set = torch.cat((measurement_set, add_set), 0).to(device)
    gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
    gpss = gpss.squeeze()
    nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device) # original_x_test
    nnet_samples = nnet_samples.squeeze()

    functional_wdist = lipf(gpss, nnet_samples)
    functional_wdist = torch.mean(functional_wdist)

    ##################### variance constraint #########

    with torch.no_grad():
        gp_var = prior(measurement_set).variance
        # print('gp_var.shape: ', gp_var.shape)

    opt_samples = opt_bnn.sample_functions(measurement_set, num_sample=200).float().to(device)
    opt_samples = opt_samples.squeeze()
    opt_var = torch.std(opt_samples, 0) ** 2
    # print('opt_var.shape: ', opt_var.shape)
    dif_var = (gp_var - opt_var).abs()
    # dif_var = torch.sum(dif_var)
    dif_var = torch.mean(dif_var)

    w_prior_mu_list = []
    b_prior_mu_list = []
    w_prior_std_list = []
    b_prior_std_list = []
    for name, p in opt_bnn.named_parameters():
        if "W_mu" in name:
            # w_prior_mu = p
            w_prior_mu_list.append(p)
        if 'b_mu' in name:
            # b_prior_mu = p
            b_prior_mu_list.append(p)
        if 'W_std' in name:
            # w_prior_std = p
            w_prior_std_list.append(p)
        if 'b_std' in name:
            # b_prior_std = p
            b_prior_std_list.append(p)

    # print(w_prior_mu_list, b_prior_mu_list,  w_prior_std_list, b_prior_std_list)

    input_layer_wd = bnn.input_layer.para_wd(w_prior_mu_list[0], b_prior_mu_list[0], w_prior_std_list[0],
                                             b_prior_std_list[0])
    mid_layer_wd = bnn.mid_layer.para_wd(w_prior_mu_list[1], b_prior_mu_list[1], w_prior_std_list[1],
                                         b_prior_std_list[1])
    output_layer_wd = bnn.output_layer.para_wd(w_prior_mu_list[2], b_prior_mu_list[2], w_prior_std_list[2],
                                               b_prior_std_list[2])

    bnn_2wassdist = (input_layer_wd + mid_layer_wd + output_layer_wd) / 3


    ### calculate bnn likelihood loss ###
    pred_y = bnn.forward(original_x_train)

    pred_y = pred_y.squeeze().flatten()
    # print(pred_y)

    train_loss = loss(pred_y, original_y_train) + prior_coeff * bnn_2wassdist + f_coeff * functional_wdist + 1 * dif_var
    likelihood_loss = loss(pred_y, original_y_train)
    # print('likelihood_loss: ', likelihood_loss)
    # print('distance_prior: ', bnn_2wassdist)
    # print('w_dist: ', functional_wdist)
    #
    ####### calculate w_1(bnn, optbnn)#######
    # bnn_samles = bnn.sample_functions(measurement_set, num_sample=128).float().to(device)
    # bnn_samles = bnn_samles.squeeze()
    # w_1_f = lipf(bnn_samles, nnet_samples)
    # w_1_f = torch.mean(w_1_f)

    # optimisation
    bnn_optimizer.zero_grad()
    opt_bnn_optimizer.zero_grad()
    train_loss.backward()
    bnn_optimizer.step()
    scheduler.step()
    opt_bnn_optimizer.step()
    #

    train_loss_all.append(train_loss.item())
    likelihood_loss_all.append(likelihood_loss.item())
    prior_distance_all.append(bnn_2wassdist.item())
    wdist_all.append(functional_wdist.item())
    # w1_bnn_optbnn.append(w_1_f.item())

    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
          datetime.now().replace(microsecond=0) - start_time)

    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    if epoch % test_interval == 0:
        bnn.eval()

        samples_pred_test_y = bnn.forward_eval(original_x_test, num_sample).squeeze().detach()

        bnn.train()

        mse_test_loss = loss(samples_pred_test_y, original_y_test)
        # nll_test_loss = -F.nll_loss(samples_pred_test_y, original_y_test, reduction='sum')
        # print('nll_test_loss: ', nll_test_loss)
        print('mse_test_loss: ', mse_test_loss)

        ######################calibration curve#####################
        # test_num = original_y_test.shape[0]
        # test_num = test_num ** 0.5
        # std_y_pred = torch.std(samples_pred_test_y, 0).cpu().numpy()
        # mean_y_pred = torch.mean(samples_pred_test_y, 0).cpu().numpy()
        # oy = original_y_test
        # oy = oy.cpu().numpy()
        # # upper_bound = mean_y_pred + 1.96 * std_y_pred / test_num
        # # lower_bound = mean_y_pred - 1.96 * std_y_pred / test_num
        # upper_bound = mean_y_pred + 3 * std_y_pred
        # lower_bound = mean_y_pred - 3 * std_y_pred
        # perc_within = np.mean((oy <= upper_bound) & (oy >= lower_bound))
        # print(f'{perc_within * 100:.1f}% of the points in the 95% confidence intervel')

        ################ NLL ################
        mean_y_pred = torch.mean(samples_pred_test_y, 0)
        std_y_pred = torch.std(samples_pred_test_y, 0)
        variance_y_pred = torch.std(samples_pred_test_y, 0) ** 2 + sigma
        log_std = torch.log(std_y_pred).sum()
        vec = 0.5 * (original_y_test - mean_y_pred) ** 2 / variance_y_pred
        vec = vec.sum()
        const = 0.5 * np.log(np.pi * 2) + np.log(sigma)
        nll = test_num * const + log_std + vec
        nll_mean = nll / test_num

        print('nll_mean: ', nll_mean)

        # mean_y_pred = torch.mean(samples_pred_test_y, 0)
        # std_y_pred = torch.std(samples_pred_test_y, 0)
        # variance_y_pred = torch.std(samples_pred_test_y, 0) ** 2
        # log_std = torch.log(std_y_pred).sum()
        # # vec = 0.5 * (original_y_test - mean_y_pred) ** 2 / variance_y_pred
        # vec = 0.5 * ((original_y_test - mean_y_pred) ** 2 + variance_y_pred)
        # vec = vec.sum()
        # const = 0.5 * np.log(np.pi * 2)
        # nll = test_num * const + log_std + vec
        # nll_mean = nll / test_num
        #
        # print('nll_mean: ', nll_mean)


log_f.close()

# plt.figure()
# indices = np.arange(2001)[::20]
# train_loss_all = np.array(train_loss_all)
# plt.plot(indices, train_loss_all[indices], '-ko', ms=3)
# plt.ylabel(r'Training loss')
# plt.tight_layout()
# plt.xlabel('Iteration')
# plt.tight_layout()
# plt.savefig(figures_folder + '/train_loss.pdf')
#
# plt.figure()
# indices = np.arange(2001)[::20]
# likelihood_loss_all = np.array(likelihood_loss_all)
# plt.plot(indices, likelihood_loss_all[indices], '-ko', ms=3)
# plt.ylabel(r'Likelihood loss')
# plt.tight_layout()
# plt.xlabel('Iteration')
# plt.tight_layout()
# plt.savefig(figures_folder + '/likelihood_loss.pdf')
#
# plt.figure()
# indices = np.arange(2001)[::20]
# prior_distance_all = np.array(prior_distance_all)
# plt.plot(indices, prior_distance_all[indices], '-ko', ms=3)
# # plt.ylabel(r'$W_2(p_{prior bnn}, p_{bnn})$')
# plt.ylabel(r'$W_2(q_{f}, p_{g})$')
# plt.tight_layout()
# plt.xlabel('Iteration')
# plt.tight_layout()
# plt.savefig(figures_folder + '/prior_distance.pdf')
#
# plt.figure()
# indices = np.arange(2001)[::20]
# wdist_all = np.array(wdist_all)
# plt.plot(indices, wdist_all[indices], '-ko', ms=3)
# # plt.ylabel(r'$W_1(p_{gp}, p_{prior bnn})$')
# plt.ylabel(r'$W_1(p_{f}, p_{g})$')
# plt.tight_layout()
# plt.xlabel('Iteration')
# plt.tight_layout()
# plt.savefig(figures_folder + '/wdist.pdf')


###################################
# plt.figure()
# plt.plot(train_loss_all, 'r-')
# plt.title('train loss')
# plt.savefig(figures_folder + '/train_loss.pdf')
#
# plt.figure()
# plt.plot(likelihood_loss_all, 'r-')
# plt.title('likelihood loss')
# plt.savefig(figures_folder + '/likelihood_loss.pdf')
#
# plt.figure()
# plt.plot(prior_distance_all, 'r-')
# plt.title('prior_distance')
# plt.savefig(figures_folder + '/prior_distance.pdf')
#
# plt.figure()
# plt.plot(wdist_all, 'r-')
# plt.title('wdist')
# plt.savefig(figures_folder + '/wdist.pdf')

# plt.clf()
# figure = plt.figure(figsize=(8, 5.5), facecolor='white')
# # init_plotting()
# # plt.figure()
# index = np.arange(2001)
# index = list(index)
# plt.scatter(wdist_all, w1_bnn_optbnn, c=index)
# plt.colorbar(label='epoch')
# plt.ylabel(r'w1_bnn_optbnn')
# plt.tight_layout()
# plt.xlabel('w_2_para')
# plt.tight_layout()
# # plt.legend(loc='upper center', ncol=3, fontsize='small')
# plt.savefig(figures_folder + '/function_parameter_dist.pdf')

############## calibration curve ############
# print('mean_y_pred: ', mean_y_pred)
# print('oy: ',oy)
# pcdf = np.mean(mean_y_pred.reshape <= oy.reshape(-1, 1), axis=1)
# print('pcdf.shape: ', pcdf.shape)
# ecdf = np.zeros(len(pcdf))
# for i, p in enumerate(pcdf):
#     ecdf[i] = np.sum(pcdf <= p)/len(pcdf)
#
# print('ecdf.shape: ', ecdf.shape)
# print(pcdf)
# print(ecdf)
#
# plt.clf()
# figure = plt.figure(figsize=(8, 5.5), facecolor='white')
# # init_plotting()
#
# plt.scatter(pcdf, ecdf)
# plt.ylabel('Empirical')
# plt.tight_layout()
# plt.xlabel('Predicted')
# plt.tight_layout()
# plt.savefig(figures_folder + '/calibration_curve.pdf')


# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("Test loss  : ", mse_test_loss)
print("============================================================================================")

