import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np

from data import x3_gap_toy, sin_toy
from toy import polynomial_toy, polynomial_gap_toy, polynomial_gap_toy2, g2_toy, g3_toy
from bnn import BNN

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

from utils.utils import default_plotting_new as init_plotting


################################### Hyper-parameters ###################################

lr_bnn = 0.01
prior_coeff = 1
bnn_name_string = 'FWBI'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 2000
num_sample = 1000
epochs = 10001
regu_coeff = 0.01

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 100
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + 'wbnn_g2_1'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'wbnn_g2_1' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + 'wbnn_g2_1'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'wbnn_g2_1'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

dataset = dict(sin=sin_toy, polynomial=polynomial_toy, gap=polynomial_gap_toy, g2=g2_toy, g3=g3_toy)['g2']()
original_x_train, original_y_train = dataset.train_samples()
mean_x, std_x = np.mean(original_x_train), np.std(original_x_train)
mean_y, std_y = np.mean(original_y_train), np.std(original_y_train)
train_x = (original_x_train - mean_x) / std_x
train_y = (original_y_train - mean_y) / std_y
original_x_test, original_y_test = dataset.test_samples()
test_x = (original_x_test - mean_x) / std_x
test_y = (original_y_test - mean_y) / std_y

y_logstd = np.log(dataset.y_std / std_y)

lower_ap = (dataset.x_min - mean_x) / std_x
upper_ap = (dataset.x_max - mean_x) / std_x

training_num, input_dim = np.shape(train_x)
input_dim = 1 #np.array([input_dim])
output_dim = 1 #np.array([1])
is_continuous = True

original_x_train = torch.from_numpy(original_x_train).float().to(device)
original_y_train = torch.from_numpy(original_y_train).float().to(device)
original_x_test = torch.from_numpy(original_x_test).float().to(device)
original_y_test = torch.from_numpy(original_y_test).float().to(device)

loss = nn.MSELoss(reduction='mean')

print("============================================================================================")
################################## define model ##################################

bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(), 'lr': lr_bnn}])

print("============================================================================================")
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)
for epoch in range(epochs):

    # loss
    pred_y, distance_prior = bnn.forward_w(original_x_train)
    # L1regu = regu_coeff / 20 * L1regu
    # pred_y = bnn.forward(original_x_train)
    # y_predicted, distance_prior = bnn.forward_w(train_x)

    pred_y = pred_y.squeeze().flatten()
    # print(pred_y)
    # print(train_y_tmp)
    train_loss = loss(pred_y, original_y_train) + prior_coeff * distance_prior
    # print('wassdist: ', prior_coeff * distance_prior)
    # train_loss = loss(pred_y, original_y_train)

    # separate the nn lose and distance

    # bnn_optimizer.zero_grad()
    # distance_prior.backward()
    # for name, p in bnn.named_parameters():
    #     distance_grad = p.grad
    #
    # train_loss = loss(pred_y, original_y_train)
    # bnn_optimizer.zero_grad()
    # train_loss.backward()
    # for name, p in bnn.named_parameters():
    #     loss_grad = p.grad + prior_coeff * distance_grad
    #     p.grad = loss_grad
    #
    # bnn_optimizer.step()
    # scheduler.step()


    # optimisation
    bnn_optimizer.zero_grad()
    train_loss.backward()
    bnn_optimizer.step()
    scheduler.step()

    # idx_start = 0
    # i = 0
    # for name, p in bnn.named_parameters():
    #     if "mu" in name:
    #         print(p.grad.max(), p.grad.min(), p.grad.abs().mean())
    #         i += 1

    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
          datetime.now().replace(microsecond=0) - start_time)

    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    if epoch % test_interval == 0:

        bnn.eval()

        samples_pred_test_y = bnn.forward_eval(original_x_test, num_sample).squeeze().detach()

        # mean_y_pred = pred_test_y
        mean_y_pred = torch.mean(samples_pred_test_y, 0).cpu().numpy()
        std_y_pred = torch.std(samples_pred_test_y, 0).cpu().numpy()
        test_loss = loss(samples_pred_test_y, original_y_test)

        bnn.train()

        plt.clf()
        figure = plt.figure(figsize=(8, 5.5), facecolor='white')
        init_plotting()

        plt.plot(original_x_test.squeeze().cpu(), original_y_test.cpu(), 'g', label="True function")
        plt.plot(original_x_test.squeeze().cpu(), mean_y_pred, 'slateblue', label='Mean function', linewidth=2.5)  # chocolate
        for i in range(5):
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred - i * 1 * std_y_pred,
                             mean_y_pred - (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')  # wheat
            plt.fill_between(original_x_test.squeeze().cpu(), mean_y_pred + i * 1 * std_y_pred,
                             mean_y_pred + (i + 1) * 1 * std_y_pred, linewidth=0.0,
                             alpha=1.0 - i * 0.15, color='orchid')
        plt.plot(original_x_train.cpu(), original_y_train.cpu(), 'ok', zorder=10, ms=6, label='Observations')
        # plt.grid(True)
        # plt.tick_params(axis='both', bottom='off', top='off', left='off', right='off',
        #                 labelbottom='off', labeltop='off', labelleft='off', labelright='off')
        ############sin#########
        # plt.axvline(-2, linestyle='--', color='darkorchid')   # darkorange
        # plt.axvline(2, linestyle='--', color='darkorchid')
        # plt.text(-1.5, -3, 'Training data range')
        #######gap
        # plt.axvline(-7.5, linestyle='--', color='darkorchid')
        # plt.axvline(-2.5, linestyle='--', color='darkorchid')
        # plt.axvline(5, linestyle='--', color='darkorchid')
        # plt.axvline(7.5, linestyle='--', color='darkorchid')
        # plt.text(-7.4, -1, 'Training data range', fontsize='small')
        # plt.text(5.1, -1, 'Training data', fontsize='xx-small')
        # plt.text(5.1, -1.8, 'range', fontsize='xx-small')

        ######g3######
        # plt.axvline(-0.75, linestyle='--', color='darkorchid')
        # plt.axvline(-0.25, linestyle='--', color='darkorchid')
        # plt.axvline(0.25, linestyle='--', color='darkorchid')
        # plt.axvline(0.75, linestyle='--', color='darkorchid')
        # plt.text(-0.75, -1.5, 'Training data range', fontsize='small')
        # plt.text(0.25, 1, 'Training data range', fontsize='small')


        plt.legend(loc='upper center', ncol=3, fontsize='small')
        plt.title('WBBB posterior')
        plt.tight_layout()
        plt.ylim([dataset.y_min, dataset.y_max])
        plt.tight_layout()
        print("--------------------------------------------------------------------------------------------")
        plt.savefig(figures_folder + '/plot_epoch{}.pdf'.format(epoch))
        print("figure saved")
        print("epoch : {} \t\t test loss \t\t : {}".format(epoch, test_loss))
        print("--------------------------------------------------------------------------------------------")


log_f.close()

# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("============================================================================================")

def test():
    print('test')
