import os
from datetime import datetime
import sys
import math
import pandas as pd
from sklearn import preprocessing
from torch.autograd import Variable

import torch
import torch.nn as nn
import numpy as np
import gpytorch
from gp_prior import ExactGPModel, prior_sample_functions, prior_sample_functions2


from bnn import BNN, OPTBNN
import torch.nn.functional as F
import argparse

from utils.logging import get_logger

from data import uci_woval

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from utils.utils import default_plotting_new as init_plotting


################################### Hyper-parameters ###################################

lr_bnn = 0.01
prior_coeff = 1
bnn_name_string = 'UAI_toy'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 1000
num_sample = 1000
epochs = 10001
n_step_prior_pretraining = 100
lr_optbnn = 0.01
f_coeff = 10

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 100
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'

print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + 'FWBI' + '/'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'FWBI' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + 'FWBI' + '/'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'FWBI' + '/'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

# training_num, input_dim = train_x.shape
# test_num = test_x.shape[0]
# output_dim = 1
# is_continuous = True
#
# print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)
#
# original_x_train = torch.from_numpy(train_x).float().to(device)
# original_y_train = torch.from_numpy(train_y).float().to(device)
# original_x_test = torch.from_numpy(test_x).float().to(device)
# original_y_test = torch.from_numpy(test_y).float().to(device)
current_dir = os.getcwd()
print(current_dir)
root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_path)
data_path = os.path.join(root_path, 'data', 'uci', 'mushroom.data')
df = pd.read_csv(os.getcwd() + '/data/uci/mushroom.data', header=None)

# Set pandas to output all of the columns in output
df.columns = ['class', 'cap-shape','cap-surface','cap-color','bruises','odor','gill-attachment',
         'gill-spacing','gill-size','gill-color','stalk-shape','stalk-root',
         'stalk-surf-above-ring','stalk-surf-below-ring','stalk-color-above-ring','stalk-color-below-ring',
         'veil-type','veil-color','ring-number','ring-type','spore-color','population','habitat']

# Split context from label
X = pd.DataFrame(df, columns=df.columns[1:len(df.columns)], index=df.index)
# Put the class values (0th column) into Y
Y = df['class']

# Transform labels into one-hot encoded array
le = preprocessing.LabelEncoder()
le.fit(Y)
y = le.transform(Y)

# Temporary variable to avoid error
x_tmp = pd.DataFrame(X, columns=[X.columns[0]])

# Encode each feature column and add it to x_train
for colname in X.columns:
    le.fit(X[colname])
    #print(colname, le.classes_)
    x_tmp[colname] = le.transform(X[colname])

# Produce mushroom array: 8124 mushrooms, each with 117 one-hot encoded features
oh = preprocessing.OneHotEncoder()
oh.fit(x_tmp)
x = oh.transform(x_tmp).toarray()

training_num, input_dim = x.shape
output_dim = 1
is_continuous = True
print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)

original_x_train = torch.from_numpy(x).float().to(device)
original_y_train = torch.from_numpy(y).float().to(device)


loss = nn.MSELoss(reduction='mean')

batch_num = 64
# m = int(round(batch_num ** 0.5))


def init_buffer_gp():
    bufferx = []
    buffery = []
    for i in np.random.choice(range(len(x)), 1000):
        # eat = np.random.rand() > 0.5
        # action = np.array([1, 0] if eat else [0, 1])
        context, edible = x[i], y[i]
        action = np.array([1, 0] if edible==1 else [0, 1])
        oracle = oracle_reward(edible)
        bufferx.append(np.concatenate((x[i], action)))
        buffery.append(oracle)
    train_x = torch.Tensor([bufferx[i] for i in range(1000)]).to(device)
    train_y = torch.Tensor([buffery[i] for i in range(1000)]).to(device)
    return train_x, train_y


print("============================================================================================")
################################## define model ##################################


# bnn = fBNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
#
# bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(),  'lr': lr_bnn}])


print("============================================================================================")

############ define reward function####################
def get_reward(eaten, edible):
    # REWARDS FOR AGENT
    #  Eat poisonous mushroom
    if not eaten:
        return 0
    if eaten and edible:
        return 5
    elif eaten and not edible:
        return 5 if np.random.rand() > 0.4 else -35

def oracle_reward(edible):
    return 5*edible

################ define mushroom net######################
Var = lambda x, dtype=torch.FloatTensor: Variable(
    torch.from_numpy(x).type(dtype)).to(device)
class MushroomNet():
    def __init__(self, n_weight_sampling=2):
        # self.label = label
        self.n_weight_sampling = n_weight_sampling
        self.epsilon = 0
        self.net = None
        self.loss, self.optimizer = None, None
        self.cum_regrets = [0]
        self.bufferX, self.bufferY = [], []
        self.bufferx, self.buffery = [], []

    def init_buffer(self):
        for i in np.random.choice(range(len(x)), 4096):
            eat = np.random.rand() > 0.5
            action = np.array([1, 0] if eat else [0, 1])
            self.bufferx.append(np.concatenate((x[i], action)))
            self.buffery.append(get_reward(eat, y[i]))
        train_x = torch.Tensor([self.bufferx[i] for i in range(4096)]).to(device)
        train_y = torch.Tensor([self.buffery[i] for i in range(4096)]).to(device)
        return train_x, train_y

    # Use NN to decide next action
    def try_(self, mushroom):
        samples = self.n_weight_sampling
        context, edible = x[mushroom], y[mushroom]
        try_eat = Var(np.concatenate((context, [1, 0])))
        try_reject = Var(np.concatenate((context, [0, 1])))

        # Calculate rewards using model
        with torch.no_grad():
            # r_eat = sum([self.net.forward(try_eat) for _ in range(samples)]).item()
            # r_reject = sum([self.net.forward(try_reject) for _ in range(samples)]).item()
            r_eat = self.net.forward(try_eat)
            r_reject = self.net.forward(try_reject)

        # Take random action for epsilon greedy agents, calculate agent's reward
        eaten = r_eat > r_reject
        # if np.random.rand() < self.epsilon:
        #     eaten = (np.random.rand() < .5)
        agent_reward = get_reward(eaten, edible)

        # Get rewards and update buffer
        action = np.array([1, 0] if eaten else [0, 1])
        self.bufferX.append(np.concatenate((context, action)))
        self.bufferY.append(agent_reward)

        # Calculate regret
        oracle = oracle_reward(edible)
        regret = oracle - agent_reward
        self.cum_regrets.append(self.cum_regrets[-1] + regret)

    def try_2(self, mushroom):
        samples = self.n_weight_sampling
        context, edible = x[mushroom], y[mushroom]
        try_eat = Var(np.concatenate((context, [1, 0])))
        try_reject = Var(np.concatenate((context, [0, 1])))
        try_eat = try_eat[None, :]
        try_reject = try_reject[None, :]
        print('try_eat.shape: ', try_eat.shape)

        # Calculate rewards using model
        self.net.eval()
        self.likelihood.eval()
        with torch.no_grad():
            # r_eat = sum([self.net.forward(try_eat) for _ in range(samples)]).item()
            # r_reject = sum([self.net.forward(try_reject) for _ in range(samples)]).item()
            # self.net.set_train_data(try_eat)

            r_eat = self.net.forward(try_eat).mean
            # self.net.set_train_data(try_reject)
            r_reject = self.net.forward(try_reject).mean

        # Take random action for epsilon greedy agents, calculate agent's reward
        eaten = r_eat > r_reject
        # if np.random.rand() < self.epsilon:
        #     eaten = (np.random.rand() < .5)
        agent_reward = get_reward(eaten, edible)

        # Get rewards and update buffer
        action = np.array([1, 0] if eaten else [0, 1])
        self.bufferX.append(np.concatenate((context, action)))
        self.bufferY.append(agent_reward)

        # Calculate regret
        oracle = oracle_reward(edible)
        regret = oracle - agent_reward
        self.cum_regrets.append(self.cum_regrets[-1] + regret)

    # Feed next mushroom
    def update(self, mushroom):
        self.try_(mushroom)
        # idx pool
        l = len(self.bufferX)
        print('len(bufferX): ', l)
        idx_pool = range(l) if l >= 4096 else ((int(4096 // l) + 1) *
                                               list(range(l)))
        idx_pool = np.random.permutation(idx_pool[-4096:])
        context_pool = torch.Tensor([self.bufferX[i] for i in idx_pool]).to(device)
        value_pool = torch.Tensor([self.bufferY[i] for i in idx_pool]).to(device)
        # print(type(context_pool), type(value_pool))
        # print('context_pool.shape: ', context_pool.shape)
        # print('value_pool.shape: ', value_pool.shape)
        for i in range(0, 4096, 64):
            self.loss_step(context_pool[i:i + 64], value_pool[i:i + 64])
        # self.loss_step(context_pool, value_pool)

    def update2(self, mushroom):
        self.try_(mushroom)
        l = len(self.bufferX)
        # idx_pool = range(l)
        idx_pool = range(l) if l >= 4096 else ((int(4096 // l) + 1) *
                                               list(range(l)))
        idx_pool = np.random.permutation(idx_pool)
        context_pool = torch.Tensor([self.bufferX[i] for i in idx_pool]).to(device)
        value_pool = torch.Tensor([self.bufferY[i] for i in idx_pool]).to(device)
        self.loss_step(context_pool, value_pool)

    def update_gp(self, mushroom):
        self.try_2(mushroom)
        # idx pool
        l = len(self.bufferX)
        print('len(bufferX): ', l)
        idx_pool = range(l) if l >= 4096 else ((int(4096 // l) + 1) *
                                               list(range(l)))
        idx_pool = np.random.permutation(idx_pool[-4096:])
        context_pool = torch.Tensor([self.bufferX[i] for i in idx_pool]).to(device)
        value_pool = torch.Tensor([self.bufferY[i] for i in idx_pool]).to(device)
        # print(type(context_pool), type(value_pool))
        # print('context_pool.shape: ', context_pool.shape)
        # print('value_pool.shape: ', value_pool.shape)
        for i in range(0, 4096, 64):
            self.net.set_train_data(context_pool[i:i + 64], value_pool[i:i + 64], strict=False)
            # self.net.get_fantasy_model(context_pool[i:i + 64], value_pool[i:i + 64])
            self.loss_step(context_pool[i:i + 64], value_pool[i:i + 64])
        # self.loss_step(context_pool, value_pool)

class BBB_MNet_kl(MushroomNet):
    def __init__(self, lr):
        super().__init__()
        self.net = BNN(input_dim+2, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
        self.mse = lambda x, y: .5*((x-y)**2).sum()
        self.lr = lr
        self.optimizer = torch.optim.Adam([{'params': self.net.parameters(),  'lr': self.lr}])

    def loss_step(self, x, y, n_samples = 2):
        # beta = 2 ** (64 - (batch_id + 1)) / (2 ** 64 - 1)
        pred_y, distance_prior = self.net.forward_kl(x)
        pred_y = pred_y.squeeze().flatten()
        # print('pred_y.shape :', pred_y.shape)
        # print('y.shape: ', y.shape)
        # print(type(outputs))
        train_loss = loss(pred_y, y) + prior_coeff * distance_prior
        self.optimizer.zero_grad()
        train_loss.backward()
        self.optimizer.step()
        # print
        print("epoch : {} \t\t training loss_kl \t\t : {}".format(epoch, train_loss),
              datetime.now().replace(microsecond=0) - start_time)

class BBB_MNet_w(MushroomNet):
    def __init__(self, lr):
        super().__init__()
        self.net = BNN(input_dim+2, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
        self.mse = lambda x, y: .5*((x-y)**2).sum()
        self.lr = lr
        self.optimizer = torch.optim.Adam([{'params': self.net.parameters(),  'lr': self.lr}])

    def loss_step(self, x, y, n_samples = 2):
        # beta = 2 ** (64 - (batch_id + 1)) / (2 ** 64 - 1)
        pred_y, distance_prior = self.net.forward_w(x)
        pred_y = pred_y.squeeze().flatten()
        # print('pred_y.shape :', pred_y.shape)
        # print('y.shape: ', y.shape)
        # print(type(outputs))
        train_loss = loss(pred_y, y) + prior_coeff * distance_prior
        self.optimizer.zero_grad()
        train_loss.backward()
        self.optimizer.step()
        # print
        print("epoch : {} \t\t training loss_w \t\t : {}".format(epoch, train_loss),
              datetime.now().replace(microsecond=0) - start_time)



    # def loss_step(self, x, y):
    #     raise NotImplementedError
def sample_measurement_set(X, num_data):
    n = torch.Tensor([40])
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

def lipf(X, Y):
    return ((X - Y).abs())

class BBB_MNet_ifbnn(MushroomNet):
    def __init__(self, lr):
        super().__init__()
        self.net = BNN(input_dim+2, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
        self.mse = lambda x, y: .5*((x-y)**2).sum()
        self.lr = lr
        self.lr_optbnn = lr_optbnn
        self.opt_bnn = BNN(input_dim+2, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
        self.net_optimizer = torch.optim.Adam([{'params': self.net.parameters(),  'lr': self.lr}])
        self.opt_bnn_optimizer = torch.optim.RMSprop([{'params': self.opt_bnn.parameters(), 'lr': self.lr_optbnn}])

    def loss_step(self, x, y):

        measurement_set = sample_measurement_set(X=x, num_data=64)

        gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
        gpss = gpss.squeeze()
        nnet_samples = self.opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
        nnet_samples = nnet_samples.squeeze()

        functional_wdist = lipf(gpss, nnet_samples)
        functional_wdist = torch.mean(functional_wdist)

        ##################### variance constraint #########

        with torch.no_grad():
            gp_var = prior(measurement_set).variance
            # print('gp_var.shape: ', gp_var.shape)

        opt_samples = self.opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)
        opt_samples = opt_samples.squeeze()
        opt_var = torch.std(opt_samples, 0) ** 2
        # print('opt_var.shape: ', opt_var.shape)
        dif_var = (gp_var - opt_var).abs()
        # dif_var = torch.sum(dif_var)
        dif_var = torch.mean(dif_var)

        w_prior_mu_list = []
        b_prior_mu_list = []
        w_prior_std_list = []
        b_prior_std_list = []
        for name, p in self.opt_bnn.named_parameters():
            if "W_mu" in name:
                # w_prior_mu = p
                w_prior_mu_list.append(p)
            if 'b_mu' in name:
                # b_prior_mu = p
                b_prior_mu_list.append(p)
            if 'W_std' in name:
                # w_prior_std = p
                w_prior_std_list.append(p)
            if 'b_std' in name:
                # b_prior_std = p
                b_prior_std_list.append(p)

        # print(w_prior_mu_list, b_prior_mu_list,  w_prior_std_list, b_prior_std_list)

        input_layer_wd = self.net.input_layer.para_wd(w_prior_mu_list[0], b_prior_mu_list[0], w_prior_std_list[0],
                                                 b_prior_std_list[0])
        mid_layer_wd = self.net.mid_layer.para_wd(w_prior_mu_list[1], b_prior_mu_list[1], w_prior_std_list[1],
                                             b_prior_std_list[1])
        output_layer_wd = self.net.output_layer.para_wd(w_prior_mu_list[2], b_prior_mu_list[2], w_prior_std_list[2],
                                                   b_prior_std_list[2])

        bnn_2wassdist = (input_layer_wd + mid_layer_wd + output_layer_wd) / 3
        pred_y = self.net.forward(x)
        pred_y = pred_y.squeeze().flatten()
        train_loss = loss(pred_y, y) + prior_coeff * bnn_2wassdist + f_coeff * functional_wdist + 1e-3 * dif_var
        self.net_optimizer.zero_grad()
        self.opt_bnn_optimizer.zero_grad()
        train_loss.backward()
        self.net_optimizer.step()
        self.opt_bnn_optimizer.step()
        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
              datetime.now().replace(microsecond=0) - start_time)

class GP_MNet(MushroomNet):
    def __init__(self, lr, train_x, train_y):
        super().__init__()
        self.train_x = train_x
        self.train_y = train_y
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
        self.net = ExactGPModel(self.train_x, self.train_y, self.likelihood, input_dim + 2).to(device)
        self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.net)
        self.lr = lr
        self.optimizer = torch.optim.Adam([{'params': self.net.parameters(), 'lr': self.lr}])



    def loss_step(self, x, y):
        # self.net = ExactGPModel(x, y, self.likelihood, input_dim + 2).to(device)
        # self.optimizer = torch.optim.Adam([{'params': self.net.parameters(), 'lr': self.lr}])
        self.net.train()
        self.likelihood.train()
        # mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.net)
        self.optimizer.zero_grad()
        # Output from model
        output = self.net(x)

        # Calc loss and backprop gradients
        loss_gp = -self.mll(output, y)
        loss_gp.backward()
        print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
            i + 1, n_step_prior_pretraining, loss_gp.item(),
            # prior.covar_module.base_kernel.lengthscale.item(),
            self.net.likelihood.noise.item()
        ))
        self.optimizer.step()



# class BBB_MNet_gwi(MushroomNet):
#     def __init__(self, lr):
#         super().__init__()
#         self.net = BNN(input_dim + 2, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(
#             device)
#         self.lr = lr
#         self.optimizer = torch.optim.Adam([{'params': self.net.parameters(), 'lr': self.lr}])
#
#     def loss_step(self, x, y):
#         z_mask = torch.randperm(batch_num)[:m]
#         z_prime_mask = torch.randperm(batch_num)[:30]
#         Z = x[z_mask, :].to(device)
#         Z_prime = x[z_prime_mask, :].to(device)
#         eye = torch.eye(m).to(device)
#
#         with torch.no_grad():
#             kxx = k(x).evaluate()
#             # print('kxx.shape: ', kxx.shape)
#             kzx = k(Z, x).evaluate()
#             # print('kzx.shape: ', kzx.shape)
#             kzpx = k(Z_prime, x).evaluate()
#             kzzp = k(Z, Z_prime).evaluate()
#             kzz = k(Z).evaluate()
#             # print('kzz.shape: ', kzz.shape)
#             chol_L = torch.linalg.cholesky(kzz + kzx @ kzx.t() / sigma + 1e-2 * eye)
#             # inv_L = torch.inverse(kzz + kzx @ kzx.t() / sigma + 1e-2 * eye)
#             L = torch.linalg.cholesky(torch.cholesky_inverse(chol_L))
#             # L = torch.linalg.cholesky(inv_L)
#             L = torch.nn.Parameter(L)
#
#         t = L.t() @ kzx
#         T_mat = t.t() @ t
#
#         t2 = L.t() @ kzzp
#         T_mat2 = t2.t() @ t
#
#         chol_z = torch.linalg.cholesky(kzz + 1e-2 * eye)
#         sol = torch.cholesky_solve(kzx, chol_z)
#
#         rxx = kxx - kzx.t() @ sol + T_mat
#
#         rzpx = kzpx - kzzp.t() @ sol + T_mat2
#
#         #####
#         bnn_loss = self.net.forward(x)
#         # print('bnn.shape: ', bnn_loss.shape)
#         prior_marginal = prior(x)
#         m_p_x = 0. * prior_marginal.mean
#         m_q_x = m_p_x + bnn_loss.squeeze()
#
#         const = 0.5 * np.log(np.pi * 2) + np.log(sigma)
#         pred_y = m_q_x
#         vec = (y - pred_y) ** 2
#         vec = torch.sum(vec)
#         r_trace = torch.trace(rxx)
#         k_trace = torch.trace(kxx)
#
#         likelihood = batch_num * const + (vec + r_trace) / 2. * sigma
#         reg = loss(m_q_x, m_p_x)
#
#         # calculate hard trace
#         big_eye = 30. * torch.eye(30).to(device)
#
#         rk_hat = rzpx @ kzpx.t()
#         # print('rk_hat.shape: ', rk_hat.shape)
#         eigs = torch.linalg.eigvals(rk_hat + big_eye)
#         eigs = eigs.abs()
#         eigs = eigs - big_eye.diag()
#         eigs = eigs[eigs > 0]
#         hard_trace = torch.sum(eigs ** 0.5)
#
#         w2 = reg + k_trace / batch_num + r_trace / batch_num - 2 / ((30 * batch_num) ** 0.5) * hard_trace
#
#         train_loss = likelihood + prior_coeff * w2
#
#         # optimisation
#         self.optimizer.zero_grad()
#         # krr_optimizer.zero_grad()
#         train_loss.backward(retain_graph=True)
#         self.optimizer.step()
#
#         print("epoch : {} \t\t training loss \t\t : {} \t\t distance_prior \t\t : {}".format(epoch, train_loss,
#                                                                                              w2),
#               datetime.now().replace(microsecond=0) - start_time)
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
# scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

############## pre-train GP prior ##############################################

# pre-train GP prior
train_x_gp, train_y_gp = init_buffer_gp()
# gpprior = GP_MNet(lr=0.1, train_x=train_x_gp, train_y=train_y_gp)
# for i in range(n_step_prior_pretraining):
#     mushroom = np.random.randint(len(x))
#     gpprior.update_gp(mushroom)
#     if (i + 1) % 10 == 0:
#         print('cum_regrets: ', gpprior.cum_regrets[-1])
#
# gpprior.net.eval()
# gpprior.likelihood.eval()
#########################################
likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior = ExactGPModel(train_x_gp, train_y_gp, likelihood, input_dim+2).to(device)

prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(train_x_gp)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, train_y_gp)
    loss_gp.backward()
    print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))
    # print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     # prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))

    optimizer.step()

prior.eval()
likelihood.eval()

# k = prior.covar_module
# sigma = prior.likelihood.noise.item()
###################################################################

# net_gwi = BBB_MNet_gwi(lr_bnn)
net = BBB_MNet_ifbnn(lr_bnn)
# net_kl = BBB_MNet_kl(lr_bnn)
# net_w = BBB_MNet_w(lr_bnn)
for epoch in range(epochs):
    mushroom = np.random.randint(len(x))
    # net_gwi.update(mushroom)
    net.update(mushroom)
    # net_kl.update(mushroom)
    # net_w.update(mushroom)
    if (epoch + 1) % 10 == 0:
        # print('cum_regrets: ', net_gwi.cum_regrets[-1])
        print('cum_regrets: ', net.cum_regrets[-1])
        df = pd.DataFrame.from_dict(net.cum_regrets)
        df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_ifbnn_1e-3_1t' + 'mushroom_regrets0.4.csv')
#
#         # df = pd.DataFrame.from_dict(net.cum_regrets)
#         # df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_kl' + 'mushroom_regrets.csv')
#         print('cum_regrets_kl: ', net_kl.cum_regrets[-1])
#         print('cum_regrets_w: ', net_w.cum_regrets[-1])
#         df = pd.DataFrame.from_dict(net_kl.cum_regrets)
#         df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_kl_0.4_laurt1' + 'mushroom_regrets.csv')
#         df2 = pd.DataFrame.from_dict(net_w.cum_regrets)
#         df2.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_w_0.4_laurt1' + 'mushroom_regrets.csv')
# #         df3 = pd.DataFrame.from_dict(net_gwi.cum_regrets)
# #         df3.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_gwi' + 'mushroom_regrets.csv')
# # #
#
#
# plt.clf()
# figure = plt.figure(figsize=(8, 5.5), facecolor='white')
# init_plotting()
# # # plt.figure()
# indices = np.arange(10001)[::100]
# # cum_regrets_ifbnn = np.array(net.cum_regrets)
# cum_regrets_kl = np.array(net_kl.cum_regrets)
# cum_regrets_w = np.array(net_w.cum_regrets)
# # cum_regrets_gwi = np.array(net_gwi.cum_regrets)
# # plt.plot(indices, cum_regrets_ifbnn[indices], '-go', ms=3, label='ifbnn')
# plt.plot(indices, cum_regrets_kl[indices], '-ro', ms=3, label='kl')
# plt.plot(indices, cum_regrets_w[indices], '-ko', ms=3, label='w')
# # plt.plot(indices, cum_regrets_gwi[indices], '-ko', ms=3, label='gwi')
# plt.ylabel(r'cum_regrets')
# plt.tight_layout()
# plt.xlabel('Iteration')
# plt.tight_layout()
# plt.legend(loc='upper center', ncol=3, fontsize='small')
# plt.savefig(figures_folder + '/klwbnn_0.4_laurt1.pdf')
#
#
#
#
# regret_kl = pd.read_csv(results_folder + '/' + bnn_name_string + '/' + 'context_kl_0.6' + 'mushroom_regrets.csv') #0.4_laurt1
# regret_kl = regret_kl.to_numpy()
# regret_w = pd.read_csv(results_folder + '/' + bnn_name_string + '/' + 'context_w_0.6' + 'mushroom_regrets.csv') #0.4_laurt1
# regret_w = regret_w.to_numpy()
# regret_ifbnn = pd.read_csv(results_folder + '/' + bnn_name_string + '/' + 'context_ifbnn_0.6' + 'mushroom_regrets.csv') #0.4_laurt1
# regret_ifbnn = regret_ifbnn.to_numpy()
# regret_fbnn = pd.read_csv(results_folder + '/' + bnn_name_string + '/' + 'context_fbnn_0.6' + 'mushroom_regrets.csv')
# regret_fbnn = regret_fbnn.to_numpy()
# regret_gwi = pd.read_csv(results_folder + '/' + bnn_name_string + '/' + 'context_gwi_0.6' + 'mushroom_regrets.csv')
# regret_gwi = regret_gwi.to_numpy()
# plt.clf()
# figure = plt.figure(figsize=(8, 5.5), facecolor='white')
# init_plotting()
# # plt.figure()
# indices = np.arange(10001)[::100]
# cum_regrets_kl = regret_kl[:, 1]
# cum_regrets_w = regret_w[:, 1]
# cum_regrets_ifbnn = regret_ifbnn[:, 1]
# cum_regrets_fbnn = regret_fbnn[:, 1]
# cum_regrets_gwi = regret_gwi[:, 1]
# # plt.plot(indices, cum_regrets_ifbnn[indices], '-go', ms=3, label='ifbnn')
# # plt.plot(indices, cum_regrets_kl[indices], '-ro', ms=3, label='kl')
# # plt.plot(indices, cum_regrets_w[indices], '-ko', ms=3, label='w')
# # plt.plot(indices, cum_regrets_fbnn[indices], '-bo', ms=3, label='fbnn')
# # plt.plot(indices, cum_regrets_gwi[indices], '-yo', ms=3, label='gwi')
# #
# plt.plot(indices, cum_regrets_ifbnn[indices], color='deeppink', label='FWBI')
# plt.plot(indices, cum_regrets_kl[indices], color='darkorange', label='KLBBB')
# plt.plot(indices, cum_regrets_w[indices], color='navy', label='WBBB')
# plt.plot(indices, cum_regrets_fbnn[indices], color='green', label='FBNN')
# plt.plot(indices, cum_regrets_gwi[indices], color='aqua', label='GWI')
# plt.ylabel(r'cum_regrets')
# plt.tight_layout()
# plt.xlabel('Iteration')
# plt.tight_layout()
# plt.grid()
# # plt.legend(loc='upper center', ncol=5, fontsize='small')
# plt.legend()
# plt.savefig(figures_folder + '/0.6_ifbnn_kl_w_fbnn_gwi.pdf')
#


# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("============================================================================================")