import warnings
warnings.simplefilter("ignore", UserWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
from torch.utils.data import TensorDataset, DataLoader

from torch.nn.functional import relu
import gpytorch

import itertools
import numpy as np
import math

# from utils.spectral import SpectralScoreEstimator

device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

def real_imaginary_relu(z):
    return relu(z.real) + 1.j * relu(z.imag)

def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)

def init_norm_layer(input_dim, norm_layer):
    if norm_layer == "batchnorm":
        return nn.BatchNorm1d(input_dim, eps=0, momentum=None,
                              affine=False, track_running_stats=False)
    elif norm_layer is None:
        return nn.Identity()


def sample_measurement_set(X, num_data):

    # sample from old using geometric distribution
    # p = 2.0 / num_data
    # g = torch.distributions.Geometric(p)
    # n = g.sample()
    # count = 0
    # while n > num_data:
    #     n = g.sample()
    #     count += 1
    #     if count > 10:
    #         n = num_data
    #         break

    n = torch.Tensor([num_data])
    # sample measurement set with size n
    perm = torch.randperm(X.size()[0])
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx]

    return measurement_set



class Gradientnetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Gradientnetwork, self).__init__()
        self.lin1 = nn.Linear(input_dim, hidden_dim)
        self.relu1 = nn.ReLU()
        self.lin2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.ReLU()
        self.lin3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = x.float()
        x = self.lin1(x)
        x = self.relu1(x)
        x = self.lin2(x)
        x = self.relu2(x)
        x = self.lin3(x)
        return x

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, input_dim):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.LinearMean(input_dim)
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PeriodicKernel() + gpytorch.kernels.RBFKernel())
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel() * gpytorch.kernels.PeriodicKernel())
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=0.5))
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class BatchIndependentMultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([2]))
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([2])), batch_shape=torch.Size([2]))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn(gpytorch.distributions.MultivariateNormal(mean_x, covar_x))


class BNNMLP(nn.Module):
    def __init__(self, n_in, n_out, W_mu=None, b_mu=None, W_std=None,
                 b_std=None, scaled_variance=False, prior_per='parameter'):
        """Initialization.
        Args:
            n_in: int, the size of the input data.
            n_out: int, the size of the output.
            W_std: float, the initial value of
                the standard deviation of the weights.
            b_std: float, the initial value of
                the standard deviation of the biases.
            prior_per: str, indicates whether using different prior for
                each parameter, option `parameter`, or use the share the
                prior for all parameters in the same layer, option `layer`.
        """
        super(BNNMLP, self).__init__()

        self.n_in = n_in
        self.n_out = n_out
        self.scaled_variance = scaled_variance

        if W_mu is None:
            if self.scaled_variance:
                W_mu = 1.
            else:
                W_mu = 1. / math.sqrt(self.n_in)
        if b_mu is None:
            b_mu = 1.

        if W_std is None:
            if self.scaled_variance:
                W_std = 1.
            else:
                W_std = 1. / math.sqrt(self.n_in)
        if b_std is None:
            b_std = 1.

        if prior_per == "layer":
            W_shape, b_shape = (1), (1)
        elif prior_per == "parameter":
            W_shape, b_shape = [self.n_in, self.n_out], [self.n_out]
        else:
            raise ValueError("Accepted values: `parameter` or `layer`")

        #
        bound = 1. / math.sqrt(self.n_in)
        m = torch.distributions.uniform.Uniform(torch.tensor([-bound]), torch.tensor([bound]))

        # W_mu_tmp = m.sample(W_shape).squeeze(-1)
        # b_mu_tmp = m.sample(b_shape).squeeze(-1)
        # W_mu_tmp = torch.zeros(W_shape)
        # b_mu_tmp = torch.zeros(b_shape)
        W_mu_tmp = torch.randn(W_shape)
        b_mu_tmp = torch.randn(b_shape)

        self.W_mu = nn.Parameter(
            W_mu_tmp, requires_grad=True)

        # self.W_std = nn.Parameter(
        #     torch.rand(W_shape) , requires_grad=True)
        self.W_std = nn.Parameter(
            torch.ones(W_shape) * 2.1, requires_grad=True)

        self.b_mu = nn.Parameter(
            b_mu_tmp, requires_grad=True)

        # self.b_std = nn.Parameter(
        #     torch.rand(b_shape), requires_grad=True)
        self.b_std = nn.Parameter(
            torch.ones(b_shape) * 2.1, requires_grad=True)

        # save prior
        self.W_mu_prior = torch.zeros(W_shape)
        self.b_mu_prior = torch.zeros(b_shape)
        self.W_std_prior = torch.zeros(W_shape) + 1.0
        self.b_std_prior = torch.zeros(b_shape) + 1.0

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu + F.softplus(self.W_std) * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)
        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.mm(X, W) + b

        return output

    def forward_with_wb(self, X, W, b):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        # W = self.W_mu + F.softplus(self.W_std) * \
        #     torch.randn((self.n_in, self.n_out), device=self.W_std.device)
        # if self.scaled_variance:
        #     W = W / math.sqrt(self.n_in)
        # b = self.b_mu + F.softplus(self.b_std) * \
        #     torch.randn((self.n_out), device=self.b_std.device)

        output = torch.mm(X, W) + b

        return output

    def forward_mean(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu
        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu

        output = torch.mm(X, W) + b

        return output

    def forward_mean_wb(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu
        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu

        output = torch.mm(X, W) + b

        return output, W, b

    def forward_wb(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu + F.softplus(self.W_std) * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)
        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.mm(X, W) + b

        return output, W, b

    def forward_epsilon(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        epsilon_w = torch.randn((self.n_in, self.n_out), device=self.W_std.device)
        W = self.W_mu + F.softplus(self.W_std) * epsilon_w

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)

        epsilon_b = torch.randn((self.n_out), device=self.b_std.device)
        b = self.b_mu + F.softplus(self.b_std) * epsilon_b

        output = torch.mm(X, W) + b
        W = W.abs()
        sumw = torch.sum(W)
        # W = W ** 2
        # sumw = torch.sum(W)
        return output, epsilon_w, epsilon_b, sumw

    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        Ws = self.W_mu + F.softplus(self.W_std) * \
             torch.randn([num_sample, self.n_in, self.n_out],
                         device=self.W_std.device)
        if self.scaled_variance:
            Ws = Ws / math.sqrt(self.n_in)
        bs = self.b_mu + F.softplus(self.b_std) * \
             torch.randn([num_sample, 1, self.n_out],
                         device=self.b_std.device)

        output = torch.matmul(X, Ws) + bs

        return output

    def sample_w_tilde(self):
        W = self.W_mu + F.softplus(self.W_std) * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)
        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        # return torch.cat([torch.flatten(W), torch.flatten(b)])
        return torch.cat([W.view(-1), b.view(-1)])

    def w_size(self):
        w_size = torch.cat([self.W_mu.view(-1), self.b_mu.view(-1)]).shape[0]

        return w_size

class ifBNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 W_mu=None, b_mu=None, W_std=None, b_std=None, scaled_variance=False, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
            W_std: float, the initial value of the logarithm of
                the standard deviation of the weights.
            b_std: float, the initial value of the logarithm of
                the standard deviation of the biases.
        """
        super(ifBNN, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        self.eta = 0.1
        self.n_eigen_threshold = 2

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        if b_std is None:
            b_std = W_std

        # Initialize layers

        self.input_layer = BNNMLP(
            input_dim, hidden_dims[0], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.input_layer = nn.Linear(input_dim, hidden_dims[0])

        # self.norm_layer1 = init_norm_layer(hidden_dims[0], self.norm_layer)
        self.norm_layer1 = nn.Identity()

        self.mid_layer = BNNMLP(
            hidden_dims[0], hidden_dims[1], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.mid_layer = nn.Linear(hidden_dims[0], hidden_dims[1])

        # self.norm_layer2 = init_norm_layer(hidden_dims[1], self.norm_layer)
        self.norm_layer2 = nn.Identity()

        self.output_layer = BNNMLP(
            hidden_dims[1], output_dim, W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.output_layer = nn.Linear(hidden_dims[1], output_dim)

        # self.final_output_layer = nn.Tanh()
        # self.final_output_layer = nn.Softmax(dim=-1)

        # self.mean_layer1 = nn.Linear(output_dim, output_dim*5)
        # self.mean_layer2 = nn.Linear(output_dim*5, output_dim)
        # self.std_layer1 = nn.Linear(output_dim, output_dim*5)
        # self.std_layer2 = nn.Linear(output_dim*5, output_dim)

        # TODO: define gradient network, input: X, output: gradient of nn weights
        # self.gradient_network = nn.linear()
        #
        # self.gradnet_optimizor = torch.optim.Adam()

    # def forward(self, X):
    #     """Performs forward pass given input data.
    #     Args:
    #         X: torch.tensor, [batch_size, input_dim], the input data.
    #         sample: boolean, whether or not perform forward pass using
    #             sampled weights.
    #     Returns:
    #         torch.tensor, [batch_size, output_dim], the output data.
    #     """
    #     X = X.view(-1, self.input_dim)
    #
    #     X = self.activation_fn(self.input_layer(X))
    #     X = self.activation_fn(self.mid_layer(X))
    #     X = self.output_layer(X)
    #     # X = self.final_output_layer(X)
    #
    #     return X

    def forward_normal(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.input_layer(X))
        X = self.activation_fn(self.mid_layer(X))
        X = self.output_layer(X)
        # X = self.final_output_layer(X)

        return X

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_mean(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_mean(X)))
        X = self.output_layer.forward_mean(X)
        # X = self.final_output_layer(X)

        return X

    def forward_mean_wb(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X, W1, b1 = self.input_layer.forward_mean_wb(X)
        # X = self.activation_fn(self.norm_layer1(X))
        X = self.activation_fn(X)

        X, W2, b2 = self.mid_layer.forward_mean_wb(X)
        # X = self.activation_fn(self.norm_layer2(X))
        X = self.activation_fn(X)

        X, W3, b3 = self.output_layer.forward_mean_wb(X)

        W1 = W1.view(-1)
        W2 = W2.view(-1)
        W3 = W3.view(-1)

        Wb = torch.cat((W1, b1, W2, b2, W3, b3))

        return X, Wb

    def forward_diff_wb(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X_raw = X.view(-1, self.input_dim)

        X, W1, b1 = self.input_layer.forward_mean_wb(X_raw)
        # X = self.activation_fn(self.norm_layer1(X))
        X = self.activation_fn(X)

        X, W2, b2 = self.mid_layer.forward_mean_wb(X)
        # X = self.activation_fn(self.norm_layer2(X))
        X = self.activation_fn(X)

        X_mean, W3, b3 = self.output_layer.forward_mean_wb(X)

        W1 = W1.view(-1)
        W2 = W2.view(-1)
        W3 = W3.view(-1)

        Wb_mean = torch.cat((W1, b1, W2, b2, W3, b3))

        X, W1, b1 = self.input_layer.forward_wb(X_raw)
        # X = self.activation_fn(self.norm_layer1(X))
        X = self.activation_fn(X)

        X, W2, b2 = self.mid_layer.forward_wb(X)
        # X = self.activation_fn(self.norm_layer2(X))
        X = self.activation_fn(X)

        X, W3, b3 = self.output_layer.forward_wb(X)

        W1 = W1.view(-1)
        W2 = W2.view(-1)
        W3 = W3.view(-1)

        Wb = torch.cat((W1, b1, W2, b2, W3, b3))

        X = X - X_mean
        Wb = Wb - Wb_mean

        return X, Wb

    def get_variance(self):

        variance = torch.cat((self.input_layer.W_std.view(-1), self.input_layer.b_std.view(-1),
                              self.mid_layer.W_std.view(-1), self.mid_layer.b_std.view(-1),
                              self.output_layer.W_std.view(-1), self.output_layer.b_std.view(-1)))

        return F.softplus(variance) ** 2

    def forward_wb(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X, W1, b1 = self.input_layer.forward_wb(X)
        # X = self.activation_fn(self.norm_layer1(X))
        X = self.activation_fn(X)

        X, W2, b2 = self.mid_layer.forward_wb(X)
        # X = self.activation_fn(self.norm_layer2(X))
        X = self.activation_fn(X)

        X, W3, b3 = self.output_layer.forward_wb(X)

        W1 = W1.view(-1)
        W2 = W2.view(-1)
        W3 = W3.view(-1)

        Wb = torch.cat((W1, b1, W2, b2, W3, b3))

        return X, Wb

    def forward_with_wb(self, X, wb):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        # Wb = torch.cat((W1, b1, W2, b2, W3, b3))
        s1, s2 = self.input_layer.W_mu.shape
        n_w1 = s1 * s2
        s1 = self.input_layer.b_mu.shape
        n_b1 = s1[0]
        w1 = wb[:n_w1]
        b1 = wb[n_w1:(n_w1+n_b1)]

        X = self.input_layer.forward_with_wb(X, w1.reshape_as(self.input_layer.W_mu), b1.reshape_as(self.input_layer.b_mu))
        # X = self.activation_fn(self.norm_layer1(X))
        X = self.activation_fn(X)

        s1, s2 = self.mid_layer.W_mu.shape
        n_w2 = s1 * s2
        s1 = self.mid_layer.b_mu.shape
        n_b2 = s1[0]
        w2 = wb[(n_w1 + n_b1):(n_w1 + n_b1 + n_w2)]
        b2 = wb[(n_w1 + n_b1 + n_w2):(n_w1 + n_b1 + n_w2 + n_b2)]

        X = self.mid_layer.forward_with_wb(X, w2.reshape_as(self.mid_layer.W_mu), b2.reshape_as(self.mid_layer.b_mu))
        # X = self.activation_fn(self.norm_layer2(X))
        X = self.activation_fn(X)

        s1, s2 = self.output_layer.W_mu.shape
        n_w3 = s1 * s2
        s1 = self.output_layer.b_mu.shape
        n_b3 = s1[0]
        w3 = wb[(n_w1 + n_b1 + n_w2 + n_b2):(n_w1 + n_b1 + n_w2 + n_b2 + n_w3)]
        b3 = wb[(n_w1 + n_b1 + n_w2 + n_b2 + n_w3):(n_w1 + n_b1 + n_w2 + n_b2 + n_w3 + n_b3)]

        X = self.output_layer.forward_with_wb(X, w3.reshape_as(self.output_layer.W_mu), b3.reshape_as(self.output_layer.b_mu))

        return X


    def forward_epsilon(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X, epsilon_w1, epsilon_b1, sumw1= self.input_layer.forward_epsilon(X)
        X = self.activation_fn(self.norm_layer1(X))

        X, epsilon_w2, epsilon_b2, sumw2 = self.mid_layer.forward_epsilon(X)
        X = self.activation_fn(self.norm_layer2(X))

        X, epsilon_w3, epsilon_b3, sumw3= self.output_layer.forward_epsilon(X)
        # X = self.final_output_layer(X)
        sumw_layer = sumw1 + sumw2 + sumw3

        return X, [epsilon_w1, epsilon_b1, epsilon_w2, epsilon_b2, epsilon_w3, epsilon_b3]

    def sample_functions(self, X, num_sample):
        """Performs predictions using `n_samples` set of weights.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            num_sample: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [batch_size, n_samples, output_dim], the output
            data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)

        return X

    def distance_prior(self, X):
        X = X.view(-1, self.input_dim)
        w_prior = 0
        num_layer = 0

        # X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        w_prior = w_prior + self.input_layer.wd()
        num_layer = num_layer + 1

        # X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        w_prior = w_prior + self.mid_layer.wd()
        num_layer = num_layer + 1

        # X = self.output_layer(X)
        w_prior = w_prior + self.output_layer.wd()
        num_layer = num_layer + 1

        return w_prior / num_layer

    def distance_prior_wd(self, X):
        X = X.view(-1, self.input_dim)
        w_prior = 0
        num_layer = 0

        # X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        w_prior = w_prior + self.input_layer.wd()
        num_layer = num_layer + 1

        # X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        w_prior = w_prior + self.mid_layer.wd()
        num_layer = num_layer + 1

        # X = self.output_layer(X)
        w_prior = w_prior + self.output_layer.wd()
        num_layer = num_layer + 1

        return w_prior / num_layer

    def distance_prior_kl(self, X):
        X = X.view(-1, self.input_dim)
        w_prior = 0
        num_layer = 0

        # X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        w_prior = w_prior + self.input_layer.kl()
        num_layer = num_layer + 1

        # X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        w_prior = w_prior + self.mid_layer.kl()
        num_layer = num_layer + 1

        # X = self.output_layer(X)
        w_prior = w_prior + self.output_layer.kl()
        num_layer = num_layer + 1

        return w_prior / num_layer

    def sample_w_star(self):
        w_star = torch.cat([self.input_layer.sample_w_tilde(),
                           self.mid_layer.sample_w_tilde(),
                           torch.zeros_like(self.output_layer.sample_w_tilde())], 0)
        return w_star

    def last_layer_wb_size(self):

        return self.output_layer.w_size()

    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data using a number of parameter samples.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        # X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        # X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.input_layer.forward_eval(X, num_sample))
        X = self.activation_fn(self.mid_layer.forward_eval(X, num_sample))
        X = self.output_layer.forward_eval(X, num_sample)
        # X = self.final_output_layer(X)

        return X

    def train_gradnet(self, prior_mean, prior_cov, X_batch, y_batch, measurement_set):

        # TODO: train gradient network using self.optimizor

        # get delta function

        # evaluate posterior mean and cov

        return -1

    def eval_gradnet(self):

        # TODO: output gradient given X

        return -1

    def save(self, checkpoint_path):
        torch.save(self.state_dict(), checkpoint_path)

    def load(self, checkpoint_path):
        self.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))

