import warnings
warnings.simplefilter("ignore", UserWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
from torch.utils.data import TensorDataset, DataLoader
from collections import OrderedDict

import itertools
import numpy as np
import math


def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)

def init_norm_layer(input_dim, norm_layer):
    if norm_layer == "batchnorm":
        return nn.BatchNorm1d(input_dim, eps=0, momentum=None,
                              affine=False, track_running_stats=False)
    elif norm_layer is None:
        return nn.Identity()


class BNNMLP(nn.Module):
    def __init__(self, n_in, n_out, W_mu=None, b_mu=None, W_std=None,
                 b_std=None, scaled_variance=False, prior_per='parameter'):
        """Initialization.
        Args:
            n_in: int, the size of the input data.
            n_out: int, the size of the output.
            W_std: float, the initial value of
                the standard deviation of the weights.
            b_std: float, the initial value of
                the standard deviation of the biases.
            prior_per: str, indicates whether using different prior for
                each parameter, option `parameter`, or use the share the
                prior for all parameters in the same layer, option `layer`.
        """
        super(BNNMLP, self).__init__()

        self.n_in = n_in
        self.n_out = n_out
        self.scaled_variance = scaled_variance

        if W_mu is None:
            if self.scaled_variance:
                W_mu = 1.
            else:
                W_mu = 1. / math.sqrt(self.n_in)
        if b_mu is None:
            b_mu = 1.

        if W_std is None:
            if self.scaled_variance:
                W_std = 1.
            else:
                W_std = 1. / math.sqrt(self.n_in)
        if b_std is None:
            b_std = 1.

        if prior_per == "layer":
            W_shape, b_shape = [1], [1]    # (1), (1)
        elif prior_per == "parameter":
            W_shape, b_shape = [self.n_in, self.n_out], [self.n_out]
        else:
            raise ValueError("Accepted values: `parameter` or `layer`")

        #
        bound = 1. / math.sqrt(self.n_in)
        m = torch.distributions.uniform.Uniform(torch.tensor([-bound]), torch.tensor([bound]))
        # m = torch.distributions.uniform.Uniform(torch.tensor([-1.0]), torch.tensor([1.0]))
        # W_mu_tmp = m.sample(W_shape).squeeze()
        # b_mu_tmp = m.sample(b_shape).squeeze()

        W_mu_tmp = m.sample(W_shape).squeeze(-1)
        b_mu_tmp = m.sample(b_shape).squeeze(-1)

        self.W_mu = nn.Parameter(
            W_mu_tmp, requires_grad=True)
        self.b_mu = nn.Parameter(
            b_mu_tmp, requires_grad=True)

        self.W_std = nn.Parameter(
            torch.randn(W_shape) + 1, requires_grad=True)
        self.b_std = nn.Parameter(
            torch.randn(b_shape) + 1, requires_grad=True)

        # save prior
        self.W_mu_prior = torch.zeros(W_shape) + 0.001
        self.b_mu_prior = torch.zeros(b_shape) + 0.001
        self.W_std_prior = torch.zeros(W_shape) + 1.0
        self.b_std_prior = torch.zeros(b_shape) + 1.0

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu + F.softplus(self.W_std) * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.mm(X, W) + b

        return output

    def forward_mu(self, X):

        W = self.W_mu
        # W = self.W_mu * torch.ones((self.n_in, self.n_out), device=self.W_std.device)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)

        b = self.b_mu

        output = torch.mm(X, W) + b

        return output

    def forward_mu_delta(self, X):

        W = self.W_mu + 1e-3 * torch.randn((self.n_in, self.n_out), device=self.W_mu.device)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)

        b = self.b_mu + 1e-3 * torch.randn((self.n_out), device=self.W_mu.device)

        output = torch.mm(X, W) + b

        return output



    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu.repeat(num_sample, 1, 1) + F.softplus(self.W_std.repeat(num_sample, 1, 1)) * \
            torch.randn((num_sample, self.n_in, self.n_out), device=self.W_std.device)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.bmm(X, W) + b

        return output

    def sample_w_d(self, num_sample):
        W = self.W_mu.view(-1).repeat(num_sample, 1) + F.softplus(self.W_std.view(-1).repeat(num_sample, 1)) * \
            torch.randn((num_sample, self.n_in * self.n_out), device=self.W_std.device)

        b = self.b_mu.repeat(num_sample, 1) + F.softplus(self.b_std.repeat(num_sample, 1)) * \
            torch.randn((num_sample, self.n_out), device=self.b_std.device)

        return W, b


    def kld(self):

        self.W_std_prior = self.W_std_prior.to(self.W_std.device)
        self.W_mu_prior = self.W_mu_prior.to(self.W_mu.device)
        self.b_std_prior = self.b_std_prior.to(self.b_std.device)
        self.b_mu_prior = self.b_mu_prior.to(self.b_mu.device)

        # calculate KL divergence between prior and posterior
        kl_mu = torch.mean(torch.log(F.softplus(self.W_std_prior)) - torch.log(F.softplus(self.W_std))
                    + (torch.pow(F.softplus(self.W_std), 2) + torch.pow((self.W_mu - self.W_mu_prior), 2)) / (
                    2 * torch.pow(F.softplus(self.W_std_prior), 2)) )
        kl_b = torch.mean(torch.log(F.softplus(self.b_std_prior)) - torch.log(F.softplus(self.b_std))
                    + (torch.pow(F.softplus(self.b_std), 2) + torch.pow((self.b_mu - self.b_mu_prior), 2)) / (
                    2 * torch.pow(F.softplus(self.b_std_prior), 2)))

        return kl_mu + kl_b

    def wd(self):

        self.W_std_prior = self.W_std_prior.to(self.W_std.device)
        self.W_mu_prior = self.W_mu_prior.to(self.W_mu.device)
        self.b_std_prior = self.b_std_prior.to(self.b_std.device)
        self.b_mu_prior = self.b_mu_prior.to(self.b_mu.device)

        # calculate Wasserstein distance between prior and posterior
        wdist_w = torch.mean(torch.pow(torch.pow((self.W_mu - self.W_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.W_std) - F.softplus(self.W_std_prior) + 0.001), 2), 0.5) )
        wdist_b = torch.mean(torch.pow(torch.pow((self.b_mu - self.b_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.b_std) - F.softplus(self.b_std_prior) + 0.001), 2), 0.5) )

        return wdist_w + wdist_b

    def para_wd(self, W_mu_prior, b_mu_prior, W_std_prior, b_std_prior):
        wdist_w = torch.mean(torch.pow(torch.pow((self.W_mu - W_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.W_std) - F.softplus(W_std_prior) + 0.001), 2), 0.5))
        wdist_b = torch.mean(torch.pow(torch.pow((self.b_mu - b_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.b_std) - F.softplus(b_std_prior) + 0.001), 2), 0.5))

        return wdist_w + wdist_b


class BNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 W_mu=None, b_mu=None, W_std=None, b_std=None, scaled_variance=False, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
            W_std: float, the initial value of the logarithm of
                the standard deviation of the weights.
            b_std: float, the initial value of the logarithm of
                the standard deviation of the biases.
        """
        super(BNN, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        if b_std is None:
            b_std = W_std

        # Initialize layers

        self.input_layer = BNNMLP(
            input_dim, hidden_dims[0], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.input_layer = nn.Linear(input_dim, hidden_dims[0])

        self.norm_layer1 = init_norm_layer(hidden_dims[0], self.norm_layer)

        self.mid_layer = BNNMLP(
            hidden_dims[0], hidden_dims[1], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.mid_layer = nn.Linear(hidden_dims[0], hidden_dims[1])

        self.norm_layer2 = init_norm_layer(hidden_dims[1], self.norm_layer)

        self.output_layer = BNNMLP(
            hidden_dims[1], output_dim, W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.output_layer = nn.Linear(hidden_dims[1], output_dim)

        # self.final_output_layer = nn.Tanh()
        # self.final_output_layer = nn.Softmax(dim=-1)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        X = self.output_layer(X)
        # X = self.final_output_layer(X)

        return X

    def forward_mu(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_mu(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_mu(X)))
        X = self.output_layer.forward_mu(X)
        # X = self.final_output_layer(X)

        return X

    def forward_mu_delta(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_mu_delta(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_mu_delta(X)))
        X = self.output_layer.forward_mu_delta(X)
        # X = self.final_output_layer(X)

        return X

    def sample_functions(self, X, num_sample):
        """Performs predictions using `n_samples` set of weights.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            num_sample: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [batch_size, n_samples, output_dim], the output
            data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)

        return X


    def sample_w_b(self, num_sample):
        """ sample W and b according to the optimal W_mu W_std b_mu b_std """
        W_sample_0, b_sample_0 = self.input_layer.sample_w_d(num_sample)
        W_sample_1, b_sample_1 = self.mid_layer.sample_w_d(num_sample)
        W_sample_2, b_sample_2 = self.output_layer.sample_w_d(num_sample)

        W_sample = torch.cat((W_sample_0, W_sample_1, W_sample_2), -1)
        b_sample = torch.cat((b_sample_0, b_sample_1, b_sample_2), -1)

        return W_sample, b_sample

    def sample_mu_std(self):

        W_mu_0 = self.input_layer.W_mu.view(-1)
        W_std_0 = self.input_layer.W_std.view(-1)
        b_mu_0 = self.input_layer.b_mu.view(-1)
        b_std_0 = self.input_layer.b_std.view(-1)

        W_mu_1 = self.mid_layer.W_mu.view(-1)
        W_std_1 = self.mid_layer.W_std.view(-1)
        b_mu_1 = self.mid_layer.b_mu.view(-1)
        b_std_1 = self.mid_layer.b_std.view(-1)

        W_mu_2 = self.output_layer.W_mu.view(-1)
        W_std_2 = self.output_layer.W_std.view(-1)
        b_mu_2 = self.output_layer.b_mu.view(-1)
        b_std_2 = self.output_layer.b_std.view(-1)

        W_mu_list = torch.cat((W_mu_0, W_mu_1, W_mu_2), -1)
        b_mu_list = torch.cat((b_mu_0, b_mu_1, b_mu_2), -1)
        W_std_list = torch.cat((W_std_0, W_std_1, W_std_2), -1)
        b_std_list = torch.cat((b_std_0, b_std_1, b_std_2), -1)

        mu_list = torch.cat((W_mu_list, b_mu_list), -1)
        std_list = torch.cat((W_std_list, b_std_list), -1)

        # mu_list = torch.cat((W_mu_0, b_mu_0, W_mu_1, b_mu_1, W_mu_2, b_mu_2), -1)
        # std_list = torch.cat((W_std_0, b_std_0, W_std_1, b_std_1, W_std_2, b_std_2), -1)

        return mu_list, std_list


    def forward_flow(self, X, num_sample, flow, hidden_dims):

        W_sample_0, b_sample_0 = self.input_layer.sample_w_d(num_sample)
        W_sample_1, b_sample_1 = self.mid_layer.sample_w_d(num_sample)
        W_sample_2, b_sample_2 = self.output_layer.sample_w_d(num_sample)

        W_sample = torch.cat((W_sample_0, W_sample_1, W_sample_2), -1)
        b_sample = torch.cat((b_sample_0, b_sample_1, b_sample_2), -1)
        w_b_0 = torch.cat((W_sample, b_sample), -1)
        w_b_k, log_jacobians = flow(w_b_0)
        print('logdet: ', log_jacobians)
        # print('zk.shape: ', w_b_k.shape)
        # print('zk.max: ', w_b_k.mean())
        # sum_of_log_jacobians = sum(log_jacobians).mean()   #sum(log_jacobians)
        # print('sum_of_log_jacobians: ', sum_of_log_jacobians)

        wk_0 = w_b_k[:, :hidden_dims[0]]
        # print('wk_0.shape: ', wk_0.shape)
        wk_1 = w_b_k[:, hidden_dims[0]:hidden_dims[0]+hidden_dims[0] * hidden_dims[1]]
        # print('wk_1.shape: ', wk_1.shape)
        wk_2 = w_b_k[:, hidden_dims[0]+hidden_dims[0] * hidden_dims[1]: hidden_dims[0]+hidden_dims[0] * hidden_dims[1]+ hidden_dims[1]]
        # print('wk_2.shape: ', wk_2.shape)

        output_dim = self.output_dim
        input_dim = self.input_dim
        bk_2 = w_b_k[:,-1].squeeze()
        # print('bk_2: ', bk_2)
        bk_1 = w_b_k[:, -hidden_dims[1]-output_dim:-output_dim].squeeze()
        # print('bk_1: ', bk_1.shape)
        bk_0 = w_b_k[:, -hidden_dims[0]-hidden_dims[1]-output_dim: -hidden_dims[1]-output_dim].squeeze()
        # print('bk_0: ', bk_0.shape)

        wk_0 = torch.reshape(wk_0, (input_dim, hidden_dims[0]))
        # print('wk_0.shape: ', wk_0.shape)
        wk_1 = torch.reshape(wk_1, (hidden_dims[0], hidden_dims[1]))
        # print('wk_1.shape: ', wk_1.shape)
        wk_2 = torch.reshape(wk_2, (hidden_dims[1], output_dim))
        # print('wk_2.shape: ', wk_2.shape)

        X = X.view(-1, self.input_dim)
        kl_prior = 0
        num_layer = 0

        X = self.activation_fn(self.norm_layer1(torch.mm(X, wk_0) + bk_0))
        kl_prior = kl_prior + self.input_layer.kld()
        num_layer = num_layer + 1
        # print('X: ', X)
        # print('X.shape: ', X.shape)

        X = self.activation_fn(self.norm_layer2(torch.mm(X, wk_1) + bk_1))
        kl_prior = kl_prior + self.mid_layer.kld()
        num_layer = num_layer + 1
        # print('X: ', X)
        # print('X.shape: ', X.shape)

        X = torch.mm(X, wk_2) + bk_2
        kl_prior = kl_prior + self.output_layer.kld()
        num_layer = num_layer + 1
        # print('X: ', X)
        # print('X.shape: ', X.shape)

        return X, kl_prior / num_layer, log_jacobians


    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data using a number of parameter samples.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)
        # X = self.final_output_layer(X)

        return X

    def forward_kl(self, X):
        """Performs forward pass given input data and output KL distance (with prior) as well
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        kl_prior = 0
        num_layer = 0

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        kl_prior = kl_prior + self.input_layer.kld()
        num_layer = num_layer + 1

        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        kl_prior = kl_prior + self.mid_layer.kld()
        num_layer = num_layer + 1

        X = self.output_layer(X)
        kl_prior = kl_prior + self.output_layer.kld()
        num_layer = num_layer + 1

        # X = self.final_output_layer(X)

        return X, kl_prior / num_layer

    def forward_w(self, X):
        """Performs forward pass given input data and output Wasserstein distance (with prior) as well
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        w_prior = 0
        num_layer = 0

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        w_prior = w_prior + self.input_layer.wd()
        num_layer = num_layer + 1

        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        w_prior = w_prior + self.mid_layer.wd()
        num_layer = num_layer + 1

        X = self.output_layer(X)
        w_prior = w_prior + self.output_layer.wd()
        num_layer = num_layer + 1

        return X, w_prior / num_layer

    def save(self, checkpoint_path):
        torch.save(self.state_dict(), checkpoint_path)

    def load(self, checkpoint_path):
        self.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))


class BNNMLPMC(nn.Module):
    def __init__(self, n_in, n_out, W_mu=None, b_mu=None, W_std=None,
                 b_std=None, scaled_variance=False, prior_per='parameter'):   # 'parameter'
        """Initialization.
        Args:
            n_in: int, the size of the input data.
            n_out: int, the size of the output.
            W_std: float, the initial value of
                the standard deviation of the weights.
            b_std: float, the initial value of
                the standard deviation of the biases.
            prior_per: str, indicates whether using different prior for
                each parameter, option `parameter`, or use the share the
                prior for all parameters in the same layer, option `layer`.
        """
        super(BNNMLPMC, self).__init__()

        self.n_in = n_in
        self.n_out = n_out
        self.scaled_variance = scaled_variance

        if W_mu is None:
            if self.scaled_variance:
                W_mu = 1.
            else:
                W_mu = 1. / math.sqrt(self.n_in)
        if b_mu is None:
            b_mu = 1.

        if W_std is None:
            if self.scaled_variance:
                W_std = 1.
            else:
                W_std = 1. / math.sqrt(self.n_in)
        if b_std is None:
            b_std = 1.

        if prior_per == "layer":
            W_shape, b_shape = [1], [1]  # (1), (1)
        elif prior_per == "parameter":
            W_shape, b_shape = [self.n_in, self.n_out], [self.n_out]
        else:
            raise ValueError("Accepted values: `parameter` or `layer`")

        #
        bound = 1. / math.sqrt(self.n_in)
        m = torch.distributions.uniform.Uniform(torch.tensor([-bound]), torch.tensor([bound]))
        # m = torch.distributions.uniform.Uniform(torch.tensor([-1.0]), torch.tensor([1.0]))
        # W_mu_tmp = m.sample(W_shape).squeeze()
        # b_mu_tmp = m.sample(b_shape).squeeze()

        W_mu_tmp = m.sample(W_shape).squeeze(-1)
        b_mu_tmp = m.sample(b_shape).squeeze(-1)

        W_temp = W_mu + F.softplus(W_std) * \
            torch.randn((self.n_in, self.n_out), device=W_std.device)

        b_temp = b_mu + F.softplus(b_std) * \
            torch.randn((self.n_out), device=b_std.device)

        self.W = nn.Parameter(
            W_temp , requires_grad=True)
        self.b = nn.Parameter(
            b_temp , requires_grad=True)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b

        output = torch.mm(X, W) + b

        return output


    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W.repeat(num_sample, 1, 1)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b

        output = torch.bmm(X, W) + b

        return output


class BNNMC(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 W_mu, b_mu, W_std, b_std, scaled_variance=False, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
            W_std: float, the initial value of the logarithm of
                the standard deviation of the weights.
            b_std: float, the initial value of the logarithm of
                the standard deviation of the biases.
        """
        super(BNNMC, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        if b_std is None:
            b_std = W_std

        # Initialize layers

        self.input_layer = BNNMLPMC(
            input_dim, hidden_dims[0], W_mu[0], b_mu[0], W_std[0], b_std[0],
            scaled_variance=scaled_variance)
        # self.input_layer = nn.Linear(input_dim, hidden_dims[0])

        self.norm_layer1 = init_norm_layer(hidden_dims[0], self.norm_layer)

        self.mid_layer = BNNMLPMC(
            hidden_dims[0], hidden_dims[1], W_mu[1], b_mu[1], W_std[1], b_std[1],
            scaled_variance=scaled_variance)
        # self.mid_layer = nn.Linear(hidden_dims[0], hidden_dims[1])

        self.norm_layer2 = init_norm_layer(hidden_dims[1], self.norm_layer)

        self.output_layer = BNNMLPMC(
            hidden_dims[1], output_dim, W_mu[2], b_mu[2], W_std[2], b_std[2],
            scaled_variance=scaled_variance)

        # self.output_layer = nn.Linear(hidden_dims[1], output_dim)

        # self.final_output_layer = nn.Tanh()
        # self.final_output_layer = nn.Softmax(dim=-1)


    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        X = self.output_layer(X)
        # X = self.final_output_layer(X)

        return X

    def sample_w_b(self):
        # w_list = []
        # b_list = []
        #
        # w_list = w_list.append(self.input_layer.W)
        # w_list = w_list.append(self.mid_layer.W)
        # w_list = w_list.append(self.output_layer.W)
        #
        # b_list = b_list.append(self.input_layer.b)
        # b_list = b_list.append(self.mid_layer.b)
        # b_list = b_list.append(self.output_layer.b)

        W_sample_0 = self.input_layer.W.view(-1)
        W_sample_1 = self.mid_layer.W.view(-1)
        W_sample_2 = self.output_layer.W.view(-1)

        b_sample_0 = self.input_layer.b.view(-1)
        b_sample_1 = self.mid_layer.b.view(-1)
        b_sample_2 = self.output_layer.b.view(-1)

        W_sample = torch.cat((W_sample_0, W_sample_1, W_sample_2), -1).unsqueeze(0)
        b_sample = torch.cat((b_sample_0, b_sample_1, b_sample_2), -1).unsqueeze(0)

        return W_sample, b_sample


    def sample_functions(self, X, num_sample):
        """Performs predictions using `n_samples` set of weights.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            num_sample: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [batch_size, n_samples, output_dim], the output
            data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)

        return X

    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data using a number of parameter samples.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)
        # X = self.final_output_layer(X)

        return X

    def save(self, checkpoint_path):
        torch.save(self.state_dict(), checkpoint_path)

    def load(self, checkpoint_path):
        self.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))


class BNNMLPF(nn.Module):
    def __init__(self, n_in, n_out, W_mu=None, b_mu=None, W_std=None,
                 b_std=None, scaled_variance=False, prior_per='parameter'):
        """Initialization.
        Args:
            n_in: int, the size of the input data.
            n_out: int, the size of the output.
            W_std: float, the initial value of
                the standard deviation of the weights.
            b_std: float, the initial value of
                the standard deviation of the biases.
            prior_per: str, indicates whether using different prior for
                each parameter, option `parameter`, or use the share the
                prior for all parameters in the same layer, option `layer`.
        """
        super(BNNMLPF, self).__init__()

        self.n_in = n_in
        self.n_out = n_out
        self.scaled_variance = scaled_variance

        if W_mu is None:
            if self.scaled_variance:
                W_mu = 1.
            else:
                W_mu = 1. / math.sqrt(self.n_in)
        if b_mu is None:
            b_mu = 1.

        if W_std is None:
            if self.scaled_variance:
                W_std = 1.
            else:
                W_std = 1. / math.sqrt(self.n_in)
        if b_std is None:
            b_std = 1.

        if prior_per == "layer":
            W_shape, b_shape = [1], [1]  # (1), (1)
        elif prior_per == "parameter":
            W_shape, b_shape = [self.n_in, self.n_out], [self.n_out]
        else:
            raise ValueError("Accepted values: `parameter` or `layer`")

        #
        bound = 1. / math.sqrt(self.n_in)
        m = torch.distributions.uniform.Uniform(torch.tensor([-bound]), torch.tensor([bound]))
        # m = torch.distributions.uniform.Uniform(torch.tensor([-1.0]), torch.tensor([1.0]))
        # W_mu_tmp = m.sample(W_shape).squeeze()
        # b_mu_tmp = m.sample(b_shape).squeeze()

        W_mu_tmp = m.sample(W_shape).squeeze(-1)
        b_mu_tmp = m.sample(b_shape).squeeze(-1)

        self.W_mu = nn.Parameter(
            W_mu_tmp , requires_grad=True)
        self.b_mu = nn.Parameter(
            b_mu_tmp , requires_grad=True)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu

        output = torch.mm(X, W) + b

        return output


    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu.repeat(num_sample, 1, 1)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu

        output = torch.bmm(X, W) + b

        return output

class BNNF(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 W_mu=None, b_mu=None, W_std=None, b_std=None, scaled_variance=False, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
            W_std: float, the initial value of the logarithm of
                the standard deviation of the weights.
            b_std: float, the initial value of the logarithm of
                the standard deviation of the biases.
        """
        super(BNNF, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        if b_std is None:
            b_std = W_std

        # Initialize layers

        self.input_layer = BNNMLPF(
            input_dim, hidden_dims[0], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.input_layer = nn.Linear(input_dim, hidden_dims[0])

        self.norm_layer1 = init_norm_layer(hidden_dims[0], self.norm_layer)

        # self.mid_layer = BNNMLPMC(
        #     hidden_dims[0], hidden_dims[1], W_mu, b_mu, W_std, b_std,
        #     scaled_variance=scaled_variance)

        # self.norm_layer2 = init_norm_layer(hidden_dims[1], self.norm_layer)

        self.output_layer = BNNMLPF(
            hidden_dims[0], output_dim, W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.output_layer = nn.Linear(hidden_dims[1], output_dim)

        # self.final_output_layer = nn.Tanh()
        # self.final_output_layer = nn.Softmax(dim=-1)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        # X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        X = self.output_layer(X)
        # X = self.final_output_layer(X)

        return X

    def sample_functions(self, X, num_sample):
        """Performs predictions using `n_samples` set of weights.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            num_sample: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [batch_size, n_samples, output_dim], the output
            data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        # X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)

        return X

    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data using a number of parameter samples.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        # X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)
        # X = self.final_output_layer(X)

        return X

    def save(self, checkpoint_path):
        torch.save(self.state_dict(), checkpoint_path)

    def load(self, checkpoint_path):
        self.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))


class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)   #1 3
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

    def save(self, checkpoint_path):
        torch.save(self.state_dict(), checkpoint_path)

    def load(self, checkpoint_path):
        self.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2], 10)


class BayesWrapper:
    def __init__(self, name, net, rho_init=-5, lr=1e-2):
        super().__init__()
        self.name = name
        self.net=net
        self.bayes_params = [(name, p.clone().detach(), #mu
                              torch.zeros_like(p)+rho_init, #rho
                              torch.zeros_like(p),  #sigma
                              torch.zeros_like(p)) #epsilon (buffer)
                             for name, p in self.net.named_parameters()
                             ]
        self.criterion = nn.CrossEntropyLoss()
        params = [mu for name, mu, rho, _, eps in self.bayes_params] + [rho for name, mu, rho, _, eps in self.bayes_params]

        self.optimizer = torch.optim.Adam(params, lr=lr)


    def forward(self, input):
        for name, mu, rho, sigma, eps in self.bayes_params:
            eps.normal_()
            sigma.copy_(torch.log1p(torch.exp(rho)))
            w = mu + eps * sigma
            self.net.load_state_dict(OrderedDict({name:w}), strict=False)

        return self.net(input)

    def step(self, outputs, targets):
        self.net.zero_grad()
        n_samples = len(outputs)
        loss = self.criterion(outputs, targets)

        for (name, mu, rho, sigma, eps), p in zip(self.bayes_params, self.net.parameters()):
            mu.grad = p.grad
            # rho.grad = eps * (p.grad - 1./sigma) / (1 + torch.exp(rho))

        self.optimizer.step()

        return loss

    def __call__(self, input):
        return self.forward(input)

    def train(self):
        self.net.train()




