# This code is based on the code by Yarin Gal used for his
# paper "Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning" 

# This file contains code to train dropout networks on the UCI datasets using the following algorithm:
# 1. Create 20 random splits of the training-test dataset.
# 2. For each split:
# 3.   Create a validation (val) set taking 20% of the training set.
# 4.   Get best hyperparameters: dropout_rate and tau by training on (train-val) set and testing on val set.
# 5.   Train a network on the entire training set with the best pair of hyperparameters.
# 6.   Get the performance (MC RMSE and log-likelihood) on the test set.
# 7. Report the averaged performance (Monte Carlo RMSE and log-likelihood) on all 20 splits.

import math
import numpy as np
import argparse
import sys, os
from scipy.special import logsumexp
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn import Parameter
from torch.utils.data import DataLoader, TensorDataset

import time
from subprocess import call

parser = argparse.ArgumentParser()

parser.add_argument('--dataset', default='yacht', help='name of the UCI Dataset directory.')
parser.add_argument('--epochx', default=50, type=int, help='multiplier for the number of epochs for training.')
parser.add_argument('--batchsize', default=128, type=int, help='batch size')

parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--method', type=str, default='ours')

parser.add_argument('--MC-train', type=int, default=1)
parser.add_argument('--MC-test', type=int, default=10000)

parser.add_argument('--droprate', type=float, default=0.5)
parser.add_argument('--klw', default=0.005, type=float, help='the KL annealing')
parser.add_argument('--num-HH', default=2, type=int, help='number of transformation steps T.')

args = parser.parse_args()
warnings.filterwarnings("ignore")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class HHtrans(nn.Module):
    def __init__(self):
        super(HHtrans, self).__init__()

    def forward(self, i, v, s, H):
        K = v.shape[1]
        vvT = torch.bmm(v.unsqueeze(2), v.unsqueeze(1))
        norm_sq = torch.sum(v * v, 1)
        norm_sq = norm_sq.unsqueeze(-1).unsqueeze(-1).expand(norm_sq.size(0), K, K)
        H[str(i)] = torch.eye(K, K).to(device) - 2 * vvT / norm_sq
        s_new = torch.bmm(H[str(i)], s.unsqueeze(2)).squeeze(2)
        return s_new

class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 padding=0, dilation=1, bias=True, groups=1):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                     padding, dilation, groups, bias)

    def forward(self, x):
        return F.conv2d(x, self.weight,
                        self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

class Linear(nn.Module):
    def __init__(self, in_features, out_features, args, alpha=1.):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.args = args
        self.W = Parameter(torch.Tensor(out_features, in_features))
        self.bias = Parameter(torch.Tensor(1, out_features))


        log_alpha = (torch.ones(self.in_features) * alpha).log()
        self.log_alpha_ = nn.Parameter(log_alpha)

        # Householder trans
        self.num_HH = args.num_HH
        self.HHtrans = HHtrans()

        self.v_layers = nn.ModuleList()

        for i in range(0, self.num_HH):
            self.v_layers.append(nn.Linear(in_features, in_features))

        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.W, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def tile(self, A, dim, n_tile):
        init_dim = A.size(dim)
        repeat_idx = [1] * A.dim()
        repeat_idx[dim] = n_tile
        A = A.repeat(*(repeat_idx))
        order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
        return torch.index_select(A, dim, order_index.to(device))

    def q_s_HHtrans(self, s, v0, H):
        v = {}
        if self.num_HH > 0:
            v['0'] = v0
            for i in range(0, self.num_HH):
                v[str(i + 1)] = self.v_layers[i](v[str(i)])
                v[str(i + 1)] = F.leaky_relu(v[str(i + 1)])
                s[str(i + 1)] = self.HHtrans(i + 1, v[str(i + 1)], s[str(i)], H)
            return s[str(self.num_HH)]
        return s['0']

    def forward(self, x):
        alpha = self.log_alpha_.exp()
        r = {}
        s = {}
        H = {}

        r['0'] = torch.sqrt(alpha) * torch.randn(x.size()).to(device)
        s['0'] = 1 + r['0']
        v0 = x
        r_K = self.q_s_HHtrans(r, v0, H)
        s_K = 1 + r_K

        self.U = H[str(self.num_HH)]
        for i in reversed(range(1, self.num_HH)):  # inverse order
            self.U = torch.bmm(self.U, H[str(i)])

        X_noised = x * s_K
        activation = F.linear(X_noised, self.W)
        return activation + self.bias

    def kl_reg(self):
        alpha = self.log_alpha_.exp()
        M, K = self.U.shape[:2]
        kl = torch.log((1 + torch.sum(alpha * self.U ** 2, dim=-1)) / alpha.unsqueeze(0).expand(M, K))
        kl = torch.sum(kl, dim=-1)
        kl = kl * self.out_features
        return kl.mean() / 2

class Net(nn.Module):
    def __init__(self, args, input_dim, tau, droprate=0.1, kl_weight=1.0):
        super(Net, self).__init__()
        self.args = args
        self.kl_weight = kl_weight
        hidden_size = 50
        self.fc1 = Linear(input_dim, hidden_size, args=self.args, alpha=droprate / (1 - droprate))
        self.fc2 = Linear(hidden_size, 1, args=self.args, alpha=droprate / (1 - droprate))

    def forward(self, x):
        input = x.clone().detach()
        x = F.relu(self.fc1(x.float().to(device)))
        x = self.fc2(x)
        return x

    def fit(self, X_train, y_train_normalized, optimizer, num_samples, batch_size=128, n_epochs=400):
        train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train_normalized))
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        for t in range(n_epochs):
            train_loss = 0.
            for i, (data, target) in enumerate(train_loader):
                torch.cuda.empty_cache()
                optimizer.zero_grad()
                prediction = self(data.to(device))
                kl = 0.0
                for module in self.children():
                    if hasattr(module, 'kl_reg'):
                        kl = kl + module.kl_reg()
                kl = kl / num_samples
                mse = F.mse_loss(prediction, target.to(device))
                loss = mse + self.kl_weight * kl 
                loss.backward()
                optimizer.step()

                train_loss += loss

            if t % 100 == 0:
                print(t, 'train_loss', train_loss / X_train.shape[0] * batch_size)
            del train_loss

    def predict(self, X_test):
        y_preds = self(X_test.cpu())
        return y_preds.cpu().detach().numpy()


class Method:

    def __init__(self, X_train, y_train, X_validation, y_validation, args, input_dim, normalize=False, n_epochs=400,
                 tau=1.0, dropout=0.1, kl_weight=1.0):
        """
            Constructor for the class implementing a Bayesian neural network
            trained with the probabilistic back propagation method.

            @param X_train      Matrix with the features for the training data.
            @param y_train      Vector with the target variables for the
                                training data.
            @param n_hidden     Vector with the number of neurons for each
                                hidden layer.
            @param n_epochs     Numer of epochs for which to train the
                                network. The recommended value 40 should be
                                enough.
            @param normalize    Whether to normalize the input features. This
                                is recommended unles the input vector is for
                                example formed by binary features (a
                                fingerprint). In that case we do not recommend
                                to normalize the features.
            @param tau          Tau value used for regularization
            @param dropout      Dropout rate for all the dropout layers in the
                                network.
        """
        # We normalize the training data to have zero mean and unit standard
        # deviation in the training set if necessary

        if normalize:
            self.std_X_train = np.std(X_train, 0)
            self.std_X_train[self.std_X_train == 0] = 1
            self.mean_X_train = np.mean(X_train, 0)
        else:
            self.std_X_train = np.ones(X_train.shape[1])
            self.mean_X_train = np.zeros(X_train.shape[1])

        X_train = (X_train - np.full(X_train.shape, self.mean_X_train)) / \
                  np.full(X_train.shape, self.std_X_train)

        self.mean_y_train = np.mean(y_train)
        self.std_y_train = np.std(y_train)

        y_train_normalized = (y_train - self.mean_y_train) / self.std_y_train
        y_train_normalized = np.array(y_train_normalized, ndmin=2).T

        # We construct the network
        batch_size = args.batchsize  

        model = Net(args=args, input_dim=input_dim, tau=tau, droprate=dropout, kl_weight=kl_weight).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        # We iterate the learning process
        start_time = time.time()
        model.fit(X_train, y_train_normalized, optimizer=optimizer, num_samples=len(X_train), n_epochs=n_epochs, batch_size=batch_size)

        self.model = model
        self.tau = tau
        self.running_time = time.time() - start_time

    def predict(self, X_test, y_test):

        """
            Function for making predictions with the Bayesian neural network.

            @param X_test   The matrix of features for the test data
            @return m       The predictive mean for the test target variables.
            @return v       The predictive variance for the test target
                            variables.
            @return v_noise The estimated variance for the additive noise.

        """

        X_test = np.array(X_test, ndmin=2)
        y_test = torch.Tensor(np.array(y_test, ndmin=2).T)

        # We normalize the test set
        X_test = torch.Tensor((X_test - np.full(X_test.shape, self.mean_X_train)) / \
                              np.full(X_test.shape, self.std_X_train))

        # We compute the predictive mean and variance for the target variables # of the test data
        model = self.model
        standard_pred = model.predict(X_test)
        standard_pred = standard_pred * self.std_y_train + self.mean_y_train
        rmse_standard_pred = torch.mean((y_test.squeeze() - torch.Tensor(standard_pred).squeeze()) ** 2.) ** 0.5

        T = 10000
        Yt_hat = np.array([model.predict(X_test) for _ in range(T)])

        Yt_hat = Yt_hat * self.std_y_train + self.mean_y_train
        MC_pred = np.mean(Yt_hat, 0)
        rmse = torch.mean((y_test.squeeze() - torch.Tensor(MC_pred).squeeze()) ** 2.) ** 0.5

        # We compute the test log-likelihood
        ll = (logsumexp(-0.5 * self.tau * (y_test[None] - torch.Tensor(Yt_hat)) ** 2., 0) - np.log(T)
              - 0.5 * np.log(2 * np.pi) + 0.5 * np.log(self.tau))
        test_ll = np.mean(ll)

        return rmse_standard_pred, rmse, test_ll


np.random.seed(args.seed)
torch.manual_seed(args.seed)
epochs_multiplier = args.epochx

base_dir = 'UCI_Datasets/'
data_result_path = base_dir + 'yacht' + '/results/' + args.method
if not (os.path.isdir(data_result_path)):
    try:
        os.makedirs(data_result_path)
    except OSError:
        pass
suffix = '_' + str(args.num_HH) + "_step"

args.save_valid_model = data_result_path + "/save_valid_model_" + str(epochs_multiplier) + "_xepochs_" + str(args.klw) + "_klw_" + str(args.batchsize) + "_bs" + suffix + '.pth'
_RESULTS_VALIDATION_LL = data_result_path + "/validation_ll_" + str(epochs_multiplier) + "_xepochs_" + str(args.klw) + "_klw_" + str(args.batchsize) + "_bs" + suffix
_RESULTS_VALIDATION_RMSE = data_result_path + "/validation_rmse_" + str(epochs_multiplier) + "_xepochs_" + str(args.klw) + "_klw_" + str(args.batchsize) + "_bs" + suffix
_RESULTS_VALIDATION_MC_RMSE = data_result_path + "/validation_MC_rmse_" + str(epochs_multiplier) + "_xepochs_" + str(args.klw) + "_klw_" + str(args.batchsize) + "_bs" + suffix


_RESULTS_TEST_LL = data_result_path + "/test_ll_" + str(epochs_multiplier) + "_xepochs_" + str(args.klw) + "_klw_" + str(args.batchsize) + "_bs" + suffix
_RESULTS_TEST_RMSE = data_result_path + "/test_rmse_" + str(epochs_multiplier) + "_xepochs_" + str(args.klw) + "_klw_" + str(args.batchsize) + "_bs" + suffix
_RESULTS_TEST_MC_RMSE = data_result_path + "/test_MC_rmse_" + str(epochs_multiplier) + "_xepochs_" + str(args.klw) + "_klw_" + str(args.batchsize) + "_bs" + suffix
_RESULTS_TEST_TAU = data_result_path + "/test_tau_" + str(epochs_multiplier) + "_xepochs_" + str(args.klw) + "_klw_" + str(args.batchsize) + "_bs" + suffix
_RESULTS_TEST_LOG = data_result_path + "/log_results_" + str(epochs_multiplier) + "_xepochs_" + str(args.klw) + "_klw_" + str(args.batchsize) + "_bs" + suffix

_DATA_DIRECTORY_PATH = base_dir + 'yacht' + "/data/"
_DROPOUT_RATES_FILE = _DATA_DIRECTORY_PATH + "dropout_rates.txt"
_TAU_VALUES_FILE = _DATA_DIRECTORY_PATH + "tau_values.txt"
_DATA_FILE = _DATA_DIRECTORY_PATH + "data.txt"
_HIDDEN_UNITS_FILE = _DATA_DIRECTORY_PATH + "n_hidden.txt"
_EPOCHS_FILE = _DATA_DIRECTORY_PATH + "n_epochs.txt"
_INDEX_FEATURES_FILE = _DATA_DIRECTORY_PATH + "index_features.txt"
_INDEX_TARGET_FILE = _DATA_DIRECTORY_PATH + "index_target.txt"
_N_SPLITS_FILE = _DATA_DIRECTORY_PATH + "n_splits.txt"


def _get_index_train_test_path(split_num, train=True):
    """
       Method to generate the path containing the training/test split for the given
       split number (generally from 1 to 20).
       @param split_num      Split number for which the data has to be generated
       @param train          Is true if the data is training data. Else false.
       @return path          Path of the file containing the requried data
    """
    if train:
        return _DATA_DIRECTORY_PATH + "index_train_" + str(split_num) + ".txt"
    else:
        return _DATA_DIRECTORY_PATH + "index_test_" + str(split_num) + ".txt"

# We delete previous results
print("Removing existing result files...")
call(["rm", _RESULTS_VALIDATION_LL])
call(["rm", _RESULTS_VALIDATION_RMSE])
call(["rm", _RESULTS_VALIDATION_MC_RMSE])
call(["rm", _RESULTS_TEST_LL])
call(["rm", _RESULTS_TEST_TAU])
call(["rm", _RESULTS_TEST_RMSE])
call(["rm", _RESULTS_TEST_MC_RMSE])
call(["rm", _RESULTS_TEST_LOG])
print("Result files removed.")

print("Loading data and other hyperparameters...")
# We load the data
data = np.loadtxt(_DATA_FILE)

# We load the number of hidden units
n_hidden = np.loadtxt(_HIDDEN_UNITS_FILE).tolist()

# We load the number of training epocs
n_epochs = np.loadtxt(_EPOCHS_FILE).tolist() # 40 default

# We load the indexes for the features and for the target
index_features = np.loadtxt(_INDEX_FEATURES_FILE)
index_target = np.loadtxt(_INDEX_TARGET_FILE)

X = data[:, [int(i) for i in index_features.tolist()]]
y = data[:, int(index_target.tolist())]

# We iterate over the training test splits
n_splits = np.loadtxt(_N_SPLITS_FILE)
print("Done.")

errors, MC_errors, lls = [], [], []
for split in range(0, int(n_splits)):
    # We load the indexes of the training and test sets
    print('Loading file: ' + _get_index_train_test_path(split, train=True))
    print('Loading file: ' + _get_index_train_test_path(split, train=False))
    index_train = np.loadtxt(_get_index_train_test_path(split, train=True))
    index_test = np.loadtxt(_get_index_train_test_path(split, train=False))

    X_train = X[[int(i) for i in index_train.tolist()]]
    y_train = y[[int(i) for i in index_train.tolist()]]

    X_test = X[[int(i) for i in index_test.tolist()]]
    y_test = y[[int(i) for i in index_test.tolist()]]

    X_train_original = X_train
    y_train_original = y_train
    num_training_examples = int(0.8 * X_train.shape[0])
    X_validation = X_train[num_training_examples:, :]
    y_validation = y_train[num_training_examples:]
    X_train = X_train[0:num_training_examples, :]
    y_train = y_train[0:num_training_examples]

    # Printing the size of the training, validation and test sets
    print('Number of training examples: ' + str(X_train.shape[0]))
    print('Number of validation examples: ' + str(X_validation.shape[0]))
    print('Number of test examples: ' + str(X_test.shape[0]))
    print('Number of train_original examples: ' + str(X_train_original.shape[0]))

    # List of hyperparameters which we will try out using grid-search
    dropout_rates = np.loadtxt(_DROPOUT_RATES_FILE).tolist()
    tau_values = np.loadtxt(_TAU_VALUES_FILE).tolist()

    if isinstance(dropout_rates, float): dropout_rates = [dropout_rates]
    if isinstance(tau_values, float): tau_values = [tau_values]

    # We perform grid-search to select the best hyperparameters based on the highest log-likelihood value
    best_network = None
    best_ll = -float('inf')
    best_tau = 0
    best_dropout = 0

    for dropout_rate in dropout_rates:
        for tau in tau_values:
            print('Grid search step: ' + ' Tau: ' + str(tau) + ' Dropout rate: ' + str(dropout_rate))
            network = Method(X_train, y_train, X_validation, y_validation, args=args, input_dim=X_train.shape[1], normalize=True,
                                n_epochs=int(n_epochs * epochs_multiplier), tau=tau, dropout=dropout_rate, kl_weight=args.klw)

            # We obtain the test RMSE and the test ll from the validation sets
            error, MC_error, ll = network.predict(X_validation, y_validation)
            if (ll > best_ll):
                best_ll = ll
                best_network = network
                best_tau = tau
                best_dropout = dropout_rate
                print('Best log_likelihood changed to: ' + str(best_ll))
                print('Best tau changed to: ' + str(best_tau))
                print('Best dropout rate changed to: ' + str(best_dropout))

            # Storing validation results
            with open(_RESULTS_VALIDATION_RMSE, "a") as myfile:
                myfile.write('Dropout_Rate: ' + repr(dropout_rate) + ' Tau: ' + repr(tau) + ' :: ')
                myfile.write(repr(error) + '\n')

            with open(_RESULTS_VALIDATION_MC_RMSE, "a") as myfile:
                myfile.write('Dropout_Rate: ' + repr(dropout_rate) + ' Tau: ' + repr(tau) + ' :: ')
                myfile.write(repr(MC_error) + '\n')

            with open(_RESULTS_VALIDATION_LL, "a") as myfile:
                myfile.write('Dropout_Rate: ' + repr(dropout_rate) + ' Tau: ' + repr(tau) + ' :: ')
                myfile.write(repr(ll) + '\n')

    # Storing test results
    best_network = Method(X_train_original, y_train_original, X_validation, y_validation, args=args, input_dim=X_train.shape[1], normalize=True,
                                n_epochs=int(n_epochs * epochs_multiplier), tau=best_tau, dropout=best_dropout, kl_weight=args.klw)

    error, MC_error, ll = best_network.predict(X_test, y_test)
    print('MC_rmse, log_likelihood: ', MC_error, ll)

    with open(_RESULTS_TEST_RMSE, "a") as myfile:
        myfile.write('fold_ids ' + repr(split) + ', ' + repr(error) + ' , ' + 'Best_dropout_rate: ' + repr(best_dropout) + ' Best_tau: ' + repr(best_tau) + '\n')

    with open(_RESULTS_TEST_MC_RMSE, "a") as myfile:
        myfile.write('fold_ids ' + repr(split) + ', ' + repr(MC_error) + ' , ' + 'Best_dropout_rate: ' + repr(best_dropout) + ' Best_tau: ' + repr(best_tau) + '\n')

    with open(_RESULTS_TEST_LL, "a") as myfile:
        myfile.write('fold_ids ' + repr(split) + ', ' + repr(ll) + ' , ' + 'Best_dropout_rate: ' + repr(best_dropout) + ' Best_tau: ' + repr(best_tau) + '\n')

    with open(_RESULTS_TEST_TAU, "a") as myfile:
        myfile.write(repr(best_network.tau) + '\n')

    print("Tests on split " + str(split) + " complete.")
    errors += [error.numpy()]
    MC_errors += [MC_error.numpy()]
    lls += [ll]

with open(_RESULTS_TEST_LOG, "a") as myfile:
    myfile.write('errors %f +- %f (stddev) +- %f (std error), median %f 25p %f 75p %f \n' % (
        np.mean(errors), np.std(errors), np.std(errors) / math.sqrt(n_splits),
        np.percentile(errors, 50), np.percentile(errors, 25), np.percentile(errors, 75)))
    myfile.write('MC errors %f +- %f (stddev) +- %f (std error), median %f 25p %f 75p %f \n' % (
        np.mean(MC_errors), np.std(MC_errors), np.std(MC_errors) / math.sqrt(n_splits),
        np.percentile(MC_errors, 50), np.percentile(MC_errors, 25), np.percentile(MC_errors, 75)))
    myfile.write('lls %f +- %f (stddev) +- %f (std error), median %f 25p %f 75p %f \n' % (
        np.mean(lls), np.std(lls), np.std(lls) / math.sqrt(n_splits),
        np.percentile(lls, 50), np.percentile(lls, 25), np.percentile(lls, 75)))
