
import warnings
warnings.simplefilter("ignore", UserWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
from torch.utils.data import TensorDataset, DataLoader

import itertools
import numpy as np
import math


def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)

def init_norm_layer(input_dim, norm_layer):
    if norm_layer == "batchnorm":
        return nn.BatchNorm1d(input_dim, eps=0, momentum=None,
                              affine=False, track_running_stats=False)
    elif norm_layer is None:
        return nn.Identity()


class BNNMLP(nn.Module):
    def __init__(self, n_in, n_out, W_mu=None, b_mu=None, W_std=None,
                 b_std=None, scaled_variance=False, prior_per='parameter'):
        """Initialization.
        Args:
            n_in: int, the size of the input data.
            n_out: int, the size of the output.
            W_std: float, the initial value of
                the standard deviation of the weights.
            b_std: float, the initial value of
                the standard deviation of the biases.
            prior_per: str, indicates whether using different prior for
                each parameter, option `parameter`, or use the share the
                prior for all parameters in the same layer, option `layer`.
        """
        super(BNNMLP, self).__init__()

        self.n_in = n_in
        self.n_out = n_out
        self.scaled_variance = scaled_variance

        if W_mu is None:
            if self.scaled_variance:
                W_mu = 1.
            else:
                W_mu = 1. / math.sqrt(self.n_in)
        if b_mu is None:
            b_mu = 1.

        if W_std is None:
            if self.scaled_variance:
                W_std = 1.
            else:
                W_std = 1. / math.sqrt(self.n_in)
        if b_std is None:
            b_std = 1.

        if prior_per == "layer":
            W_shape, b_shape = [1], [1]    # (1), (1)
        elif prior_per == "parameter":
            W_shape, b_shape = [self.n_in, self.n_out], [self.n_out]
        else:
            raise ValueError("Accepted values: `parameter` or `layer`")

        #
        bound = 1. / math.sqrt(self.n_in)
        m = torch.distributions.uniform.Uniform(torch.tensor([-bound]), torch.tensor([bound]))
        # m = torch.distributions.uniform.Uniform(torch.tensor([-1.0]), torch.tensor([1.0]))
        # W_mu_tmp = m.sample(W_shape).squeeze()
        # b_mu_tmp = m.sample(b_shape).squeeze()

        W_mu_tmp = m.sample(W_shape).squeeze(-1)
        b_mu_tmp = m.sample(b_shape).squeeze(-1)

        self.W_mu = nn.Parameter(
            W_mu_tmp, requires_grad=True)
        self.b_mu = nn.Parameter(
            b_mu_tmp, requires_grad=True)

        self.W_std = nn.Parameter(
            torch.randn(W_shape) + 1, requires_grad=True)
        self.b_std = nn.Parameter(
            torch.randn(b_shape) + 1, requires_grad=True)

        # save prior
        self.W_mu_prior = torch.zeros(W_shape) + 0.001
        self.b_mu_prior = torch.zeros(b_shape) + 0.001
        self.W_std_prior = torch.zeros(W_shape) + 1.0
        self.b_std_prior = torch.zeros(b_shape) + 1.0

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu + F.softplus(self.W_std) * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.mm(X, W) + b

        return output
        
    def forward_plain(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu 
        output = torch.mm(X, W) + b

        return output

    def forward_norm(self, X):
        """Performs forward pass given input data.
                Args:
                    X: torch.tensor, [batch_size, input_dim], the input data.
                Returns:
                    output: torch.tensor, [batch_size, output_dim], the output data.
                """
        W = self.W_mu + F.softplus(self.W_std) * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.mm(X, W) + b
        W = W.abs()
        sumw = torch.sum(W)
        # W = W ** 2
        # sumw = torch.sum(W)

        return output, sumw

    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu.repeat(num_sample, 1, 1) + F.softplus(self.W_std.repeat(num_sample, 1, 1)) * \
            torch.randn((num_sample, self.n_in, self.n_out), device=self.W_std.device)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.bmm(X, W) + b

        return output


    def forward_eval_plain(self, X, num_sample):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu.repeat(num_sample, 1, 1)
        
        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
            
        b = self.b_mu

        output = torch.bmm(X, W) + b

        return output


    def kld(self):

        self.W_std_prior = self.W_std_prior.to(self.W_std.device)
        self.W_mu_prior = self.W_mu_prior.to(self.W_mu.device)
        self.b_std_prior = self.b_std_prior.to(self.b_std.device)
        self.b_mu_prior = self.b_mu_prior.to(self.b_mu.device)

        # calculate KL divergence between prior and posterior
        kl_mu = torch.mean(torch.log(F.softplus(self.W_std_prior)) - torch.log(F.softplus(self.W_std))
                    + (torch.pow(F.softplus(self.W_std), 2) + torch.pow((self.W_mu - self.W_mu_prior), 2)) / (
                    2 * torch.pow(F.softplus(self.W_std_prior), 2)) )
        kl_b = torch.mean(torch.log(F.softplus(self.b_std_prior)) - torch.log(F.softplus(self.b_std))
                    + (torch.pow(F.softplus(self.b_std), 2) + torch.pow((self.b_mu - self.b_mu_prior), 2)) / (
                    2 * torch.pow(F.softplus(self.b_std_prior), 2)))

        return kl_mu + kl_b

    def wd(self):

        self.W_std_prior = self.W_std_prior.to(self.W_std.device)
        self.W_mu_prior = self.W_mu_prior.to(self.W_mu.device)
        self.b_std_prior = self.b_std_prior.to(self.b_std.device)
        self.b_mu_prior = self.b_mu_prior.to(self.b_mu.device)

        # calculate Wasserstein distance between prior and posterior
        wdist_w = torch.mean(torch.pow(torch.pow((self.W_mu - self.W_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.W_std) - F.softplus(self.W_std_prior) + 0.001), 2), 0.5) )
        wdist_b = torch.mean(torch.pow(torch.pow((self.b_mu - self.b_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.b_std) - F.softplus(self.b_std_prior) + 0.001), 2), 0.5) )

        return wdist_w + wdist_b

    def para_wd(self, W_mu_prior, b_mu_prior, W_std_prior, b_std_prior):
        wdist_w = torch.mean(torch.pow(torch.pow((self.W_mu - W_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.W_std) - F.softplus(W_std_prior) + 0.001), 2), 0.5))
        wdist_b = torch.mean(torch.pow(torch.pow((self.b_mu - b_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.b_std) - F.softplus(b_std_prior) + 0.001), 2), 0.5))

        return wdist_w + wdist_b



class BNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 W_mu=None, b_mu=None, W_std=None, b_std=None, scaled_variance=False, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
            W_std: float, the initial value of the logarithm of
                the standard deviation of the weights.
            b_std: float, the initial value of the logarithm of
                the standard deviation of the biases.
        """
        super(BNN, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        if b_std is None:
            b_std = W_std

        # Initialize layers

        self.input_layer = BNNMLP(
            input_dim, hidden_dims[0], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.input_layer = nn.Linear(input_dim, hidden_dims[0])

        self.norm_layer1 = init_norm_layer(hidden_dims[0], self.norm_layer)

        self.mid_layer = BNNMLP(
            hidden_dims[0], hidden_dims[1], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.mid_layer = nn.Linear(hidden_dims[0], hidden_dims[1])

        self.norm_layer2 = init_norm_layer(hidden_dims[1], self.norm_layer)

        self.output_layer = BNNMLP(
            hidden_dims[1], output_dim, W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.output_layer = nn.Linear(hidden_dims[1], output_dim)

        # self.final_output_layer = nn.Tanh()
        # self.final_output_layer = nn.Softmax(dim=-1)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        X = self.output_layer(X)
        # X = self.final_output_layer(X)

        return X
        
    def forward_plain(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_plain(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_plain(X)))
        X = self.output_layer.forward_plain(X)
        # X = self.final_output_layer(X)

        return X

    def sample_measurement_set(self, X, num_data):

        # sample from old using geometric distribution
        # p = 2.0 / num_data
        # g = torch.distributions.Geometric(p)
        # n = g.sample()
        # count = 0
        # while n > num_data:
        #     n = g.sample()
        #     count += 1
        #     if count > 10:
        #         n = num_data
        #         break

        n = torch.Tensor([200])
        # sample measurement set with size n
        perm = torch.randperm(int(num_data))
        idx = perm[:n.to(torch.long)]
        measurement_set = X[idx, :]

        return measurement_set

    def sample_functions(self, X, num_sample):
        """Performs predictions using `n_samples` set of weights.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            num_sample: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [batch_size, n_samples, output_dim], the output
            data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)

        return X

    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data using a number of parameter samples.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)
        # X = self.final_output_layer(X)

        return X
        
    def forward_eval_plain(self, X, num_sample):
        """Performs forward pass given input data using a number of parameter samples.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval_plain(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval_plain(X, num_sample)))
        X = self.output_layer.forward_eval_plain(X, num_sample)
        # X = self.final_output_layer(X)

        return X

    def forward_kl(self, X):
        """Performs forward pass given input data and output KL distance (with prior) as well
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        kl_prior = 0
        num_layer = 0

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        kl_prior = kl_prior + self.input_layer.kld()
        num_layer = num_layer + 1

        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        kl_prior = kl_prior + self.mid_layer.kld()
        num_layer = num_layer + 1

        X = self.output_layer(X)
        kl_prior = kl_prior + self.output_layer.kld()
        num_layer = num_layer + 1

        # X = self.final_output_layer(X)

        return X, kl_prior / num_layer

    def forward_w(self, X):
        """Performs forward pass given input data and output Wasserstein distance (with prior) as well
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        w_prior = 0
        num_layer = 0

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        w_prior = w_prior + self.input_layer.wd()
        num_layer = num_layer + 1

        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        w_prior = w_prior + self.mid_layer.wd()
        num_layer = num_layer + 1

        X = self.output_layer(X)
        w_prior = w_prior + self.output_layer.wd()
        num_layer = num_layer + 1

        return X, w_prior / num_layer

    def forward_w_norm(self, X):
        """Performs forward pass given input data and output Wasserstein distance (with prior) as well
                Args:
                    X: torch.tensor, [batch_size, input_dim], the input data.
                Returns:
                    torch.tensor, [batch_size, output_dim], the output data.
                """
        X = X.view(-1, self.input_dim)
        w_prior = 0
        num_layer = 0

        X, sumw1= self.input_layer.forward_norm(X)
        X = self.activation_fn(self.norm_layer1(X))
        w_prior = w_prior + self.input_layer.wd()
        num_layer = num_layer + 1

        X, sumw2 = self.mid_layer.forward_norm(X)
        X = self.activation_fn(self.norm_layer2(X))
        w_prior = w_prior + self.mid_layer.wd()
        num_layer = num_layer + 1

        X, sumw3 = self.output_layer.forward_norm(X)
        w_prior = w_prior + self.output_layer.wd()
        num_layer = num_layer + 1

        sumw_layer = sumw1 + sumw2 + sumw3

        # X = self.final_output_layer(X)

        return X, w_prior / num_layer, sumw_layer

    def forward_kl_norm(self, X):
        """Performs forward pass given input data and output Wasserstein distance (with prior) as well
                        Args:
                            X: torch.tensor, [batch_size, input_dim], the input data.
                        Returns:
                            torch.tensor, [batch_size, output_dim], the output data.
                        """
        X = X.view(-1, self.input_dim)
        kl_prior = 0
        num_layer = 0

        X, sumw1 = self.input_layer.forward_norm(X)
        X = self.activation_fn(self.norm_layer1(X))
        kl_prior = kl_prior + self.input_layer.kld()
        num_layer = num_layer + 1

        X, sumw2 = self.mid_layer.forward_norm(X)
        X = self.activation_fn(self.norm_layer2(X))
        kl_prior = kl_prior + self.mid_layer.kld()
        num_layer = num_layer + 1

        X, sumw3 = self.output_layer.forward_norm(X)
        kl_prior = kl_prior + self.output_layer.kld()
        num_layer = num_layer + 1

        sumw_layer = sumw1 + sumw2 + sumw3

        # X = self.final_output_layer(X)

        return X, kl_prior / num_layer, sumw_layer


    def save(self, checkpoint_path):
        torch.save(self.state_dict(), checkpoint_path)

    def load(self, checkpoint_path):
        self.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))

class OPTBNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 W_mu=None, b_mu=None, W_std=None, b_std=None, scaled_variance=False, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
            W_std: float, the initial value of the logarithm of
                the standard deviation of the weights.
            b_std: float, the initial value of the logarithm of
                the standard deviation of the biases.
        """
        super(OPTBNN, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        if b_std is None:
            b_std = W_std

        # Initialize layers

        self.input_layer = BNNMLP(
            input_dim, hidden_dims[0], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance, prior_per='layer')
        # self.input_layer = nn.Linear(input_dim, hidden_dims[0])

        self.norm_layer1 = init_norm_layer(hidden_dims[0], self.norm_layer)

        self.mid_layer = BNNMLP(
            hidden_dims[0], hidden_dims[1], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance, prior_per='layer')
        # self.mid_layer = nn.Linear(hidden_dims[0], hidden_dims[1])

        self.norm_layer2 = init_norm_layer(hidden_dims[1], self.norm_layer)

        self.output_layer = BNNMLP(
            hidden_dims[1], output_dim, W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance, prior_per='layer')
        # self.output_layer = nn.Linear(hidden_dims[1], output_dim)

        # self.final_output_layer = nn.Tanh()
        # self.final_output_layer = nn.Softmax(dim=-1)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        X = self.output_layer(X)
        # X = self.final_output_layer(X)

        return X

    def sample_measurement_set(self, X, num_data):

        # sample from old using geometric distribution
        # p = 2.0 / num_data
        # g = torch.distributions.Geometric(p)
        # n = g.sample()
        # count = 0
        # while n > num_data:
        #     n = g.sample()
        #     count += 1
        #     if count > 10:
        #         n = num_data
        #         break

        n = torch.Tensor([200])
        # sample measurement set with size n
        perm = torch.randperm(int(num_data))
        idx = perm[:n.to(torch.long)]
        measurement_set = X[idx, :]

        return measurement_set

    def sample_functions(self, X, num_sample):
        """Performs predictions using `n_samples` set of weights.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            num_sample: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [batch_size, n_samples, output_dim], the output
            data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)

        return X

    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data using a number of parameter samples.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)
        # X = self.final_output_layer(X)

        return X

    def save(self, checkpoint_path):
        torch.save(self.state_dict(), checkpoint_path)

    def load(self, checkpoint_path):
        self.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))

