import os
import sys
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np

from bnn import BNN
import gpytorch
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gp_prior import ExactGPModel, prior_sample_functions, DirichletGPModel

import argparse

from utils.logging import get_logger

from data import uci_woval
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets
from sklearn.metrics import roc_auc_score
from sklearn import metrics
import pandas as pd

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt


################################### Hyper-parameters ###################################

lr_bnn = 1e-2
prior_coeff = 10
bnn_name_string = 'image_ifbnn'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 2000
num_sample = 1000
epochs = 601
n_step_prior_pretraining = 500
lr_optbnn = 0.01
f_coeff = 10

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 800
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * 255. / 126.),  # divide as in paper
        ])

transform2 = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

train_data = datasets.MNIST(
            root='data',
            train=True,
            download=False,
            transform=transform)

test_data = datasets.MNIST(
            root='data',
            train=False,
            download=False,
            transform=transform)


train_data_fm = datasets.FashionMNIST(
            root='data2',
            train=True,
            download=False,
            transform=transform)

test_data_fm = datasets.FashionMNIST(
            root='data',
            train=False,
            download=False,
            transform=transform)

train_data_cf = datasets.CIFAR10(
            root='data',
            train=True,
            download=False,
            transform=transform2)

test_data_cf = datasets.CIFAR10(
            root='data',
            train=False,
            download=False,
            transform=transform2)
train_data_sv = datasets.SVHN(
            root='data',
            split='train',
            download=False,
            transform=transform2)

training_num = len(train_data_cf)
test_num = len(test_data_cf)
# input_dim = 28 * 28
input_dim = 32 * 32 * 3
output_dim = 10
is_continuous = True
batch_size = 125
eval_batch_size = 1000
m = int(round(batch_size ** 0.5))

# obtain training indices that will be used for validation
valid_size = 1 / 5
indices = list(range(training_num))
split = int(valid_size * training_num)
train_idx, valid_idx = indices[split:], indices[:split]
oob_in_idx = train_idx[:10000]
train_idx = train_idx[30000:40000]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
oob_in_sampler = SubsetRandomSampler(oob_in_idx)

# miniset of training data and test data
mini_train_size = 1 / 12
mini_train_index = int(mini_train_size * training_num)
mini_train_idx = indices[:mini_train_index]
mini_train_sampler = SubsetRandomSampler(mini_train_idx)

indices_test = list(range(test_num))
mini_test_size = 1 / 2
mini_test_index = int(mini_test_size * test_num)
mini_test_idx = indices[:mini_test_index]
mini_test_sampler = SubsetRandomSampler(mini_test_idx)

######### define data loader ############

train_loader = torch.utils.data.DataLoader(
        train_data_cf,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=1)
valid_loader = torch.utils.data.DataLoader(
    train_data_cf,
    batch_size=batch_size,
    sampler=valid_sampler,
    num_workers=1)
test_loader = torch.utils.data.DataLoader(
        test_data_cf,
        batch_size=eval_batch_size,
        num_workers=1)
pre_loader = torch.utils.data.DataLoader(
        test_data_cf,
        batch_size=10000,
        num_workers=1)

oob_out_loader = torch.utils.data.DataLoader(
    train_data_sv,
    batch_size=batch_size,
    sampler=oob_in_sampler,
    num_workers=1)
oob_in_loader = torch.utils.data.DataLoader(
    train_data_cf,
    batch_size=batch_size,
    sampler=oob_in_sampler,
    num_workers=1)

mini_train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    sampler=mini_train_sampler,
    num_workers=1)

mini_test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=eval_batch_size,
    sampler=mini_test_sampler,
    num_workers=1)

print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)



loss = nn.MSELoss(reduction='mean')


print("============================================================================================")
################################## define model ##################################

bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
opt_bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(),  'lr': lr_bnn}])
opt_bnn_optimizer = torch.optim.RMSprop([{'params': opt_bnn.parameters(), 'lr': lr_optbnn}])

print("============================================================================================")
################################## pretrain GP ##################################


for batch_id, (data, target) in enumerate(pre_loader):
    train_x, train_y = data.to(device), target.to(device)

# print('train_x.shape: ', train_x.shape)
# print('train_y.shape: ', train_y.shape)
# train_x = train_x.squeeze()
# train_x = torch.reshape(train_x, (10000, 784))
train_x = torch.reshape(train_x, (10000, 3072))
mask = torch.randperm(10000)[:1000]
train_x = train_x[mask, :]
train_y = train_y[mask]
likelihood = DirichletClassificationLikelihood(train_y, learn_additional_noise=True)
prior = DirichletGPModel(train_x, likelihood.transformed_targets, likelihood, num_classes=likelihood.num_classes).to(device)
print('likelihood.num_classes: ', likelihood.num_classes)
prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(train_x)
    # pred_y = likelihood(prior(train_x)).mean
    # print('pred_y.shape: ', pred_y.shape)
    # Calc loss and backprop gradients
    # print('train_x.shape: ', train_x.shape)
    # print('train_y.shape: ', train_y.shape)
    # print('likelihood.transformed_targets.shape: ', likelihood.transformed_targets.shape)
    loss_gp = -mll(output, likelihood.transformed_targets).sum()
    loss_gp.backward()
    # print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))
    print('Iter %d/%d - Loss: %.3f lengthscale: %.3f  noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        prior.covar_module.base_kernel.lengthscale.mean().item(),
        prior.likelihood.second_noise_covar.noise.mean().item()
    ))

    optimizer.step()

prior.eval()
likelihood.eval()

k = prior.covar_module
sigma = 0.03

print("============================================================================================")
######### def train and evalution #############################

def train(loader):
    loss_sum = 0

    for batch_id, (data, target) in enumerate(loader):

        data, target = data.to(device), target.to(device)
        # data = torch.reshape(data, (125, 784))
        data = torch.reshape(data, (125, 3072))

        z_mask = torch.randperm(batch_size)[:m]
        z_prime_mask = torch.randperm(batch_size)[:30]
        Z = data[z_mask, :].to(device)
        Z_prime = data[z_prime_mask, :].to(device)
        eye = torch.eye(m).to(device)

        with torch.no_grad():
            kxx = k(data).evaluate()
            # print('kxx.shape: ', kxx.shape)
            kzx = k(Z, data).evaluate()
            # print('kzx.shape: ', kzx.shape)
            kzpx = k(Z_prime, data).evaluate()
            kzzp = k(Z, Z_prime).evaluate()
            kzz = k(Z).evaluate()
            # print('kzz.shape: ', kzz.shape)
            kzx_t = torch.transpose(kzx, -1, 1)
            chol_L = torch.linalg.cholesky(kzz + kzx @ kzx_t / sigma + 1e-2 * eye)
            # inv_L = torch.inverse(kzz + kzx @ kzx.t() / sigma + 1e-2 * eye)
            L = torch.linalg.cholesky(torch.cholesky_inverse(chol_L))
            # L = torch.linalg.cholesky(inv_L)
            L = torch.nn.Parameter(L)
            L_t = torch.transpose(L, -1, 1)

        t = L_t @ kzx
        t_t = torch.transpose(t, -1, 1)
        T_mat = t_t @ t

        t2 = L_t @ kzzp
        t2_t = torch.transpose(t2, -1, 1)
        T_mat2 = t2_t @ t

        chol_z = torch.linalg.cholesky(kzz + 1e-2 * eye)
        sol = torch.cholesky_solve(kzx, chol_z)

        rxx = kxx - kzx_t @ sol + T_mat

        kzzp_t = torch.transpose(kzzp, -1, 1)
        rzpx = kzpx - kzzp_t @ sol + T_mat2

        #####
        bnn_loss = bnn.forward(data)
        # print('bnn.shape: ', bnn_loss.shape)
        prior_marginal = prior(data)
        m_p_x = 0. * prior_marginal.mean
        m_p_x = m_p_x.t()
        # print('m_p_x.shape: ', m_p_x.shape)
        m_q_x = m_p_x + bnn_loss.squeeze()

        const = 0.5 * np.log(np.pi * 2) + np.log(sigma)
        pred_y = m_q_x
        pred_y = F.softmax(pred_y, dim=1)
        # print(pred_y.shape)
        # print(target.shape)
        output = torch.log(pred_y)
        vec = F.nll_loss(output, target, reduction='sum')
        # vec = (target - pred_y) ** 2
        vec = torch.sum(vec)
        # print('rxx.shape: ', rxx.shape)
        # print('kxx.shape: ', kxx.shape)
        # r_trace = torch.trace(rxx)
        # k_trace = torch.trace(kxx)
        r_trace = rxx.diagonal(offset=0, dim1=-2, dim2=-1).sum(dim=-1)
        r_trace = torch.sum(r_trace)
        k_trace = kxx.diagonal(offset=0, dim1=-2, dim2=-1).sum(dim=-1)
        k_trace = torch.sum(k_trace)

        likelihood = batch_size * const + (vec + r_trace) / 2. * sigma
        reg = loss(m_q_x, m_p_x)

        # calculate hard trace
        big_eye = 30. * torch.eye(30).to(device)

        kzpx_t = torch.transpose(kzpx, -1, 1)
        rk_hat = rzpx @ kzpx_t
        # print('rk_hat.shape: ', rk_hat.shape)
        eigs = torch.linalg.eigvals(rk_hat + big_eye)
        eigs = eigs.abs()
        eigs = eigs - big_eye.diag()
        eigs = eigs[eigs > 0]
        hard_trace = torch.sum(eigs ** 0.5)

        w2 = reg + k_trace / batch_size + r_trace / batch_size - 2 / ((30 * batch_size) ** 0.5) * hard_trace

        train_loss = likelihood + prior_coeff * w2
        ### calculate bnn likelihood loss ###
        # pred_y = bnn.forward(data)
        # # print('prior_distance: ', distance_prior)
        # pred_y = F.softmax(pred_y, dim=1)
        # # print(pred_y.shape)
        # # print(target.shape)
        # output = torch.log(pred_y)
        # train_loss = F.nll_loss(output, target, reduction='sum') + prior_coeff * bnn_2wassdist + f_coeff * functional_wdist
        # print(pred_y)

        # likelihood_loss = F.nll_loss(output, target, reduction='sum')
        # print('batch_id: ', batch_id)
        # print('likelihood_loss: ', likelihood_loss)
        # print('distance_prior: ', bnn_2wassdist)
        # print('w_dist: ', functional_wdist)

        loss_sum += train_loss / len(loader)
        # optimisation
        bnn_optimizer.zero_grad()
        # krr_optimizer.zero_grad()
        train_loss.backward(retain_graph=True)
        bnn_optimizer.step()
    return loss_sum

def evaluate(model, loader, samples=1):
    acc_sum = 0
    for idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)

        if samples == 1:
            output = model.forward(data)
            output = F.softmax(output, dim=1)
        else:
            output = model.forward(data)
            output = F.softmax(output, dim=1)
            for i in range(samples - 1):
                output += F.softmax(model.forward(data))

        predict = output.data.max(1)[1]
        acc = predict.eq(target.data).cpu().sum().item()
        acc_sum += acc
    return acc_sum / len(loader)


################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')

train_loss_all = []
prior_distance_all = []
wdist_all = []
likelihood_loss_all = []
w1_bnn_optbnn = []
for epoch in range(epochs):
    train_loss = train(train_loader)
    # print('train_loss: ', train_loss)
    test_acc = evaluate(bnn, test_loader, samples=10)
    # print('test_acc: ', round(100 * (test_acc / eval_batch_size), 3))
    test_err = round((1 - test_acc / eval_batch_size), 4)
    # print('test_err: ', test_err)


    # print
    print("epoch : {} \t\t training loss \t\t : {} \t\t test_error \t\t : {}".format(epoch, train_loss,
                                                                                         test_err),
          datetime.now().replace(microsecond=0) - start_time)


    # log in logging file
    if epoch % log_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        log_f.write('{},{}\n'.format(epoch, train_loss))
        log_f.flush()
        print("log saved")
        print("--------------------------------------------------------------------------------------------")

    # save model weights
    if epoch % save_model_freq == 0:
        print("--------------------------------------------------------------------------------------------")
        print("saving model at : " + checkpoint_path)
        bnn.save(checkpoint_path)
        print("model saved")
        print("--------------------------------------------------------------------------------------------")

    # if epoch % test_interval == 0:
    #
    #     bnn.eval()
    #
    #     samples_pred_test_y = bnn.forward_eval(original_x_test, num_sample).squeeze().detach()
    #
    #     mse_test_loss = loss(samples_pred_test_y, original_y_test)
    #
    #     bnn.train()

######## OOD_test ###############
log_upper_bound = np.log(output_dim)

entropies_in_sample=[]

for batch_id, (data, target) in enumerate(oob_in_loader):
    data, target = data.to(device), target.to(device)
    # data = torch.reshape(data, (125, 784))
    data = torch.reshape(data, (125, 3072))
    with torch.no_grad():
        p=bnn.forward(data)
        p=F.softmax(p, dim=1)
        e = torch.sum(- p * torch.log(p), dim=1)
        e = torch.nan_to_num(e, nan=0.0)
        # print(e)
        entropies_in_sample.append(e)

entropies_in_sample = torch.cat(entropies_in_sample, dim=0)
true_lables_in_sample = torch.zeros_like(entropies_in_sample)

entropies_out_sample = []
for batch_id, (data, target) in enumerate(oob_out_loader):
    data, target = data.to(device), target.to(device)
    # data = torch.reshape(data, (125, 784))
    data = torch.reshape(data, (125, 3072))
    with torch.no_grad():
        p=bnn.forward(data)
        p = F.softmax(p, dim=1)
        e = torch.sum(- p * torch.log(p), dim=1)
        # print(e)
        e = torch.nan_to_num(e, nan=0.0)
        entropies_out_sample.append(e)
entropies_out_sample = torch.cat(entropies_out_sample, dim=0)
true_lables_out_sample = torch.ones_like(entropies_out_sample)

in_sample_preds = entropies_in_sample/log_upper_bound
out_sample_preds = entropies_out_sample/log_upper_bound
preds = torch.cat([in_sample_preds, out_sample_preds], dim=0).cpu().numpy()
truths = torch.cat([true_lables_in_sample, true_lables_out_sample], dim=0).cpu().numpy()
AUC = roc_auc_score(truths, preds)
print('auc: ', AUC)

# fpr, tpr, thresholds = metrics.roc_curve(truths, preds)
# fpr_gwi_fm = pd.DataFrame.from_dict(fpr)
# fpr_gwi_fm.to_csv(results_folder + '/' + bnn_name_string + '/' + 'fpr_gwi_fm.csv')
# tpr_gwi_fm = pd.DataFrame.from_dict(tpr)
# tpr_gwi_fm.to_csv(results_folder + '/' + bnn_name_string + '/' + 'tpr_gwi_fm.csv')
#
# roc_auc = metrics.auc(fpr, tpr)
# display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name='gwi')
# display.plot()
# plt.savefig(figures_folder + '/oob_gwi_fm.pdf')


log_f.close()


# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("Test err  : ", test_err)
print("============================================================================================")
