import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
import gpytorch

from data import uci_woval
from toy import x3_gap_toy, sin_toy
from toy import polynomial_toy, polynomial_gap_toy, polynomial_gap_toy2, g2_toy, g3_toy

from ifbde_std_softplus import *

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

from utils.utils import default_plotting_new as init_plotting

# from utils.utils import device

from torch.func import functional_call, vmap, grad

from sqrtm import sqrtm

device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")


################################### Hyper-parameters ###################################

lr_bnn = 0.01
lr_gp0 = 0.1

prior_coeff = 10
random_seed = 127

save_model_freq = 5000
log_model_freq = 3000

test_interval = 1000
num_sample = 1000
epochs = 2001

n_step_prior_pretraining = 100

PRIOR_PRETRAIN = True
ADD_W2 = True
# num_xmset = 100

################################### Network Architecture ###################################

# num_ensemble = 5
n_units = 10
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'

######################################################################


datasetstring = 'concrete' #'protein' # 'concrete' #
bnn_name_string = 'BDENTK-uci-'+activation_fn+'-'+datasetstring+'-'+str(prior_coeff)+'-'
# bnn_name_string = 'gp-prior-only'


torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic=True
torch.autograd.set_detect_anomaly(True)


print("============================================================================================")
################################## set device ##################################

# # set device to cpu or cuda
# device = torch.device('cpu')
# print((torch.cuda.is_available()))
# if (torch.cuda.is_available()):
#     device = torch.device('cuda:0')
#     torch.cuda.empty_cache()
#     print("Device set to : " + str(torch.cuda.get_device_name(device)))
# else:
#     print("Device set to : cpu")

print("============================================================================================")


################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'sin_tt_0.1_mset' + '/'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)



###################### logging ######################
run_num = 0
#### create new log file for each run
log_f_name = figures_folder + bnn_name_string + '_sin_tt_0.1_mset' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

checkpoint_path = figures_folder + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)


print("============================================================================================")
############################## load and normalize data ##############################


train_x, test_x, train_y, test_y = uci_woval(datasetstring, seed=random_seed)
# print(test_y)

training_num, input_dim = train_x.shape
test_num = test_x.shape[0]
output_dim = 1
is_continuous = True

print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)

original_x_train = torch.from_numpy(train_x).float().to(device)
original_y_train = torch.from_numpy(train_y).float().to(device)
original_x_test = torch.from_numpy(test_x).float().to(device)
original_y_test = torch.from_numpy(test_y).float().to(device)

original_x_train_test = torch.concat((original_x_train, original_x_test))
original_y_train_test = torch.concat((original_y_train, original_y_test))

num_train_test = original_y_train_test.shape[0]

print("num_train_test = ", num_train_test)


print("============================================================================================")
################################## define model ##################################

bde = ifBDE(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, num_ensemble=5).to(device)
# bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(), 'lr': lr_bnn}])
bde_optimizer = torch.optim.Adam([{'params': bde.parameters(), 'lr': lr_bnn}])

w_size = bde.inn_parameter_num
# loss = nn.MSELoss(reduction='mean')

print("============================================================================================")
################################## prior pre-training ##################################

# TODO: pre-train BNN prior

# prior = ifBNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
# prior_optimizer = torch.optim.Adam([{'params': prior.parameters(), 'lr': lr_bnn}])
# prior.eval()

if PRIOR_PRETRAIN:

    # pre-train GP prior
    # data_x_for_prior = original_x_train
    # data_y_for_prior = original_y_train

    data_x_for_prior = original_x_test
    data_y_for_prior = original_y_test

    # idx = torch.randperm(num_train_test)[:10000]
    # data_x_for_prior = original_x_train_test[idx, :]
    # data_y_for_prior = original_y_train_test[idx]

    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    # prior = ExactGPModel(original_x_train, original_y_train, likelihood, input_dim).to(device)
    # prior = ExactGPModel(original_x_test, original_y_test, likelihood, input_dim).to(device)
    prior = ExactGPModel(data_x_for_prior, data_y_for_prior, likelihood, input_dim).to(device)

    # idx = torch.randperm(num_train_test)[:10000]
    # prior = ExactGPModel(original_x_train_test[idx], original_y_train_test[idx], likelihood, input_dim).to(device)

    prior.train()
    likelihood.train()

    # Use the adam optimizer
    gp_optimizer = torch.optim.Adam(prior.parameters(), lr=lr_gp0)  # Includes GaussianLikelihood parameters

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

    for i in range(n_step_prior_pretraining):
        # Zero gradients from previous iteration
        gp_optimizer.zero_grad()
        # Output from model
        output = prior(data_x_for_prior)
        # idx = torch.randperm(num_train_test)[:1000]
        # output = prior(original_x_train_test[idx, :])
        # Calc loss and backprop gradients
        loss_gp = -mll(output, data_y_for_prior)

        gp_optimizer.zero_grad()
        loss_gp.backward()
        print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
            i + 1, n_step_prior_pretraining, loss_gp.item(),
            prior.likelihood.noise.item()
        ))
        gp_optimizer.step()

    prior.eval()
    likelihood.eval()
    sigma = prior.likelihood.noise.item()

    # del data_x_for_prior
    # del data_y_for_prior
    del output
    del gp_optimizer


print("============================================================================================")
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
# scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

def compute_loss(params, buffers, sample, i):
    samples = sample.unsqueeze(0)
    predictions = functional_call(bde.ensemble_list[i], (params, buffers), (samples,))
    loss = predictions.squeeze()
    return loss


def _batch_matrix_pow(matrix: torch.Tensor, p: float) -> torch.Tensor:
    r"""
    Power of a matrix using Eigen Decomposition.
    Args:
        matrix: matrix
        p: power
    Returns:
        Power of a matrix
    """
    tmp = torch.diag_embed(0.01 * torch.rand(matrix.shape[0], matrix.shape[1])).to(device)
    vals, vecs = torch.linalg.eigh(matrix + tmp)
    # vals = torch.view_as_complex(vals.contiguous())
    vals_pow = vals.clamp(min=eps).pow(p)
    vals_pow = vals_pow + 0.0001
    matrix_pow = torch.bmm(vecs, torch.bmm(torch.diag_embed(vals_pow), torch.inverse(vecs)))

    return matrix_pow


def _matrix_pow2(matrix: torch.Tensor, p: float) -> torch.Tensor:
    r"""
    Power of a matrix using Eigen Decomposition.
    Args:
        matrix: matrix
        p: power
    Returns:
        Power of a matrix
    """
    vals, vecs = torch.linalg.eigh(matrix + 0.1 * torch.diag(torch.rand(matrix.shape[0], 1)).to(device))
    vals_pow = vals.clamp(min=eps).pow(p)

    # matrix_pow = torch.matmul(vecs, torch.matmul(torch.diag(vals_pow), torch.inverse(vecs)))
    # inv_matrix_pow = torch.matmul(vecs, torch.matmul(torch.diag(1 / vals_pow + 0.0001), torch.inverse(vecs)))

    matrix_pow = vecs @ torch.diag(vals_pow) @ torch.inverse(vecs)
    inv_matrix_pow = vecs @ torch.diag(1/vals_pow + 0.0001) @ torch.inverse(vecs)

    if torch.isnan(matrix_pow).sum() > 0 or torch.isnan(inv_matrix_pow).sum() > 0:
        print()

    return matrix_pow, inv_matrix_pow

def barycenter(K_posterior_list, K_posterior):

    iter_num = 10
    for i in range(iter_num):
        K_posterior_halfpower, inv_K_posterior_halfpower = _matrix_pow2(K_posterior, 0.5)

        # K_posterior_halfpower = sqrtm(K_posterior)
        # inv_K_posterior_halfpower = torch.inverse(K_posterior_halfpower)

        K_posterior_halfpower_batch = K_posterior_halfpower.unsqueeze(0).repeat(K_posterior_list.shape[0], 1, 1)

        tmp = torch.bmm(torch.bmm(K_posterior_halfpower_batch, K_posterior_list), K_posterior_halfpower_batch)
        tmp = _batch_matrix_pow(tmp, 0.5)
        tmp = torch.sum(tmp, dim=0)
        tmp = tmp @ tmp
        K_posterior = inv_K_posterior_halfpower @ tmp @ inv_K_posterior_halfpower

        if torch.isnan(K_posterior).sum() > 0 or torch.isinf(K_posterior).sum() > 0:
            print()

    return (K_posterior + K_posterior.t()) / 2


loss = nn.MSELoss(reduction='mean')

w2_all = []
likelihood_loss = []
all_loss = []
num_exception = 0

for epoch in range(epochs):

    x_mset_num = 40

    # 1. sample measurement set
    x_mset = sample_measurement_set(original_x_train_test, num_data=x_mset_num) # original_x_train_add

    x_num = 100
    idx = torch.randperm(training_num)
    original_x_train_tmp = original_x_train[idx[:x_num], :]
    original_y_train_tmp = original_y_train[idx[:x_num]]

    #
    x_mset = torch.cat((original_x_train_tmp, x_mset), 0)
    num_xmset = x_mset.shape[0]

    # 3. obtain likelihood gradients
    bde.train()

    train_loss = 0
    likehd_loss = 0

    pred_y = bde.forward(original_x_train_tmp)

    for i in range(bde.num_ensemble):
        likehd_loss = likehd_loss + loss(pred_y[i, :].squeeze().view(-1), original_y_train_tmp)

    likehd_loss = likehd_loss / bde.num_ensemble
    likelihood_loss.append(likehd_loss.item())
    train_loss = train_loss + likehd_loss

    # 4. evaluate w2
    if ADD_W2:
        eps = 1e-5

        # # obtain prior
        prior_marginal = likelihood(prior(x_mset))
        mean_prior = prior_marginal.mean.unsqueeze(1)
        mean_prior = mean_prior.squeeze()
        K_prior = prior_marginal.covariance_matrix

        # # obtain gradients
        pred_y_mset, Wb_mset = bde.forward_wb(x_mset)
        pred_y_train = pred_y_mset[:, :x_num, :]

        # wb_dim_star = w_size - bde.last_layer_wb_size

        #
        grads_x_mset = torch.zeros(num_xmset, w_size).to(device)
        mean_posterior_list = torch.zeros(bde.num_ensemble, num_xmset).to(device)
        K_posterior_list = torch.zeros(bde.num_ensemble, num_xmset, num_xmset).to(device)

        for i in range(bde.num_ensemble):
            #
            params = {k: v.detach() for k, v in bde.ensemble_list[i].named_parameters()}
            buffers = {k: v.detach() for k, v in bde.ensemble_list[i].named_buffers()}
            # params = {k: v for k, v in bnn.named_parameters() if "_mu" in k}
            # buffers = {k: v for k, v in bnn.named_buffers()}

            ft_compute_grad = grad(compute_loss)
            ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, None))
            ft_per_sample_grads = ft_compute_sample_grad(params, buffers, x_mset, i)

            pre_idx = 0
            for key in ft_per_sample_grads:
                ft_grad = ft_per_sample_grads[key].reshape(num_xmset, -1)
                idx = ft_grad.shape[1]
                grads_x_mset[:, pre_idx:(pre_idx + idx)] = ft_grad
                pre_idx = pre_idx + idx

            #
            grads_x_train = grads_x_mset[:x_num, :]

            # pred_y_train_i = pred_y_train[i, :, :]

            #
            K_prior_i = grads_x_mset[:, :-bde.last_layer_wb_size] @ grads_x_mset[:, :-bde.last_layer_wb_size].t()

            # # obtain posterior
            Theta_mset_mset = grads_x_mset @ grads_x_mset.t()
            Theta_mset_train = grads_x_mset @ grads_x_train.t()
            Theta_train_train = grads_x_train @ grads_x_train.t()

            try:
                inv_Theta_train_train = torch.pinverse(Theta_train_train)
            except Exception as e:
                print(e)  # log stack trace
                num_exception = num_exception + 1
                continue

            mean_posterior_list[i, :] = Theta_mset_train @ inv_Theta_train_train @ pred_y_train[i, :, :].squeeze()
            tmp = Theta_mset_mset - Theta_mset_train @ inv_Theta_train_train @ Theta_mset_train.t() - K_prior_i
            K_posterior_list[i, :, :] = (tmp + tmp.t()) / 2

        # calculate the mean
        mean_posterior = torch.mean(mean_posterior_list, dim=0)
        K_posterior = torch.mean(K_posterior_list, dim=0)
        K_posterior = (K_posterior + K_posterior.t()) / 2

        try:
            K_posterior = barycenter(K_posterior_list, K_posterior)
        except Exception as e:
            print(e)  # log stack trace
            num_exception = num_exception + 1
            continue

        #
        # eig_val, eig_vec = torch.linalg.eigh(torch.mm(K_prior, K_posterior) + 0.001 * torch.eye(2*x_num).to(device))
        # eig_val = torch.sqrt(eig_val.clamp(min=eps))
        # tmp = torch.diag(eig_val)
        # tmp2 = torch.mm(torch.mm(eig_vec, tmp), eig_vec.t())
        # hard_trace = torch.trace(tmp2)

        tmp2 = K_prior @ K_posterior
        tmp2 = torch.nn.functional.relu(tmp2) + 1e-4 * torch.randn(tmp2.shape[0], tmp2.shape[0]).abs().to(device)
        tmp2 = (tmp2 + 1e-4) ** 0.5
        hard_trace = torch.trace(tmp2)

        #
        diff = mean_posterior - mean_prior.squeeze()
        mean_term = torch.sum(diff ** 2)

        trace_term = torch.trace(K_prior + K_posterior)

        wdist_objective = mean_term + trace_term - 2 * hard_trace
        # wdist_objective = wdist_objective.clamp(min=eps)
        wdist_objective = wdist_objective.abs()
        wdist_objective = torch.mean(wdist_objective ** 0.5) / num_xmset

        w2_all.append(wdist_objective.item())
    else:
        wdist_objective = torch.zeros(1).to(device)
        w2_all.append(wdist_objective)

    if torch.isnan(wdist_objective) or torch.isinf(wdist_objective):
        num_exception = num_exception + 1
        continue

    train_loss = train_loss + prior_coeff * wdist_objective
    all_loss.append(train_loss.item())

    # print("epoch : {} \t mean_term \t : {} \t trace_term \t : {} \t hard_trace \t : {} \t".format(
    #     epoch, mean_term, trace_term, hard_trace),
    #     datetime.now().replace(microsecond=0) - start_time)

    print("epoch : {} \t training loss \t : {} \t likelihood loss \t : {} \t wdist loss \t : {} \t".format(
        epoch, all_loss[-1], likelihood_loss[-1], w2_all[-1]),
        datetime.now().replace(microsecond=0) - start_time)

    # print(torch.cuda.memory_summary())

    # 5. optimisation
    try:
        bde_optimizer.zero_grad()
        train_loss.backward()  # obtain p.grad
        bde_optimizer.step()
    except Exception as e:
        print(e)  # log stack trace
        num_exception = num_exception + 1
        continue

    # 8. log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # 9. save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bde.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    # 10. test bnn
    if epoch % test_interval == 0:

        bde.eval()

        # mean_y_pred, std_y_pred, _ = bnn.forward_wb_gaus(original_x_test)
        # mean_y_pred = mean_y_pred.detach().squeeze().view(-1)
        # std_y_pred = std_y_pred.detach().squeeze().view(-1)

        samples_pred_test_y = bde.forward_eval(original_x_test, num_sample).squeeze().detach()
        # mean_y_pred = pred_test_y
        mean_y_pred = torch.mean(samples_pred_test_y, 0)
        std_y_pred = torch.std(samples_pred_test_y, 0)

        y_pred = torch.mean(samples_pred_test_y, 0)

        rmse_test_loss = torch.sqrt(loss(y_pred, original_y_test))
        print('rmse_test_loss: ', rmse_test_loss)

        variance_y_pred = torch.std(samples_pred_test_y, 0) ** 2 + sigma
        log_std = torch.log(std_y_pred + 1e-4).sum()
        vec = 0.5 * (original_y_test - mean_y_pred) ** 2 / variance_y_pred
        vec = vec.sum()
        const = 0.5 * np.log(np.pi * 2) + np.log(sigma + 1e-4)
        nll = test_num * const + log_std + vec
        nll_mean = nll / test_num

        print('nll_mean: ', nll_mean)


log_f.close()
print("============================================================================================")


# final test
bde.eval()

# mean_y_pred, std_y_pred, _ = bnn.forward_wb_gaus(original_x_test)
# mean_y_pred = mean_y_pred.detach().squeeze().view(-1)
# std_y_pred = std_y_pred.detach().squeeze().view(-1)

samples_pred_test_y = bde.forward_eval(original_x_test, num_sample).squeeze().detach()
# mean_y_pred = pred_test_y
mean_y_pred = torch.mean(samples_pred_test_y, 0)
std_y_pred = torch.std(samples_pred_test_y, 0)

y_pred = torch.mean(samples_pred_test_y, 0)

rmse_test_loss = torch.sqrt(loss(y_pred, original_y_test))
print('rmse_test_loss: ', rmse_test_loss)

variance_y_pred = torch.std(samples_pred_test_y, 0) ** 2 + sigma
log_std = torch.log(std_y_pred + 1e-4).sum()
vec = 0.5 * (original_y_test - mean_y_pred) ** 2 / variance_y_pred
vec = vec.sum()
const = 0.5 * np.log(np.pi * 2) + np.log(sigma + 1e-4)
nll = test_num * const + log_std + vec
nll_mean = nll / test_num

print('nll_mean: ', nll_mean)
print("============================================================================================")


#
plt.figure()
indices = np.arange(len(all_loss))
loss_all = np.array(all_loss)
plt.plot(indices, loss_all[indices], '-ko', ms=3)
plt.ylabel(r'All loss')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.savefig(figures_folder + '/all_loss.pdf')

plt.figure()
w2_all = np.array(w2_all)
plt.plot(w2_all, 'r-')
plt.title('w2_all')
plt.savefig(figures_folder + '/w2_all.pdf')

plt.figure()
likelihood_loss = np.array(likelihood_loss)
plt.plot(likelihood_loss, 'r-')
plt.title('likelihood_loss')
plt.savefig(figures_folder + '/likelihood_loss.pdf')

# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("Number of exceptions : ", num_exception)
print("============================================================================================")

