import os
from datetime import datetime
import sys
import math
import pandas as pd
from sklearn import preprocessing
from torch.autograd import Variable

import torch
import torch.nn as nn
import numpy as np
# import gpytorch
# from gp_prior import ExactGPModel, prior_sample_functions


from bnn import BNN
import torch.nn.functional as F
import argparse

from utils.logging import get_logger

from data import uci_woval

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from utils.utils import default_plotting_new as init_plotting


################################### Hyper-parameters ###################################

lr_bnn = 0.01
prior_coeff = 1
bnn_name_string = 'bnn'
random_seed = 123
max_epoch_num = 100
save_model_freq = 5000
log_model_freq = 3000
test_interval = 1000
num_sample = 1000
epochs = 10001
n_step_prior_pretraining = 100
lr_optbnn = 0.01
f_coeff = 10

torch.manual_seed(random_seed)
np.random.seed(random_seed)


################################### Network Architecture ###################################

n_units = 100
n_hidden = 2
hidden_dims = [n_units] * n_hidden
activation_fn = 'tanh'


print("============================================================================================")
################################## set device ##################################

# set device to cpu or cuda
device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

print("============================================================================================")
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_dir = log_dir + '/' + bnn_name_string + '/' + 'context_mushroom' + '/'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#### get number of log files in log directory
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)

#### create new log file for each run
log_f_name = log_dir + '/' + bnn_name_string + 'context_mushroom' + "_" + str(run_num) + ".csv"

print("current logging run number for " + bnn_name_string + " : ", run_num)
print("logging at : " + log_f_name)


print("============================================================================================")
################### checkpointing ###################

run_num_pretrained = 0  #### change this to prevent overwriting weights in same env_name folder

directory = "./pretrained"
if not os.path.exists(directory):
    os.makedirs(directory)

directory = directory + '/' + bnn_name_string + '/' + 'context_mushroom' + '/'
if not os.path.exists(directory):
    os.makedirs(directory)

checkpoint_path = directory + "{}_{}_{}.pth".format(bnn_name_string, random_seed, run_num_pretrained)
print("save checkpoint path : " + checkpoint_path)

################### savefigures ###################

results_folder = "./results"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

figures_folder = results_folder + '/' + bnn_name_string + '/' + 'context_mushroom' + '/'
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)


print("============================================================================================")
############# print all hyperparameters #############

print("learning rate: ", lr_bnn)
print("coefficient of prior regularization: ", prior_coeff)
print("random seed: ", random_seed)
print("max number of epoches: ", max_epoch_num)


print("============================================================================================")
############################## load and normalize data ##############################

# training_num, input_dim = train_x.shape
# test_num = test_x.shape[0]
# output_dim = 1
# is_continuous = True
#
# print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)
#
# original_x_train = torch.from_numpy(train_x).float().to(device)
# original_y_train = torch.from_numpy(train_y).float().to(device)
# original_x_test = torch.from_numpy(test_x).float().to(device)
# original_y_test = torch.from_numpy(test_y).float().to(device)
current_dir = os.getcwd()
print(current_dir)
root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_path)
data_path = os.path.join(root_path, 'data', 'uci', 'mushroom.data')
df = pd.read_csv(os.getcwd() + '/data/uci/mushroom.data', header=None)

# Set pandas to output all of the columns in output
df.columns = ['class', 'cap-shape','cap-surface','cap-color','bruises','odor','gill-attachment',
         'gill-spacing','gill-size','gill-color','stalk-shape','stalk-root',
         'stalk-surf-above-ring','stalk-surf-below-ring','stalk-color-above-ring','stalk-color-below-ring',
         'veil-type','veil-color','ring-number','ring-type','spore-color','population','habitat']

# Split context from label
X = pd.DataFrame(df, columns=df.columns[1:len(df.columns)], index=df.index)
# Put the class values (0th column) into Y
Y = df['class']

# Transform labels into one-hot encoded array
le = preprocessing.LabelEncoder()
le.fit(Y)
y = le.transform(Y)

# Temporary variable to avoid error
x_tmp = pd.DataFrame(X, columns=[X.columns[0]])

# Encode each feature column and add it to x_train
for colname in X.columns:
    le.fit(X[colname])
    #print(colname, le.classes_)
    x_tmp[colname] = le.transform(X[colname])

# Produce mushroom array: 8124 mushrooms, each with 117 one-hot encoded features
oh = preprocessing.OneHotEncoder()
oh.fit(x_tmp)
x = oh.transform(x_tmp).toarray()

training_num, input_dim = x.shape
output_dim = 1
is_continuous = True
print("training_num = ", training_num, " input_dim = ", input_dim, " output_dim = ", output_dim)

original_x_train = torch.from_numpy(x).float().to(device)
original_y_train = torch.from_numpy(y).float().to(device)


loss = nn.MSELoss(reduction='mean')


print("============================================================================================")
################################## define model ##################################


bnn = BNN(input_dim, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)

bnn_optimizer = torch.optim.Adam([{'params': bnn.parameters(),  'lr': lr_bnn}])


print("============================================================================================")

############ define reward function####################
def get_reward(eaten, edible):
    # REWARDS FOR AGENT
    #  Eat poisonous mushroom
    if not eaten:
        return 0
    if eaten and edible:
        return 5
    elif eaten and not edible:
        return 5 if np.random.rand() > 0.5 else -35

def oracle_reward(edible):
    return 5*edible

################ define mushroom net######################
Var = lambda x, dtype=torch.FloatTensor: Variable(
    torch.from_numpy(x).type(dtype)).to(device)
class MushroomNet():
    def __init__(self, n_weight_sampling=2):
        # self.label = label
        self.n_weight_sampling = n_weight_sampling
        self.epsilon = 0
        self.net = None
        self.loss, self.optimizer = None, None
        self.cum_regrets = [0]
        self.bufferX, self.bufferY = [], []

    def init_buffer(self):
        for i in np.random.choice(range(len(x)), 4096):
            eat = np.random.rand() > 0.5
            action = [1, 0] if eat else [0, 1]
            self.bufferX.append(np.concatenate((x[i], action)))
            self.bufferY.append(get_reward(eat, y[i]))

    # Use NN to decide next action
    def try_(self, mushroom):
        samples = self.n_weight_sampling
        context, edible = x[mushroom], y[mushroom]
        try_eat = Var(np.concatenate((context, [1, 0])))
        try_reject = Var(np.concatenate((context, [0, 1])))

        # Calculate rewards using model
        with torch.no_grad():
            # r_eat = sum([self.net.forward(try_eat) for _ in range(samples)]).item()
            # r_reject = sum([self.net.forward(try_reject) for _ in range(samples)]).item()
            r_eat = self.net.forward(try_eat)
            r_reject = self.net.forward(try_reject)

        # Take random action for epsilon greedy agents, calculate agent's reward
        eaten = r_eat > r_reject
        # if np.random.rand() < self.epsilon:
        #     eaten = (np.random.rand() < .5)
        agent_reward = get_reward(eaten, edible)

        # Get rewards and update buffer
        action = np.array([1, 0] if eaten else [0, 1])
        self.bufferX.append(np.concatenate((context, action)))
        self.bufferY.append(agent_reward)

        # Calculate regret
        oracle = oracle_reward(edible)
        regret = oracle - agent_reward
        self.cum_regrets.append(self.cum_regrets[-1] + regret)

    # Feed next mushroom
    def update(self, mushroom):
        self.try_(mushroom)
        # idx pool
        l = len(self.bufferX)
        print('len(bufferX): ', l)
        idx_pool = range(l) if l >= 4096 else ((int(4096 // l) + 1) *
                                               list(range(l)))
        idx_pool = np.random.permutation(idx_pool[-4096:])
        context_pool = torch.Tensor([self.bufferX[i] for i in idx_pool]).to(device)
        value_pool = torch.Tensor([self.bufferY[i] for i in idx_pool]).to(device)
        # print(type(context_pool), type(value_pool))
        # print('context_pool.shape: ', context_pool.shape)
        # print('value_pool.shape: ', value_pool.shape)
        for i in range(0, 4096, 64):
            self.loss_step(context_pool[i:i + 64], value_pool[i:i + 64], i // 64)
        # self.loss_step(context_pool, value_pool)

    def update2(self, mushroom):
        self.try_(mushroom)
        l = len(self.bufferX)
        idx_pool = range(l)
        idx_pool = np.random.permutation(idx_pool)
        context_pool = torch.Tensor([self.bufferX[i] for i in idx_pool]).to(device)
        value_pool = torch.Tensor([self.bufferY[i] for i in idx_pool]).to(device)
        self.loss_step(context_pool, value_pool)




    # def loss_step(self, x, y):
    #     raise NotImplementedError

class BBB_MNet_kl(MushroomNet):
    def __init__(self, lr):
        super().__init__()
        self.net = BNN(input_dim+2, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
        self.mse = lambda x, y: .5*((x-y)**2).sum()
        self.lr = lr
        self.optimizer = torch.optim.Adam([{'params': self.net.parameters(),  'lr': self.lr}])

    def loss_step(self, x, y, n_samples = 2):
        # beta = 2 ** (64 - (batch_id + 1)) / (2 ** 64 - 1)
        pred_y, distance_prior = self.net.forward_kl(x)
        pred_y = pred_y.squeeze().flatten()
        # print('pred_y.shape :', pred_y.shape)
        # print('y.shape: ', y.shape)
        # print(type(outputs))
        train_loss = loss(pred_y, y) + prior_coeff * distance_prior
        self.optimizer.zero_grad()
        train_loss.backward()
        self.optimizer.step()
        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
              datetime.now().replace(microsecond=0) - start_time)


class BBB_MNet_w(MushroomNet):
    def __init__(self, lr):
        super().__init__()
        self.net = BNN(input_dim+2, output_dim, hidden_dims, activation_fn, is_continuous, scaled_variance=True).to(device)
        self.mse = lambda x, y: .5*((x-y)**2).sum()
        self.lr = lr
        self.optimizer = torch.optim.Adam([{'params': self.net.parameters(),  'lr': self.lr}])

    def loss_step(self, x, y, n_samples = 2):
        # beta = 2 ** (64 - (batch_id + 1)) / (2 ** 64 - 1)
        pred_y, distance_prior = self.net.forward_w(x)
        pred_y = pred_y.squeeze().flatten()
        # print('pred_y.shape :', pred_y.shape)
        # print('y.shape: ', y.shape)
        # print(type(outputs))
        train_loss = loss(pred_y, y) + prior_coeff * distance_prior
        self.optimizer.zero_grad()
        train_loss.backward()
        self.optimizer.step()
        # print
        print("epoch : {} \t\t training loss \t\t : {}".format(epoch, train_loss),
              datetime.now().replace(microsecond=0) - start_time)
################################## start training ##################################

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

# logging file
log_f = open(log_f_name, "w+")
log_f.write('epoch,training_loss\n')
scheduler = torch.optim.lr_scheduler.StepLR(bnn_optimizer, 5000, gamma=0.9, last_epoch=-1)

net_kl = BBB_MNet_kl(lr_bnn)
net_w = BBB_MNet_w(lr_bnn)
for epoch in range(epochs):
    mushroom = np.random.randint(len(x))
    net_kl.update(mushroom)
    net_w.update(mushroom)
    if (epoch + 1) % 10 == 0:
        print('cum_regrets_kl: ', net_kl.cum_regrets[-1])
        print('cum_regrets_w: ', net_w.cum_regrets[-1])
        # df = pd.DataFrame.from_dict({net.cum_regrets})
        # df.to_csv(results_folder + '/' + bnn_name_string + '/' + 'context_mushroom' + 'mushroom_regrets.csv')


plt.clf()
figure = plt.figure(figsize=(8, 5.5), facecolor='white')
init_plotting()
# plt.figure()
indices = np.arange(10001)[::100]
cum_regrets_kl = np.array(net_kl.cum_regrets)
cum_regrets_w = np.array(net_w.cum_regrets)
plt.plot(indices, cum_regrets_kl[indices], '-ro', ms=3, label='kl')
plt.plot(indices, cum_regrets_w[indices], '-ko', ms=3, label='w')
plt.ylabel(r'cum_regrets')
plt.tight_layout()
plt.xlabel('Iteration')
plt.tight_layout()
plt.legend(loc='upper center', ncol=3, fontsize='small')
plt.savefig(figures_folder + '/cum_regrets.pdf')


# print total training time
print("============================================================================================")
end_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)
print("Finished training at (GMT) : ", end_time)
print("Total training time  : ", end_time - start_time)
print("============================================================================================")