import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
import gpytorch
from gp_prior import ExactGPModel, prior_sample_functions
import math

from bnn import BNN
import torch.nn.functional as F
import argparse

from utils.logging import get_logger

from data import uci_woval

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt


################################### Hyper-parameters ###################################

lr_bnn = 0.001
prior_coeff = 10
bnn_name_string = 'gwi'
uci_dataset_name_string = 'boston' #'protein'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 1000
num_sample = 1000
epochs = 2001
n_step_prior_pretraining = 100

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 10
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + uci_dataset_name_string + '/'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + uci_dataset_name_string + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + uci_dataset_name_string + '/'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + uci_dataset_name_string + '/'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

# dataset = uci_woval(uci_dataset_name_string, seed=random_seed)
# train_x, test_x, train_y, test_y = dataset.x_train, dataset.x_test, dataset.y_train, dataset.y_test

train_x, test_x, train_y, test_y = uci_woval(uci_dataset_name_string, seed=random_seed)
# print(test_y)

training_num, input_dim = train_x.shape
m = int(round(training_num ** 0.5))

test_num = test_x.shape[0]
output_dim = 1
is_continuous = True

print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)

original_x_train = torch.from_numpy(train_x).float().to(device)
original_y_train = torch.from_numpy(train_y).float().to(device)
original_x_test = torch.from_numpy(test_x).float().to(device)
original_y_test = torch.from_numpy(test_y).float().to(device)

print('original_y_test.shape: ', original_y_test.shape)


loss = nn.MSELoss(reduction='mean')


print("============================================================================================")
################################## define model ##################################


bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(),  'lr': lr_bnn}])


print("============================================================================================")
################################## prior pre-training ##################################

# pre-train GP prior

likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior = ExactGPModel(original_x_test, original_y_test, likelihood, input_dim).to(device)

prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(original_x_test)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, original_y_test)
    loss_gp.backward()
    print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))
    # print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     # prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))

    optimizer.step()

prior.eval()
likelihood.eval()

prior_marginal = prior(original_x_train)
# m_p_x = prior_marginal.mean
m_p_x = 0. * prior_marginal.mean
k = prior.covar_module
sigma = prior.likelihood.noise.item()


print("============================================================================================")
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)


# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

z_mask = torch.randperm(training_num)[:m]
z_prime_mask = torch.randperm(training_num)[:100]
Z = original_x_train[z_mask, :].to(device)
Z_prime = original_x_train[z_prime_mask, :].to(device)
eye = torch.eye(m).to(device)

with torch.no_grad():
    kxx = k(original_x_train).evaluate()
    print('kxx.shape: ', kxx.shape)
    kzx = k(Z, original_x_train).evaluate()
    print('kzx.shape: ', kzx.shape)
    kzpx = k(Z_prime, original_x_train).evaluate()
    kzzp = k(Z, Z_prime).evaluate()
    kzz = k(Z).evaluate()
    print('kzz.shape: ', kzz.shape)
    chol_L = torch.linalg.cholesky(kzz + kzx @ kzx.t() / sigma + 1e-2 * eye)
    # inv_L = torch.inverse(kzz + kzx @ kzx.t() / sigma + 1e-2 * eye)
    L = torch.linalg.cholesky(torch.cholesky_inverse(chol_L))
    # L = torch.linalg.cholesky(inv_L)
    L = torch.nn.Parameter(L)

t = L.t() @ kzx
T_mat = t.t() @ t

t2 = L.t() @ kzzp
T_mat2 = t2.t() @ t

chol_z = torch.linalg.cholesky(kzz + 1e-2 * eye)
sol = torch.cholesky_solve(kzx, chol_z)

rxx = kxx - kzx.t() @ sol + T_mat

rzpx = kzpx - kzzp.t() @ sol + T_mat2

###########
class KRR_mean(torch.nn.Module):
    def __init__(self, m, k):
        super(KRR_mean, self).__init__()
        beta_tmp = torch.randn(m, 1)
        self.k = k
        self.beta = nn.Parameter(beta_tmp, requires_grad=True)

    def forward(self, X):
        m_q = self.beta * self.k(Z, X).evaluate()
        m_q = torch.sum(m_q, 0)
        return m_p_x + m_q

krr = KRR_mean(m, k).to(device)
krr_optimizer = torch.optim.Adam([{'params': krr.parameters(), 'lr': lr_bnn}])

for epoch in range(epochs):
    ############### define q_kernel ############



    ############## define m_q ################
    # GWI: DNN-SVGP
    # m_q_x = m_p_x + bnn.forward(original_x_train)
    bnn_loss = bnn.forward(original_x_train)
    # print('bnn.shape: ', bnn_loss.shape)
    m_q_x = m_p_x + bnn_loss.squeeze()

    # GWI: SVGP
    # class KRR_mean(torch.nn.Module):
    #     def __int__(self):
    #         super(KRR_mean, self).__init__()
    #         self.beta = torch.nn.Parameter(torch.randn(m, 1))
    #
    #     def forward(self):
    #         m_q = self.beta * kzx
    #         m_q = torch.sum(m_q, 0)
    #         return m_p_x + m_q


    ############## define loss function #############

    # def likelihood and reg
    const = 0.5 * np.log(np.pi * 2) + np.log(sigma)
    pred_y = m_q_x
    # m_q_x = krr.forward(original_x_train)
    # pred_y = m_q_x
    # pred_y = pred_y.squeeze().flatten()
    vec = (original_y_train - pred_y) ** 2
    vec = torch.sum(vec)
    r_trace = torch.trace(rxx)
    k_trace = torch.trace(kxx)

    likelihood = training_num * const + (vec + r_trace) / 2. * sigma
    reg = loss(m_q_x, m_p_x)

    # calculate hard trace
    big_eye = 100. * torch.eye(100).to(device)

    rk_hat = rzpx @ kzpx.t()
    # print('rk_hat.shape: ', rk_hat.shape)
    eigs = torch.linalg.eigvals(rk_hat + big_eye)
    eigs = eigs.abs()
    eigs = eigs - big_eye.diag()
    eigs = eigs[eigs > 0]
    hard_trace = torch.sum(eigs ** 0.5)

    w2 = reg + k_trace / training_num + r_trace / training_num - 2 / (10 * training_num ** 0.5) * hard_trace

    train_loss = likelihood + prior_coeff * w2

    # optimisation
    bnn_optimizer.zero_grad()
    # krr_optimizer.zero_grad()
    train_loss.backward(retain_graph=True)
    bnn_optimizer.step()
    # krr_optimizer.step()

    # print
    print("epoch : {} \t\t training loss \t\t : {} \t\t distance_prior \t\t : {}".format(epoch, train_loss,
                                                                                         w2),
          datetime.now().replace(microsecond=0) - start_time)

    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    if epoch % test_interval == 0:
        bnn.eval()

        samples_pred_test_y = bnn.forward_eval(original_x_test, num_sample).squeeze().detach()

        mse_test_loss = loss(samples_pred_test_y, original_y_test)

        print('mse_test_loss: ', mse_test_loss)

        bnn.train()

        mean_y_pred = torch.mean(samples_pred_test_y, 0)
        std_y_pred = torch.std(samples_pred_test_y, 0)
        variance_y_pred = torch.std(samples_pred_test_y, 0) ** 2 + sigma
        log_std = torch.log(std_y_pred).sum()
        vec = 0.5 * (original_y_test - mean_y_pred) ** 2 / variance_y_pred
        vec = vec.sum()
        const = 0.5 * np.log(np.pi * 2) + np.log(sigma)
        nll = test_num * const + log_std + vec
        nll_mean = nll / test_num

        print('nll_mean: ', nll_mean)

log_f.close()

# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("Test loss  : ", mse_test_loss)
print("============================================================================================")


def test():
    print('test')










