import os
from datetime import datetime
import sys
import math
import pandas as pd
from sklearn import preprocessing
from torch.autograd import Variable

import torch
import torch.nn as nn
import numpy as np
import gpytorch
from gp_prior import ExactGPModel, prior_sample_functions, prior_sample_functions2


# from bnn import BNN, OPTBNN
from bnn_vimc import BNN, BNNF, BNNMC

import torch.nn.functional as F
import argparse

from utils.logging import get_logger

from data import uci_woval

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from utils.utils import default_plotting_new as init_plotting


################################### Hyper-parameters ###################################

lr_bnn = 0.01
prior_coeff = 1
bnn_name_string = 'VIMC'   # UAI_toy
random_seed = 123   # 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 1000
num_sample = 1000
epochs = 2001    #10001
n_step_prior_pretraining = 100
lr_optbnn = 0.01
f_coeff = 10

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 100
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'

print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + 'FWBI' + '/'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'FWBI' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + 'FWBI' + '/'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'FWBI' + '/'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

# training_num, input_dim = train_x.shape
# test_num = test_x.shape[0]
# output_dim = 1
# is_continuous = True
#
# print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)
#
# original_x_train = torch.from_numpy(train_x).float().to(device)
# original_y_train = torch.from_numpy(train_y).float().to(device)
# original_x_test = torch.from_numpy(test_x).float().to(device)
# original_y_test = torch.from_numpy(test_y).float().to(device)
current_dir = os.getcwd()
print(current_dir)
root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_path)
data_path = os.path.join(root_path, 'data', 'uci', 'mushroom.data')
df = pd.read_csv(os.getcwd() + '/data/uci/mushroom.data', header=None)

# Set pandas to output all of the columns in output
df.columns = ['class', 'cap-shape','cap-surface','cap-color','bruises','odor','gill-attachment',
         'gill-spacing','gill-size','gill-color','stalk-shape','stalk-root',
         'stalk-surf-above-ring','stalk-surf-below-ring','stalk-color-above-ring','stalk-color-below-ring',
         'veil-type','veil-color','ring-number','ring-type','spore-color','population','habitat']

# Split context from label
X = pd.DataFrame(df, columns=df.columns[1:len(df.columns)], index=df.index)
# Put the class values (0th column) into Y
Y = df['class']

# Transform labels into one-hot encoded array
le = preprocessing.LabelEncoder()
le.fit(Y)
y = le.transform(Y)

# Temporary variable to avoid error
x_tmp = pd.DataFrame(X, columns=[X.columns[0]])

# Encode each feature column and add it to x_train
for colname in X.columns:
    le.fit(X[colname])
    #print(colname, le.classes_)
    x_tmp[colname] = le.transform(X[colname])

# Produce mushroom array: 8124 mushrooms, each with 117 one-hot encoded features
oh = preprocessing.OneHotEncoder()
oh.fit(x_tmp)
x = oh.transform(x_tmp).toarray()

training_num, input_dim = x.shape
output_dim = 1
is_continuous = True
print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)

original_x_train = torch.from_numpy(x).float().to(device)
original_y_train = torch.from_numpy(y).float().to(device)


loss = nn.MSELoss(reduction='mean')

batch_num = 64
# m = int(round(batch_num ** 0.5))


def init_buffer_gp():
    bufferx = []
    buffery = []
    for i in np.random.choice(range(len(x)), 1000):
        # eat = np.random.rand() > 0.5
        # action = np.array([1, 0] if eat else [0, 1])
        context, edible = x[i], y[i]
        action = np.array([1, 0] if edible==1 else [0, 1])
        oracle = oracle_reward(edible)
        bufferx.append(np.concatenate((x[i], action)))
        buffery.append(oracle)
    train_x = torch.Tensor([bufferx[i] for i in range(1000)]).to(device)
    train_y = torch.Tensor([buffery[i] for i in range(1000)]).to(device)
    return train_x, train_y


print("============================================================================================")
################################## define model ##################################


opt_bnn = BNN(input_dim+2, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': lr_bnn}])

print("============================================================================================")

############ define reward function####################
def get_reward(eaten, edible):
    # REWARDS FOR AGENT
    #  Eat poisonous mushroom
    if not eaten:
        return 0
    if eaten and edible:
        return 5
    elif eaten and not edible:
        return 5 if np.random.rand() > 0.6 else -35

def oracle_reward(edible):
    return 5*edible

################ define mushroom net######################
Var = lambda x, dtype=torch.FloatTensor: Variable(
    torch.from_numpy(x).type(dtype)).to(device)


def sample_measurement_set(X, num_data):
    n = torch.Tensor([40])
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

def lipf(X, Y):
    return ((X - Y).abs())

############## pre-train GP prior ##############################################

# pre-train GP prior
train_x_gp, train_y_gp = init_buffer_gp()
#

#########################################
likelihood = gpytorch.likelihoods.GaussianLikelihood()
prior = ExactGPModel(train_x_gp, train_y_gp, likelihood, input_dim+2).to(device)

prior.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(prior.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, prior)

for i in range(n_step_prior_pretraining):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = prior(train_x_gp)
    # Calc loss and backprop gradients
    loss_gp = -mll(output, train_y_gp)
    loss_gp.backward()
    print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
        i + 1, n_step_prior_pretraining, loss_gp.item(),
        prior.covar_module.base_kernel.lengthscale.item(),
        prior.likelihood.noise.item()
    ))
    # print('Iter %d/%d - Loss: %.3f     noise: %.3f' % (
    #     i + 1, n_step_prior_pretraining, loss_gp.item(),
    #     # prior.covar_module.base_kernel.lengthscale.item(),
    #     prior.likelihood.noise.item()
    # ))

    optimizer.step()

prior.eval()
likelihood.eval()

print("============================================================================================")

############################################
###### linear_nn flow ###############
hidden_dimsflow = [100]
class LinearFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net   # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()
        x = z * self.weight + nnbias
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()   # [60, 10]
        # print('nnbias.shape: ', nnbias.shape)
        # print('x.shape: ', x.shape)
        z = x * 1/self.weight - nnbias/self.weight
        log_det_inverse = torch.sum(torch.log(1/self.weight.abs() + 1e-7))

        return z, log_det_inverse

##
class NonLinearFlow_nn(nn.Module):
    def __init__(self, net1, net2):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias1 = net1  # torch.randn(1)
        self.weight2 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias2 = net2  # torch.randn(1)


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias1 = self.bias1(x_in).squeeze()
        nnbias2 = self.bias2(x_in).squeeze()
        mask = (z - nnbias2) / self.weight2
        clamped_mask = torch.clamp(mask, -0.99, 0.99)
        x = torch.atanh(clamped_mask) / self.weight - nnbias1 / self.weight
        log_det = torch.sum(torch.log(self.weight.abs() + 1e-7))

        return x, nnbias1, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias1 = self.bias1(x_in).squeeze()
        nnbias2 = self.bias2(x_in).squeeze()
        # print('nnbias1.shape: ', nnbias1)
        # print('nnbias2.shape: ', nnbias2)

        z = self.weight2 * torch.tanh(x * self.weight + nnbias1) + nnbias2

        det = self.weight2 * self.weight * (1 - torch.tanh(x * self.weight + nnbias1) ** 2)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse

########## sigmoid_flow #############
class SigmoidFlow_nn(nn.Module):
    def __init__(self, net):
        super().__init__()
        # self.dim = dim
        # self.weight = nn.Parameter(torch.randn(1), requires_grad=True)     # torch.randn(1, dim)
        self.weight = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.weight2 = nn.Parameter(torch.tensor([1.]), requires_grad=True)
        self.bias = net


    def forward(self, z, x_in):
        # x = F.linear(z, self.weight, self.bias)
        nnbias = self.bias(x_in).squeeze()

        mask = (z - nnbias) / self.weight2
        mask = torch.clamp(mask, min=0.01, max=0.99)
        clamped_mask = mask / (1 - mask)
        # x = torch.log(clamped_mask + 1e-7) / self.weight - nnbias1 / self.weight
        x = torch.log(clamped_mask + 1e-7) / self.weight
        log_det = 1.

        return x, nnbias, log_det

    def inverse(self, x, x_in):
        # z = F.linear(x, 1/self.weight, -self.bias/self.weight)
        nnbias = self.bias(x_in).squeeze()

        mask = torch.nn.functional.sigmoid(x * self.weight)
        z = self.weight2 * mask + nnbias

        det = self.weight2 * self.weight * mask * (1 - mask)
        log_det_inverse = torch.sum(torch.log(det.abs() + 1e-7), -1)

        return z, log_det_inverse
######################

################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)


# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
# scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

################# S1: FVI #################################
bufferX = []
bufferY = []
cum_regrets = [0]

for epoch in range(1001):

    mushroom = np.random.randint(len(x))
    context, edible = x[mushroom], y[mushroom]
    try_eat = Var(np.concatenate((context, [1, 0])))
    try_reject = Var(np.concatenate((context, [0, 1])))

    # Calculate rewards using model
    with torch.no_grad():
        # r_eat = sum([self.net.forward(try_eat) for _ in range(samples)]).item()
        # r_reject = sum([self.net.forward(try_reject) for _ in range(samples)]).item()
        r_eat = opt_bnn.forward(try_eat)
        r_reject = opt_bnn.forward(try_reject)

    # Take random action for epsilon greedy agents, calculate agent's reward
    eaten = r_eat > r_reject

    agent_reward = get_reward(eaten, edible)

    # Get rewards and update buffer
    action = np.array([1, 0] if eaten else [0, 1])
    bufferX.append(np.concatenate((context, action)))
    bufferY.append(agent_reward)

    # Calculate regret
    oracle = oracle_reward(edible)
    regret = oracle - agent_reward
    cum_regrets.append(cum_regrets[-1] + regret)

    ### Feed next mushroom
    l = len(bufferX)
    print('len(bufferX): ', l)
    idx_pool = range(l) if l >= 4096 else ((int(4096 // l) + 1) *
                                           list(range(l)))
    idx_pool = np.random.permutation(idx_pool[-4096:])
    context_pool = torch.Tensor([bufferX[i] for i in idx_pool]).to(device)
    value_pool = torch.Tensor([bufferY[i] for i in idx_pool]).to(device)

    for i in range(0, 4096, 64):
        x_ = context_pool[i:i + 64]
        y_ = value_pool[i:i + 64]

        measurement_set = sample_measurement_set(X=x_, num_data=64)

        gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
        gpss = gpss.squeeze()
        nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
        nnet_samples = nnet_samples.squeeze()

        functional_wdist = lipf(gpss, nnet_samples)
        functional_wdist = torch.mean(functional_wdist)

        ##################### variance constraint #########

        with torch.no_grad():
            gp_var = prior(measurement_set).variance
            # print('gp_var.shape: ', gp_var.shape)

        opt_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)
        opt_samples = opt_samples.squeeze()
        opt_var = torch.std(opt_samples, 0) ** 2
        # print('opt_var.shape: ', opt_var.shape)
        dif_var = (gp_var - opt_var).abs()
        # dif_var = torch.sum(dif_var)
        dif_var = torch.mean(dif_var)
        ######

        pred_y = opt_bnn.forward(x_)
        pred_y = pred_y.squeeze().flatten()
        train_loss = loss(pred_y, y_) + f_coeff * functional_wdist + 1e-2 * dif_var

        opt_bnn_optimizer.zero_grad()

        train_loss.backward()

        opt_bnn_optimizer.step()

        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
              datetime.now().replace(microsecond=0) - start_time)

    if (epoch + 1) % 10 == 0:
        # print('cum_regrets: ', net_gwi.cum_regrets[-1])
        print('cum_regrets: ', cum_regrets[-1])
        df = pd.DataFrame.from_dict(cum_regrets)
        df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_fvimc3' + 'mushroom_regrets0.6.csv')
        # df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_fvi_1e-3' + 'mushroom_regrets0.6_3.csv')
print("--------------------------------------------------------------------------------------------")
######### S2: FMCMC #####################
bufferX_MC = []
bufferY_MC = []

w_mu_list = []
b_mu_list = []
w_std_list = []
b_std_list = []

for name, p in opt_bnn.named_parameters():
    if "W_mu" in name:
        # w_prior_mu = p
        w_mu_list.append(p)
    if 'b_mu' in name:
        # b_prior_mu = p
        b_mu_list.append(p)
    if 'W_std' in name:
        # w_prior_std = p
        w_std_list.append(p)
    if 'b_std' in name:
        # b_prior_std = p
        b_std_list.append(p)


bnnmc = BNNMC(input_dim+2, output_dim, hidden_dims, activation_fn, is_continuous,
              W_mu=w_mu_list, b_mu=b_mu_list, W_std=w_std_list, b_std=b_std_list, scaled_variance=True).to(device)

# bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

bnnmc_optimizer = torch.optim.Adam(bnnmc.parameters(), lr=0.01)  # yacht-energy: 0.001, whine-protein: 0.01

def noise_loss(lr):
    noise_loss = 0.0
    noise_std = (2/lr)**0.5
    for var in bnnmc.parameters():
        means = torch.zeros(var.size()).to(device)
        sigma = torch.normal(means, std=noise_std).to(device)
        noise_loss += torch.sum(var * sigma)
    return noise_loss

########### SGLD ###############

mt = 0
for epoch in range(1001):

    mushroom = np.random.randint(len(x))
    context, edible = x[mushroom], y[mushroom]
    try_eat = Var(np.concatenate((context, [1, 0])))
    try_reject = Var(np.concatenate((context, [0, 1])))

    # Calculate rewards using model
    with torch.no_grad():
        # r_eat = sum([self.net.forward(try_eat) for _ in range(samples)]).item()
        # r_reject = sum([self.net.forward(try_reject) for _ in range(samples)]).item()
        r_eat = bnnmc.forward(try_eat)
        r_reject = bnnmc.forward(try_reject)

    # Take random action for epsilon greedy agents, calculate agent's reward
    eaten = r_eat > r_reject

    agent_reward = get_reward(eaten, edible)

    # Get rewards and update buffer
    action = np.array([1, 0] if eaten else [0, 1])
    bufferX.append(np.concatenate((context, action)))
    bufferY.append(agent_reward)

    bufferX_MC.append(np.concatenate((context, action)))
    bufferY_MC.append(agent_reward)

    # Calculate regret
    # oracle = oracle_reward(edible)
    # regret = oracle - agent_reward
    # cum_regrets.append(cum_regrets[-1] + regret)

    ### Feed next mushroom
    l = len(bufferX_MC)
    print('len(bufferX): ', l)
    idx_pool = range(l) if l >= 4096 else ((int(4096 // l) + 1) *
                                           list(range(l)))
    idx_pool = np.random.permutation(idx_pool[-4096:])
    context_pool = torch.Tensor([bufferX_MC[i] for i in idx_pool]).to(device)
    value_pool = torch.Tensor([bufferY_MC[i] for i in idx_pool]).to(device)

    for i in range(0, 4096, 64):
        x_ = context_pool[i:i + 64]
        # print('x_.shape', x_.shape)
        y_ = value_pool[i:i + 64]
        # print('y_.shape', y_.shape)

        pred_y = bnnmc.forward(x_)
        pred_y = pred_y.squeeze().flatten()

        t_like = pred_y * ((pred_y - y_).detach())
        t_like = torch.mean(t_like)

        m_set = sample_measurement_set(X=x_, num_data=64)
        y_m_set = bnnmc.forward(m_set).squeeze().flatten()
        # print('y_m_set.shape: ', y_m_set.shape)

        with torch.no_grad():
            fprior_m = likelihood(prior(m_set))
            prior_mean_m = fprior_m.mean
            prior_var_m = fprior_m.variance

        d_logp_m = ((y_m_set - prior_mean_m) / prior_var_m).detach()
        d_logp_m = y_m_set * d_logp_m
        d_logp_m = torch.mean(d_logp_m)

        train_loss = t_like + 1 * d_logp_m

        if epoch > 500:
            noise = noise_loss(lr_bnn)
            train_loss = train_loss + 1e-3 * noise

        bnnmc_optimizer.zero_grad()
        train_loss.backward()
        bnnmc_optimizer.step()



        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch+1001, train_loss),
              datetime.now().replace(microsecond=0) - start_time)

    # if (epoch + 1) % 10 == 0:
    #     # print('cum_regrets: ', net_gwi.cum_regrets[-1])
    #     print('cum_regrets: ', cum_regrets[-1])
    #     df = pd.DataFrame.from_dict(cum_regrets)
    #     df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_fvimc' + 'mushroom_regrets0.4.csv')

    ##############
    if epoch > 500 and (epoch % 5) == 0:  # (100, 3)
        print('save!')
        bnnmc.cpu()
        # torch.save(bnn.state_dict(), save_path + '/g3_rbf_%i.pt'%(mt))
        dd = f'{mt}'
        print(dd)
        save_path = results_folder + '/' + bnn_name_string + '/' + 'fsgld_context3' + '/' + 'context_' + dd
        # if not os.path.exists(save_path):
        #     os.makedirs(save_path)
        bnnmc.save(save_path)
        mt += 1
        bnnmc.to(device)

#####################
def sample_measurement_set2(X, num_data):
    n = torch.Tensor([100])
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]
    return measurement_set

l = len(bufferX_MC)
idx_pool = range(l)
o_x = torch.Tensor([bufferX_MC[i] for i in idx_pool]).to(device)
# print('o_x.shape: ', o_x.shape)

original_x_train_add_g3 = sample_measurement_set2(X=o_x, num_data=o_x.shape[0])

#####################
def test2():
    bnnmc.eval()

    pred_mset = bnnmc.forward(original_x_train_add_g3).squeeze().detach()

    return pred_mset

################ load mcmc model ################
predm_list = []
num_model = 100 # 10 100 80 70

for m in range(num_model):
    dd = f'{m}'
    bnnmc.load_state_dict(torch.load(results_folder + '/' + bnn_name_string + '/' + 'fsgld_context3' + '/' + 'context_' + dd))  # sgld, sghmc

    predm = test2()
    predm_list.append(predm)
    mset_fmc = torch.stack(predm_list)
print('mset_fmc.shape: ', mset_fmc.shape)


############### GP Linearization for FVI posterior ###############
opt_bnn2 = opt_bnn
opt_bnn_optimizer2 = torch.optim.Adam([{'params': opt_bnn2.parameters(), 'lr': lr_bnn}])

w_mu, w_std = opt_bnn2.sample_mu_std()
w_mu = w_mu.detach()
w_std = w_std.detach()
print('w_std.shape: ', w_std.shape)  # [10401]
w_dim = w_std.shape[-1]
print('w_dim: ', w_std.shape[-1])
w_cov = w_std ** 2 * torch.eye(w_dim).to(device)
print('w_cov.shape: ', w_cov.shape)  # [10401, 10401]

grads_x_mset = None

for x_i in original_x_train_add_g3:

    pred_y_i = opt_bnn2.forward_mu(x_i)
    pred_y_i = pred_y_i.squeeze().flatten()

    # evaluate gradient
    opt_bnn_optimizer2.zero_grad()
    pred_y_i.backward()

    grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
    grad_i = grad_i[:, None]
    print('grad_i.shape: ', grad_i.shape)

    if grads_x_mset is None:
        grads_x_mset = torch.transpose(grad_i, 0, 1)
    else:
        grads_x_mset = torch.cat((grads_x_mset, torch.transpose(grad_i, 0, 1)), 0)

print('grads_x_mset.shape: ', grads_x_mset.shape)

prior_mean_m = opt_bnn2.forward_mu(original_x_train_add_g3).squeeze().detach()
print('prior_mean_m: ', prior_mean_m.shape)

prior_var_m = grads_x_mset @ w_cov @ torch.transpose(grads_x_mset, 0, 1)
print('prior_var_m: ', prior_var_m.shape)

prior_var_m = torch.diag(prior_var_m, 0)
print('prior_var_m: ', prior_var_m.shape)
##

######################### train LinerFolw_nn ##############

flownet = BNNF(input_dim+2, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(device)
flownet2 = BNNF(input_dim+2, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(device)

flow = LinearFlow_nn(flownet).to(device)
# flow = CubeFlow_nn(flownet).to(device)
# flow = NonLinearFlow_nn(flownet, flownet2).to(device)
# flow = SigmoidFlow_nn(flownet).to(device)
# flow = ExpFlow_nn(flownet).to(device)
flow_optimizer = torch.optim.Adam(flow.parameters(), lr=0.01)


loss_flow_all = []
flow_likelihood_list = []
flow_det_list = []
flow_loss_list = []

for epoch in range(5000):
    Z, logdet = flow.inverse(mset_fmc, original_x_train_add_g3)
    # print('Z: ', Z)
    contains_nan = torch.isnan(Z).any()
    print(contains_nan)

    loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
        torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet)

    flow_optimizer.zero_grad()
    loss_flow.backward(retain_graph=True)
    flow_optimizer.step()
    loss_flow_all.append(loss_flow.item())


    # print
    print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_flow),
          datetime.now().replace(microsecond=0) - start_time)

flow.eval()
a = flow.weight
print('flow.weight: ', flow.weight)

################################# FVI-MCMC LOOP #######################

flow_list = [flow]
a_list = [a]
# logdet_list = [flogdet]
cum_det_list = []
flow_name_list = []

##
all_train_losses = []
all_wdist_losses = []
all_likelihood_losses = []
all_m2_losses = []

##
for i in range(1):    # 13
    loop_count = 4000 * i
    print('loop_count: ', loop_count)
    outloop = i
    ################# S1: FVI #############################
    train_loss_all = []
    prior_distance_all = []
    wdist_all = []
    likelihood_loss_all = []
    dif_var_all = []
    opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': 0.01}])   #lr_bnn

    for epoch in range(2001):

        mushroom = np.random.randint(len(x))
        context, edible = x[mushroom], y[mushroom]
        try_eat = Var(np.concatenate((context, [1, 0])))
        try_reject = Var(np.concatenate((context, [0, 1])))

        # Calculate rewards using model
        with torch.no_grad():
            # r_eat = sum([self.net.forward(try_eat) for _ in range(samples)]).item()
            # r_reject = sum([self.net.forward(try_reject) for _ in range(samples)]).item()
            r_eat = opt_bnn.forward(try_eat)
            for flow in flow_list:
                r_eat, _, _ = flow(r_eat, try_eat)

            r_reject = opt_bnn.forward(try_reject)
            for flow in flow_list:
                r_reject, _, _ = flow(r_reject, try_reject)

        # Take random action for epsilon greedy agents, calculate agent's reward
        eaten = r_eat > r_reject

        agent_reward = get_reward(eaten, edible)

        # Get rewards and update buffer
        action = np.array([1, 0] if eaten else [0, 1])
        bufferX.append(np.concatenate((context, action)))
        bufferY.append(agent_reward)

        # Calculate regret
        oracle = oracle_reward(edible)
        regret = oracle - agent_reward
        cum_regrets.append(cum_regrets[-1] + regret)

        ### Feed next mushroom
        l = len(bufferX)
        print('len(bufferX): ', l)
        idx_pool = range(l) if l >= 4096 else ((int(4096 // l) + 1) *
                                               list(range(l)))
        idx_pool = np.random.permutation(idx_pool[-4096:])
        context_pool = torch.Tensor([bufferX[i] for i in idx_pool]).to(device)
        value_pool = torch.Tensor([bufferY[i] for i in idx_pool]).to(device)

        for i in range(0, 4096, 64):
            x_ = context_pool[i:i + 64]
            y_ = value_pool[i:i + 64]

            measurement_set = sample_measurement_set(X=x_, num_data=64)

            gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
            gpss = gpss.squeeze()
            nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(
                device)  # original_x_test
            nnet_samples = nnet_samples.squeeze()

            for flow in flow_list:

                nnet_samples, _, _ = flow(nnet_samples, measurement_set)

            functional_wdist = lipf(gpss, nnet_samples)
            functional_wdist = torch.mean(functional_wdist)

            ##################### variance constraint #########

            with torch.no_grad():
                gp_var = prior(measurement_set).variance
                # print('gp_var.shape: ', gp_var.shape)

            opt_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)
            opt_samples = opt_samples.squeeze()

            for flow in flow_list:

                opt_samples, _, _ = flow(opt_samples, measurement_set)

            opt_var = torch.std(opt_samples, 0) ** 2
            # print('opt_var.shape: ', opt_var.shape)
            dif_var = (gp_var - opt_var).abs()
            # dif_var = torch.sum(dif_var)
            dif_var = torch.mean(dif_var)
            ######

            pred_y = opt_bnn.forward(x_).t()

            for flow in flow_list:

                pred_y, _, _ = flow(pred_y, x_)

            pred_y = pred_y.squeeze().flatten()
            train_loss = loss(pred_y, y_) + f_coeff * functional_wdist + 1e-2 * dif_var

            opt_bnn_optimizer.zero_grad()

            train_loss.backward()

            opt_bnn_optimizer.step()

            # print
            print("epoch : {} \t\t training loss \t\t : {}".format(epoch + 2002 + loop_count, train_loss),
                  datetime.now().replace(microsecond=0) - start_time)

        if (epoch + 1) % 10 == 0:
            # print('cum_regrets: ', net_gwi.cum_regrets[-1])
            print('cum_regrets: ', cum_regrets[-1])
            df = pd.DataFrame.from_dict(cum_regrets)
            df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_fvimc3' + 'mushroom_regrets0.6.csv')

    ###########

    ################## MCMC #################
    w_mu_list = []
    b_mu_list = []
    w_std_list = []
    b_std_list = []

    for name, p in opt_bnn.named_parameters():
        if "W_mu" in name:
            # w_prior_mu = p
            w_mu_list.append(p)
        if 'b_mu' in name:
            # b_prior_mu = p
            b_mu_list.append(p)
        if 'W_std' in name:
            # w_prior_std = p
            w_std_list.append(p)
        if 'b_std' in name:
            # b_prior_std = p
            b_std_list.append(p)

    bnnmc = BNNMC(input_dim+2, output_dim, hidden_dims, activation_fn, is_continuous,
                  W_mu=w_mu_list, b_mu=b_mu_list, W_std=w_std_list, b_std=b_std_list, scaled_variance=True).to(device)

    # bnnmc = BNNMC(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

    bnnmc_optimizer = torch.optim.Adam(bnnmc.parameters(), lr=0.01)  #  lr=0.001

    def noise_loss(lr):
        noise_loss = 0.0
        noise_std = (2 / lr) ** 0.5
        for var in bnnmc.parameters():
            means = torch.zeros(var.size()).to(device)
            sigma = torch.normal(means, std=noise_std).to(device)
            noise_loss += torch.sum(var * sigma)
        return noise_loss

    ######### SGLD #####
    mt = 0
    for epoch in range(2001):

        mushroom = np.random.randint(len(x))
        context, edible = x[mushroom], y[mushroom]
        try_eat = Var(np.concatenate((context, [1, 0])))
        try_reject = Var(np.concatenate((context, [0, 1])))

        # Calculate rewards using model
        with torch.no_grad():
            # r_eat = sum([self.net.forward(try_eat) for _ in range(samples)]).item()
            # r_reject = sum([self.net.forward(try_reject) for _ in range(samples)]).item()
            r_eat = bnnmc.forward(try_eat)
            for flow in flow_list:
                r_eat, _, _ = flow(r_eat, try_eat)

            r_reject = bnnmc.forward(try_reject)
            for flow in flow_list:
                r_reject, _, _ = flow(r_reject, try_reject)

        # Take random action for epsilon greedy agents, calculate agent's reward
        eaten = r_eat > r_reject

        agent_reward = get_reward(eaten, edible)

        # Get rewards and update buffer
        action = np.array([1, 0] if eaten else [0, 1])
        bufferX.append(np.concatenate((context, action)))
        bufferY.append(agent_reward)

        bufferX_MC.append(np.concatenate((context, action)))
        bufferY_MC.append(agent_reward)

        # Calculate regret
        # oracle = oracle_reward(edible)
        # regret = oracle - agent_reward
        # cum_regrets.append(cum_regrets[-1] + regret)

        ### Feed next mushroom
        l = len(bufferX_MC)
        print('len(bufferX): ', l)
        idx_pool = range(l) if l >= 4096 else ((int(4096 // l) + 1) *
                                               list(range(l)))
        idx_pool = np.random.permutation(idx_pool[-4096:])
        context_pool = torch.Tensor([bufferX_MC[i] for i in idx_pool]).to(device)
        value_pool = torch.Tensor([bufferY_MC[i] for i in idx_pool]).to(device)

        for i in range(0, 4096, 64):
            x_ = context_pool[i:i + 64]
            # print('x_.shape', x_.shape)
            y_ = value_pool[i:i + 64]
            # print('y_.shape', y_.shape)

            pred_y = bnnmc.forward(x_).t()

            for flow in flow_list:

                pred_y, _, _ = flow(pred_y, x_)

            pred_y = pred_y.squeeze().flatten()

            t_like = pred_y * ((pred_y - y_).detach())
            t_like = torch.mean(t_like)

            m_set = sample_measurement_set(X=x_, num_data=64)
            y_m_set = bnnmc.forward(m_set).t()

            for flow in flow_list:

                y_m_set, _, _ = flow(y_m_set, m_set)
            # print('y_m_set.shape: ', y_m_set.shape)

            with torch.no_grad():
                fprior_m = likelihood(prior(m_set))
                prior_mean_m = fprior_m.mean
                prior_var_m = fprior_m.variance

            d_logp_m = ((y_m_set - prior_mean_m) / prior_var_m).detach()
            d_logp_m = y_m_set * d_logp_m
            d_logp_m = torch.mean(d_logp_m)

            train_loss = t_like + 1 * d_logp_m

            if epoch > 500:
                noise = noise_loss(lr_bnn)
                train_loss = train_loss + 1e-3 * noise

            bnnmc_optimizer.zero_grad()
            train_loss.backward()
            bnnmc_optimizer.step()

            # print
            print("epoch : {} \t\t training loss \t\t : {}".format(epoch + 4003 + loop_count, train_loss),
                  datetime.now().replace(microsecond=0) - start_time)

        # if (epoch + 1) % 10 == 0:
        #     print('cum_regrets: ', cum_regrets[-1])
        #     df = pd.DataFrame.from_dict(cum_regrets)
        #     df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_fvimc' + 'mushroom_regrets0.4.csv')

        ##############
        if epoch > 1000 and (epoch % 10) == 0:  # (100, 3)
            print('save!')
            bnnmc.cpu()
            # torch.save(bnn.state_dict(), save_path + '/g3_rbf_%i.pt'%(mt))
            dd = f'{mt}'
            print(dd)
            save_path = results_folder + '/' + bnn_name_string + '/' + 'fsgld_context3' + '/' + 'context_' + dd
            # if not os.path.exists(save_path):
            #     os.makedirs(save_path)
            bnnmc.save(save_path)
            mt += 1
            bnnmc.to(device)

    #################

    def test2():
        bnnmc.eval()

        pred_mset = bnnmc.forward(original_x_train_add_g3).squeeze().detach()
        for flow in flow_list:

            pred_mset, _, _ = flow(pred_mset, original_x_train_add_g3)

        pred_mset = pred_mset.detach()

        return pred_mset

    ################ load mcmc model ################
    pred_list = []
    predm_list = []
    num_model = 100  # 10 100 80 70

    for m in range(num_model):
        dd = f'{m}'
        bnnmc.load_state_dict(torch.load(
            results_folder + '/' + bnn_name_string + '/' + 'fsgld_context3' + '/' + 'context_' + dd))  # sgld, sghmc

        predm = test2()
        # print('predm.shape: ', predm.shape)
        predm_list.append(predm)

    mset_fmc = torch.stack(predm_list)

    ############################# GP Linearization2 ###################################
    opt_bnn2 = opt_bnn
    opt_bnn_optimizer2 = torch.optim.Adam([{'params': opt_bnn2.parameters(), 'lr': lr_bnn}])

    ##
    w_mu, w_std = opt_bnn2.sample_mu_std()
    w_mu = w_mu.detach()
    w_std = w_std.detach()
    print('w_std.shape: ', w_std.shape)  # [10401]
    w_dim = w_std.shape[-1]
    print('w_dim: ', w_std.shape[-1])
    w_cov = w_std ** 2 * torch.eye(w_dim).to(device)
    print('w_cov.shape: ', w_cov.shape)  # [10401, 10401]

    ###
    grads_x_mset = None

    for x_i in original_x_train_add_g3:
        pred_y_i = opt_bnn2.forward_mu(x_i)
        pred_y_i = pred_y_i.squeeze().flatten()

        # evaluate gradient
        opt_bnn_optimizer2.zero_grad()
        pred_y_i.backward()

        grad_i = torch.cat([p.grad.view(-1) for name, p in opt_bnn2.named_parameters() if 'mu' in name])
        grad_i = grad_i[:, None]

        if grads_x_mset is None:
            grads_x_mset = torch.transpose(grad_i, 0, 1)
        else:
            grads_x_mset = torch.cat((grads_x_mset, torch.transpose(grad_i, 0, 1)), 0)

    print('grads_x_mset.shape: ', grads_x_mset.shape)  # [60, 10401]

    ##
    prior_mean_m = opt_bnn2.forward_mu(original_x_train_add_g3).squeeze().detach()

    ##
    prior_var_m = grads_x_mset @ w_cov @ torch.transpose(grads_x_mset, 0, 1)
    prior_var_m = torch.diag(prior_var_m, 0)
    #

 ############################ train flow2 ###################
    flownet3 = BNNF(input_dim+2, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
        device)
    flownet4 = BNNF(input_dim+2, output_dim, hidden_dimsflow, activation_fn, is_continuous, scaled_variance=True).to(
        device)

    flow2 = LinearFlow_nn(flownet3).to(device)
    # flow2 = NonLinearFlow_nn(flownet3, flownet4).to(device)
    # flow2 = SigmoidFlow_nn(flownet3).to(device)
    # flow2 = ExpFlow_nn(flownet3).to(device)
    # flow2 = CubeFlow_nn(flownet2).to(device)
    flow_optimizer2 = torch.optim.Adam(flow2.parameters(), lr=0.01)


    ######
    loss_flow_all = []

    for epoch in range(5000):
        Z, logdet = flow2.inverse(mset_fmc, original_x_train_add_g3)
        # print('Z.shape: ', Z.shape)

        # loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
        #     torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet)

        #
        # cum_det = torch.stack(logdet_list)
        # cum_det = torch.sum(cum_det)
        logdet_list = []
        flow_list_inv = flow_list[::-1]

        for flow in flow_list_inv:

            Z, logdetf = flow.inverse(Z, original_x_train_add_g3)

            logdetf = torch.mean(logdetf)
            logdet_list.append(logdetf)
        # print('logdet_list: ', logdet_list)
        cum_det = torch.stack(logdet_list)
        cum_det = torch.sum(cum_det)

        loss_flow = torch.sum(torch.log(prior_var_m.abs())) + torch.mean(
            torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1) - logdet) - cum_det

        # print('sum_var_loss: ', torch.sum(torch.log(prior_var_m.abs())))
        # print('mean_loss: ', torch.mean(torch.sum(0.5 * (Z - prior_mean_m) ** 2 / prior_var_m, -1)))
        # print('logdet: ', logdet)
        # print('cum_det: ', cum_det)

        flow_optimizer2.zero_grad()
        loss_flow.backward(retain_graph=True)
        flow_optimizer2.step()
        loss_flow_all.append(loss_flow.item())

        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, loss_flow),
              datetime.now().replace(microsecond=0) - start_time)


   ############################
    flow2.eval()

    print('flow2.weight: ', flow2.weight)
    a2 = flow2.weight

    # flow_list.append(flow2)
    a_list.append(a2)
    flow_list.append(flow2)


###########
opt_bnn_optimizer = torch.optim.Adam([{'params': opt_bnn.parameters(), 'lr': 0.01}])

for epoch in range(4001):

    mushroom = np.random.randint(len(x))
    context, edible = x[mushroom], y[mushroom]
    try_eat = Var(np.concatenate((context, [1, 0])))
    try_reject = Var(np.concatenate((context, [0, 1])))

    # Calculate rewards using model
    with torch.no_grad():
        # r_eat = sum([self.net.forward(try_eat) for _ in range(samples)]).item()
        # r_reject = sum([self.net.forward(try_reject) for _ in range(samples)]).item()

        r_eat = opt_bnn.forward(try_eat)
        for flow in flow_list:
            r_eat, _, _ = flow(r_eat, try_eat)

        r_reject = opt_bnn.forward(try_reject)
        for flow in flow_list:
            r_reject, _, _ = flow(r_reject, try_reject)

    # Take random action for epsilon greedy agents, calculate agent's reward
    eaten = r_eat > r_reject

    agent_reward = get_reward(eaten, edible)

    # Get rewards and update buffer
    action = np.array([1, 0] if eaten else [0, 1])
    bufferX.append(np.concatenate((context, action)))
    bufferY.append(agent_reward)

    # Calculate regret
    oracle = oracle_reward(edible)
    regret = oracle - agent_reward
    cum_regrets.append(cum_regrets[-1] + regret)

    ### Feed next mushroom
    l = len(bufferX)
    print('len(bufferX): ', l)
    idx_pool = range(l) if l >= 4096 else ((int(4096 // l) + 1) *
                                           list(range(l)))
    idx_pool = np.random.permutation(idx_pool[-4096:])
    context_pool = torch.Tensor([bufferX[i] for i in idx_pool]).to(device)
    value_pool = torch.Tensor([bufferY[i] for i in idx_pool]).to(device)

    for i in range(0, 4096, 64):
        x_ = context_pool[i:i + 64]
        y_ = value_pool[i:i + 64]

        measurement_set = sample_measurement_set(X=x_, num_data=64)

        gpss = prior_sample_functions(measurement_set, prior, num_sample=128).detach().float().to(device)
        gpss = gpss.squeeze()
        nnet_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)  # original_x_test
        nnet_samples = nnet_samples.squeeze()

        for flow in flow_list:

            nnet_samples, _, _ = flow(nnet_samples, measurement_set)


        functional_wdist = lipf(gpss, nnet_samples)
        functional_wdist = torch.mean(functional_wdist)

        ##################### variance constraint #########

        with torch.no_grad():
            gp_var = prior(measurement_set).variance
            # print('gp_var.shape: ', gp_var.shape)

        opt_samples = opt_bnn.sample_functions(measurement_set, num_sample=128).float().to(device)
        opt_samples = opt_samples.squeeze()

        for flow in flow_list:

            opt_samples, _, _ = flow(opt_samples, measurement_set)

        opt_var = torch.std(opt_samples, 0) ** 2
        # print('opt_var.shape: ', opt_var.shape)
        dif_var = (gp_var - opt_var).abs()
        # dif_var = torch.sum(dif_var)
        dif_var = torch.mean(dif_var)
        ######

        pred_y = opt_bnn.forward(x_).t()

        for flow in flow_list:

            pred_y, _, _ = flow(pred_y, x_)

        pred_y = pred_y.squeeze().flatten()

        train_loss = loss(pred_y, y_) + f_coeff * functional_wdist + 1e-2 * dif_var

        opt_bnn_optimizer.zero_grad()

        train_loss.backward()

        opt_bnn_optimizer.step()

        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch+6003, train_loss),
              datetime.now().replace(microsecond=0) - start_time)

    if (epoch + 1) % 10 == 0:
        # print('cum_regrets: ', net_gwi.cum_regrets[-1])
        print('cum_regrets: ', cum_regrets[-1])
        df = pd.DataFrame.from_dict(cum_regrets)
        df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_fvimc3' + 'mushroom_regrets0.6.csv')

##############

print('a_list: ', a_list)

# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("============================================================================================")




























