import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np

from data import x3_gap_toy, sin_toy
from toy import polynomial_toy, polynomial_gap_toy, polynomial_gap_toy2, g2_toy, g3_toy

from fbnn import *
from bnn import BNN

import gpytorch
import torch.optim.lr_scheduler
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

from utils.utils import default_plotting_new as init_plotting


################################### Hyper-parameters ###################################

lr_bnn = 0.01
prior_coeff = 1
bnn_name_string = 'FWBI'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 2000
num_functions = 128
num_sample = 1000
epochs = 10001
n_step_prior_pretraining = 100
regu_coeff = 0.01

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 100
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
   device = torch.device('cuda:0')
   torch.cuda.empty_cache()
   print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
   print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + 'fbnn_g2_matern' + '/'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'fbnn_g2_matern' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + 'fbnn_g2_matern' + '/'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'fbnn_g2_matern'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

dataset = dict(sin=sin_toy, polynomial=polynomial_toy, g2=g2_toy, g3=g3_toy, gap=polynomial_gap_toy)['g2']()
original_x_train, original_y_train = dataset.train_samples()
mean_x, std_x = np.mean(original_x_train), np.std(original_x_train)
mean_y, std_y = np.mean(original_y_train), np.std(original_y_train)
train_x = (original_x_train - mean_x) / std_x
train_y = (original_y_train - mean_y) / std_y
original_x_test, original_y_test = dataset.test_samples()
test_x = (original_x_test - mean_x) / std_x
test_y = (original_y_test - mean_y) / std_y

y_logstd = np.log(dataset.y_std / std_y)

lower_ap = (dataset.x_min - mean_x) / std_x
upper_ap = (dataset.x_max - mean_x) / std_x

#training_num, input_dim = np.shape(train_x)
input_dim = 1 #np.array([input_dim])
output_dim = 1 #np.array([1])
is_continuous = True

original_x_train = torch.from_numpy(original_x_train).float().to(device)
original_y_train = torch.from_numpy(original_y_train).float().to(device)
original_x_test = torch.from_numpy(original_x_test).float().to(device)
original_y_test = torch.from_numpy(original_y_test).float().to(device)


add_set = np. random. uniform(-5, 5, (40, 1))
add_set_y = 2 * np. sin(4*add_set)
add_set = torch.from_numpy(add_set).float().to(device)
add_set_y = torch.from_numpy(add_set_y).float().to(device)
add_set_y = torch.squeeze(add_set_y).to(device)
# print('add_set_y_size', add_set_y.size())
# print('original_train_x_size', original_x_train.size())
original_x_train_add = torch.cat((original_x_train, add_set), 0).to(device)
original_y_train_add = torch.cat((original_y_train, add_set_y), 0).to(device)
training_num, input_dim = original_x_train_add.shape

original_x_train_add_test = torch.cat((original_x_train_add, original_x_test.unsqueeze(-1)), 0).to(device)
original_y_train_add_test = torch.cat((original_y_train_add, original_y_test), 0).to(device)

train_x = torch.from_numpy(train_x).float().to(device)
train_y = torch.from_numpy(train_y).float().to(device)
test_x = torch.from_numpy(test_x).float().to(device)
test_y = torch.from_numpy(test_y).float().to(device)

add_set_gap = np. random. uniform(-10, 10, (40, 1))
add_set_y_gap = np.sin(add_set_gap) + 0.1 * add_set_gap ** 2
add_set_gap = torch.from_numpy(add_set_gap).float().to(device)
add_set_y_gap = torch.from_numpy(add_set_y_gap).float().to(device)
add_set_y_gap = torch.squeeze(add_set_y_gap).to(device)

original_x_train_add_gap = torch.cat((original_x_train, add_set_gap), 0).to(device)
original_y_train_add_gap = torch.cat((original_y_train, add_set_y_gap), 0).to(device)

#####g2#########
add_set_g2 = np. random. uniform(-10, 10, (30, 1))
add_set_y_g2 = np.sin(add_set_g2) + 0.1 * add_set_g2
add_set_g2 = torch.from_numpy(add_set_g2).float().to(device)
add_set_y_g2 = torch.from_numpy(add_set_y_g2).float().to(device)
add_set_y_g2 = torch.squeeze(add_set_y_g2).to(device)

original_x_train_add_g2 = torch.cat((original_x_train, add_set_g2), 0).to(device)
original_y_train_add_g2 = torch.cat((original_y_train, add_set_y_g2), 0).to(device)


######g3##############
add_set_g3 = np. random. uniform(-1, 1, (40, 1))
add_set_y_g3 = np.sin(3 * 3.14 * add_set_g3) + 0.3 * np.cos(9 * 3.14 * add_set_g3) + 0.5 * np.sin(7 * 3.14 * add_set_g3)
add_set_g3 = torch.from_numpy(add_set_g3).float().to(device)
add_set_y_g3 = torch.from_numpy(add_set_y_g3).float().to(device)
add_set_y_g3 = torch.squeeze(add_set_y_g3).to(device)

original_x_train_add_g3 = torch.cat((original_x_train, add_set_g3), 0).to(device)
original_y_train_add_g3 = torch.cat((original_y_train, add_set_y_g3), 0).to(device)

loss = nn.MSELoss(reduction='mean')
# print('train_sample******', original_x_train.shape)
print("============================================================================================")
################################## define model ##################################

# prior = fBNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# prior_optimizer = torch.optim.Adam([{'params': prior.parameters(), 'lr': lr_bnn}])

bnn = fBNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(), 'lr': lr_bnn}])

print("============================================================================================")
################################## prior pre-training ##################################

#
# prior.eval()

# pre-train GP prior

likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior = ExactGPModel(original_x_test, original_y_test, likelihood, input_dim).to(device)

prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(original_x_test)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, original_y_test)
    loss_gp.backward()
    print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
         i + 1, n_step_prior_pretraining, loss_gp.item(),
        # prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))
    optimizer.step()

prior.eval()
likelihood.eval()

################################## bnn prior pre-train ##################################

# prior_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# prior_bnn_optimizer = torch.optim.RMSprop([{'params': prior_bnn.parameters(), 'lr': 0.01}])
#
# # loss
# for i in range(10000):
#     pred_prior = prior_bnn.forward(original_x_test)
#     # L1regu = regu_coeff / 20 * L1regu
#     # pred_y = bnn.forward(original_x_train)
#     # y_predicted, distance_prior = bnn.forward_w(train_x)
#
#     pred_prior = pred_prior.squeeze().flatten()
#     # print(pred_y)
#     # print(train_y_tmp)
#     train_loss_prior = loss(pred_prior, original_y_test)
# # optimisation
#     prior_bnn_optimizer.zero_grad()
#     train_loss_prior.backward()
#     prior_bnn_optimizer.step()
# # print
#     print("epoch : {} \t\t training loss \t\t : {}".format(10000, train_loss_prior),
#           )
#
# with torch.no_grad():
#
#     plt.clf()
#     figure = plt.figure(figsize=(8, 5.5), facecolor='white')
#     init_plotting()
#
#     # measurement_set = sample_measurement_set(X=original_x_train_add, num_data=60)
#
#
#     bnnprior_samples = prior_bnn.sample_functions(original_x_test, num_sample=128).float().to(device)
#     # nnet_samples = nnet_samples.squeeze().t()
#     bnnprior_samples = bnnprior_samples.squeeze()
#
#     nns = torch.mean(bnnprior_samples, 0)
#     # nns = nnet_samples[:, :5]
#     plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
#     # plt.plot(original_x_test.squeeze().cpu(), gpss.cpu().numpy(), 'slateblue', label='GP prior', linewidth=2.5)
#     plt.plot(original_x_test.squeeze().cpu(), nns.cpu().numpy(), 'darkorange', label='Prior BNN distribution', linewidth=2.5)
#     # plt.scatter(measurement_set.squeeze().cpu(), gpss.cpu().numpy())
#     # plt.scatter(measurement_set.squeeze().cpu(), nns.cpu().numpy(), cmap='r')
#
#     plt.legend(loc='upper center', ncol=2, fontsize='small')
#     plt.title('Prior BNN')
#     plt.tight_layout()
#     plt.ylim([dataset.y_min, dataset.y_max])
#     plt.tight_layout()
#     plt.savefig(figures_folder + '/Prior BNN.pdf')



print("============================================================================================")
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')

scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)
for epoch in range(epochs):

    # loss
    pred_y = bnn.forward(original_x_train)
    # L1regu = regu_coeff / 20 * L1regu
    distance_prior = bnn.fkl(original_x_train_add_g2, 60, prior, num_functions)
    # distance_prior = bnn.fkl3(original_x_train_add_g2, 60, prior_bnn, num_functions)
    # print('distance_prior: ', distance_prior)

    pred_y = pred_y.squeeze().flatten()
    # print(pred_y)
    # print(train_y_tmp)
    train_loss = loss(pred_y, original_y_train) + prior_coeff * distance_prior

    # optimisation
    bnn_optimizer.zero_grad()
    train_loss.backward()
    bnn_optimizer.step()
    scheduler.step()
    # print
    # idx_start = 0
    # i = 0
    # for name, p in bnn.named_parameters():
    #     if "mu" in name:
    #         # print(p.grad.shape)
    #         print(p.grad.size())
    #         if p.ndim == 2:
    #             p_len = np.prod(p.grad.size())
    #         else:
    #             p_len = p.grad.shape[-1]
    #         # print(p.grad.shape)
    #         # print(p_len)
    #         idx_end = idx_start + p_len
    #         idx_start = idx_end
    #         print(p.grad.max(), p.grad.min(), p.grad.abs().mean())
    #         i += 1
    print("epoch : {} \t\t training loss \t\t : {} \t\t distance_prior \t\t : {}".format(epoch, train_loss, distance_prior),
          datetime.now().replace(microsecond=0) - start_time)


    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    if epoch % test_interval == 0:

        bnn.eval()

        samples_pred_test_y = bnn.forward_eval(original_x_test, num_sample).squeeze().detach()

        # mean_y_pred = pred_test_y
        mean_y_pred = torch.mean(samples_pred_test_y, 0).cpu().numpy()
        std_y_pred = torch.std(samples_pred_test_y, 0).cpu().numpy()
        # test_loss = loss(samples_pred_test_y, original_y_test)
        y_pred = torch.mean(samples_pred_test_y, 0)
        test_loss = loss(y_pred, original_y_test)
        print('test_loss: ', test_loss)

        bnn.train()

        plt.clf()
        figure = plt.figure(figsize=(8, 5.5), facecolor='white')
        init_plotting()

        plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
        plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function', linewidth=2.5)  #chocolate
        for i in range(5):
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                             mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')    # wheat
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                             mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
        plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
        # plt.grid(True)
        # plt.tick_params(axis='both', bottom='off', top='off', left='off', right='off',
        #                 labelbottom='off', labeltop='off', labelleft='off', labelright='off')
        ##########sin##############
        # plt.axvline(-2, linestyle='--', color='darkorchid')   # darkorange
        # plt.axvline(2, linestyle='--', color='darkorchid')
        # plt.text(-1.5, -3, 'Training data range')

        ##########gap##############
        # plt.axvline(-7.5, linestyle='--', color='darkorchid')
        # plt.axvline(-2.5, linestyle='--', color='darkorchid')
        # plt.axvline(5, linestyle='--', color='darkorchid')
        # plt.axvline(7.5, linestyle='--', color='darkorchid')
        # plt.text(-7.4, -1, 'Training data range', fontsize='small')
        # plt.text(5.1, -1, 'Training data', fontsize='xx-small')
        # plt.text(5.1, -1.8, 'range', fontsize='xx-small')
        #################g3
        # plt.axvline(-7.5, linestyle='--', color='darkorchid')
        # plt.axvline(-2.5, linestyle='--', color='darkorchid')
        # plt.axvline(5, linestyle='--', color='darkorchid')
        # plt.axvline(7.5, linestyle='--', color='darkorchid')
        # plt.text(-7.4, -1, 'Training data range', fontsize='small')
        # plt.text(5.1, -1, 'Training data', fontsize='xx-small')

        ######g3######
        # plt.axvline(-0.75, linestyle='--', color='darkorchid')
        # plt.axvline(-0.25, linestyle='--', color='darkorchid')
        # plt.axvline(0.25, linestyle='--', color='darkorchid')
        # plt.axvline(0.75, linestyle='--', color='darkorchid')
        # plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
        # plt.text(0.25, 1, 'Training data range', fontsize='small')

        plt.legend(loc='upper center', ncol=3, fontsize='small')
        plt.title('FBNN posterior (Matern kernel)')
        plt.tight_layout()
        plt.ylim([dataset.y_min, dataset.y_max])
        plt.tight_layout()
        print("--------------------------------------------------------------------------------------------")
        plt.savefig(figures_folder + '/plot_epoch{}.pdf'.format(epoch))
        print("figure saved")
        print("--------------------------------------------------------------------------------------------")


log_f.close()

# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("============================================================================================")

def test():
    print('test')
