
import warnings
warnings.simplefilter("ignore", UserWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
from torch.utils.data import TensorDataset, DataLoader

import gpytorch


import itertools
import numpy as np
import math

from utils.spectral import SpectralScoreEstimator

def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)

def init_norm_layer(input_dim, norm_layer):
    if norm_layer == "batchnorm":
        return nn.BatchNorm1d(input_dim, eps=0, momentum=None,
                              affine=False, track_running_stats=False)
    elif norm_layer is None:
        return nn.Identity()


class BNNMLP(nn.Module):
    def __init__(self, n_in, n_out, W_mu=None, b_mu=None, W_std=None,
                 b_std=None, scaled_variance=False, prior_per='parameter'):
        """Initialization.
        Args:
            n_in: int, the size of the input data.
            n_out: int, the size of the output.
            W_std: float, the initial value of
                the standard deviation of the weights.
            b_std: float, the initial value of
                the standard deviation of the biases.
            prior_per: str, indicates whether using different prior for
                each parameter, option `parameter`, or use the share the
                prior for all parameters in the same layer, option `layer`.
        """
        super(BNNMLP, self).__init__()

        self.n_in = n_in
        self.n_out = n_out
        self.scaled_variance = scaled_variance

        if W_mu is None:
            if self.scaled_variance:
                W_mu = 1.
            else:
                W_mu = 1. / math.sqrt(self.n_in)
        if b_mu is None:
            b_mu = 1.

        if W_std is None:
            if self.scaled_variance:
                W_std = 1.
            else:
                W_std = 1. / math.sqrt(self.n_in)
        if b_std is None:
            b_std = 1.

        if prior_per == "layer":
            W_shape, b_shape = (1), (1)
        elif prior_per == "parameter":
            W_shape, b_shape = [self.n_in, self.n_out], [self.n_out]
        else:
            raise ValueError("Accepted values: `parameter` or `layer`")

        #
        bound = 1. / math.sqrt(self.n_in)
        m = torch.distributions.uniform.Uniform(torch.tensor([-bound]), torch.tensor([bound]))
        # W_mu_tmp = m.sample(W_shape).squeeze()
        # b_mu_tmp = m.sample(b_shape).squeeze()

        W_mu_tmp = m.sample(W_shape).squeeze(-1)
        b_mu_tmp = m.sample(b_shape).squeeze(-1)

        self.W_mu = nn.Parameter(
            W_mu_tmp, requires_grad=True)
        self.b_mu = nn.Parameter(
            b_mu_tmp, requires_grad=True)

        self.W_std = nn.Parameter(
            torch.randn(W_shape) + 1, requires_grad=True)
        self.b_std = nn.Parameter(
            torch.randn(b_shape) + 1, requires_grad=True)

        # save prior
        self.W_mu_prior = torch.zeros(W_shape) + 0.001
        self.b_mu_prior = torch.zeros(b_shape) + 0.001
        self.W_std_prior = torch.zeros(W_shape) + 1.0
        self.b_std_prior = torch.zeros(b_shape) + 1.0

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu + F.softplus(self.W_std) * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)
        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.mm(X, W) + b

        return output

    def forward_norm(self, X):
        """Performs forward pass given input data.
                       Args:
                            X: torch.tensor, [batch_size, input_dim], the input data.
                        Returns:
                            output: torch.tensor, [batch_size, output_dim], the output data.
                        """
        W = self.W_mu + F.softplus(self.W_std) * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.mm(X, W) + b
        W = W.abs()
        sumw = torch.sum(W)
        # W = W ** 2
        # sumw = torch.sum(W)

        return output, sumw

    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """
        W = self.W_mu.repeat(num_sample, 1, 1) + F.softplus(self.W_std.repeat(num_sample, 1, 1)) * \
            torch.randn((num_sample, self.n_in, self.n_out), device=self.W_std.device)

        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + F.softplus(self.b_std) * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.bmm(X, W) + b

        return output


    def kld(self):

        self.W_std_prior = self.W_std_prior.to(self.W_std.device)
        self.W_mu_prior = self.W_mu_prior.to(self.W_mu.device)
        self.b_std_prior = self.b_std_prior.to(self.b_std.device)
        self.b_mu_prior = self.b_mu_prior.to(self.b_mu.device)

        # calculate KL divergence between prior and posterior
        kl_mu = torch.mean(torch.log(F.softplus(self.W_std_prior)) - torch.log(F.softplus(self.W_std))
                    + (torch.pow(F.softplus(self.W_std), 2) + torch.pow((self.W_mu - self.W_mu_prior), 2)) / (
                    2 * torch.pow(F.softplus(self.W_std_prior), 2)) )
        kl_b = torch.mean(torch.log(F.softplus(self.b_std_prior)) - torch.log(F.softplus(self.b_std))
                    + (torch.pow(F.softplus(self.b_std), 2) + torch.pow((self.b_mu - self.b_mu_prior), 2)) / (
                    2 * torch.pow(F.softplus(self.b_std_prior), 2)))

        return kl_mu + kl_b

    def wd(self):

        self.W_std_prior = self.W_std_prior.to(self.W_std.device)
        self.W_mu_prior = self.W_mu_prior.to(self.W_mu.device)
        self.b_std_prior = self.b_std_prior.to(self.b_std.device)
        self.b_mu_prior = self.b_mu_prior.to(self.b_mu.device)

        # calculate Wasserstein distance between prior and posterior
        wdist_w = torch.mean(torch.pow(torch.pow((self.W_mu - self.W_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.W_std) - F.softplus(self.W_std_prior) + 0.001), 2), 0.5) )
        wdist_b = torch.mean(torch.pow(torch.pow((self.b_mu - self.b_mu_prior + 0.001), 2) + torch.pow(
            (F.softplus(self.b_std) - F.softplus(self.b_std_prior) + 0.001), 2), 0.5) )

        return wdist_w + wdist_b


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, input_dim):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.LinearMean(input_dim)
        # self.mean_module = gpytorch.means.ConstantMean(input_dim)
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel() * gpytorch.kernels.PeriodicKernel())
        # self.covar_module = gpytorch.kernels.LinearKernel()
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=0.5))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class ExactGPModel2(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel2, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ZeroMean()
        # self.mean_module = gpytorch.means.ConstantMean(input_dim)
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel() * gpytorch.kernels.PeriodicKernel())
        # self.covar_module = gpytorch.kernels.LinearKernel()
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=0.5))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __int__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(gpytorch.means.ConstantMean(), num_tasks=2)
        self.covar_module = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.PeriodicKernel(), num_tasks=2, rank=1)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

class BatchIndependentMultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([2]))
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([2])), batch_shape=torch.Size([2]))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn(gpytorch.distributions.MultivariateNormal(mean_x, covar_x))


class fBNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 W_mu=None, b_mu=None, W_std=None, b_std=None, scaled_variance=False, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
            W_std: float, the initial value of the logarithm of
                the standard deviation of the weights.
            b_std: float, the initial value of the logarithm of
                the standard deviation of the biases.
        """
        super(fBNN, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        self.eta = 0.1
        self.n_eigen_threshold = 2

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        if b_std is None:
            b_std = W_std

        # Initialize layers

        self.input_layer = BNNMLP(
            input_dim, hidden_dims[0], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.input_layer = nn.Linear(input_dim, hidden_dims[0])

        self.norm_layer1 = init_norm_layer(hidden_dims[0], self.norm_layer)

        self.mid_layer = BNNMLP(
            hidden_dims[0], hidden_dims[1], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.mid_layer = nn.Linear(hidden_dims[0], hidden_dims[1])

        self.norm_layer2 = init_norm_layer(hidden_dims[1], self.norm_layer)

        self.output_layer = BNNMLP(
            hidden_dims[1], output_dim, W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)
        # self.output_layer = nn.Linear(hidden_dims[1], output_dim)

        # self.final_output_layer = nn.Tanh()
        # self.final_output_layer = nn.Softmax(dim=-1)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        X = self.output_layer(X)
        # X = self.final_output_layer(X)

        return X

    def forward_norm(self, X):
        X = X.view(-1, self.input_dim)

        X, sumw1 = self.input_layer.forward_norm(X)
        X = self.activation_fn(self.norm_layer1(X))

        X, sumw2 = self.mid_layer.forward_norm(X)
        X = self.activation_fn(self.norm_layer2(X))

        X, sumw3 = self.output_layer.forward_norm(X)

        sumw_layer = sumw1 + sumw2 + sumw3

        return X, sumw_layer



    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data using a number of parameter samples.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)
        # X = self.final_output_layer(X)

        return X

    def sample_functions(self, X, num_sample):
        """Performs predictions using `n_samples` set of weights.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            num_sample: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [batch_size, n_samples, output_dim], the output
            data.
        """
        X = X.view(-1, self.input_dim)
        X = X.repeat(num_sample, 1, 1)

        X = self.activation_fn(self.norm_layer1(self.input_layer.forward_eval(X, num_sample)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.forward_eval(X, num_sample)))
        X = self.output_layer.forward_eval(X, num_sample)

        return X

    def fkl(self, Xall, num_data, prior, num_sample):

        # sample measurement set
        measurement_set = self.sample_measurement_set(Xall, num_data)
        #add_set = torch.randint(-5, 5, (40, 1))
        #measurement_set_2 = torch.cat((measurement_set, add_set), 0)
        # evaluate function value
        noisy_func_x_rand = self.sample_functions(measurement_set, num_sample).transpose(0, 1)
        #print(noisy_func_x_rand)
        # noisy_func_x_rand_prior = prior.sample_functions(measurement_set, num_sample).transpose(0, 1)
        prior_marginal = prior(measurement_set)
        mean_prior = prior_marginal.mean
        mean_prior = mean_prior[:, None]
        mean_prior = mean_prior.repeat(num_sample, 1, 1)

        K_prior = torch.diag(prior_marginal.covariance_matrix)
        K_prior = K_prior[:, None]
        K_prior = K_prior.repeat(num_sample, 1, 1)

        noisy_func_x_rand_prior = (mean_prior + K_prior * torch.randn_like(mean_prior)).transpose(0,1)

        # estimate entropy surrogate H(q(f))
        estimator = SpectralScoreEstimator(eta=self.eta, n_eigen_threshold=self.n_eigen_threshold)
        dlog_q = estimator.compute_gradients(noisy_func_x_rand)
        entropy_sur = torch.mean(torch.sum(dlog_q.detach() * noisy_func_x_rand, -1))

        # estimate cross entropy with prior
        cross_entropy_gradients = estimator.compute_gradients(noisy_func_x_rand_prior, noisy_func_x_rand)
        cross_entropy_sur = torch.mean(torch.sum(cross_entropy_gradients.detach() * noisy_func_x_rand, -1))

        fkl_prior = entropy_sur - cross_entropy_sur

        return fkl_prior

    def fkl3(self, Xall, num_data, prior, num_sample):

        # sample measurement set
        measurement_set = self.sample_measurement_set(Xall, num_data)
        #add_set = torch.randint(-5, 5, (40, 1))
        #measurement_set_2 = torch.cat((measurement_set, add_set), 0)
        # evaluate function value
        noisy_func_x_rand = self.sample_functions(measurement_set, num_sample).transpose(0, 1)
        #print(noisy_func_x_rand)
        # noisy_func_x_rand_prior = prior.sample_functions(measurement_set, num_sample).transpose(0, 1)


        noisy_func_x_rand_prior = prior.sample_functions(measurement_set, num_sample).transpose(0, 1)

        # estimate entropy surrogate H(q(f))
        estimator = SpectralScoreEstimator(eta=self.eta, n_eigen_threshold=self.n_eigen_threshold)
        dlog_q = estimator.compute_gradients(noisy_func_x_rand)
        entropy_sur = torch.mean(torch.sum(dlog_q.detach() * noisy_func_x_rand, -1))

        # estimate cross entropy with prior
        cross_entropy_gradients = estimator.compute_gradients(noisy_func_x_rand_prior, noisy_func_x_rand)
        cross_entropy_sur = torch.mean(torch.sum(cross_entropy_gradients.detach() * noisy_func_x_rand, -1))

        fkl_prior = entropy_sur - cross_entropy_sur

        return fkl_prior

    def fkl2(self, Xall, num_data, prior, num_sample):

        measurement_set = self.sample_measurement_set(Xall, num_data)
        noisy_func_x_rand = self.sample_functions(measurement_set, num_sample)
        # print('func_shape: ', noisy_func_x_rand.shape)
        # measurement_set = measurement_set.view(-1, self.input_dim)
        # measurement_set = measurement_set.repeat(num_sample, 1, 1)
        # prior_marginal = prior(measurement_set)
        # mean_prior = prior_marginal.mean
        # mean_prior = mean_prior.repeat(num_sample, 1, 1)
        # print(mean_prior.shape)

        # K_prior = torch.diag(prior_marginal.covariance_matrix)
        # K_prior = torch.reshape(K_prior, (20, 2))
        # K_prior = K_prior.repeat(num_sample, 1, 1)
        # # print(K_prior.shape)
        # noisy_func_x_rand_prior = (mean_prior + K_prior * torch.randn_like(mean_prior)).transpose(0, 1)
        noisy_func_x_rand_prior = prior(measurement_set).sample(torch.Size((128,)))
        noisy_func_x_rand_prior = torch.transpose(noisy_func_x_rand_prior, -1, 1)
        # print('prior_shape:', noisy_func_x_rand_prior.shape)

        # noisy_func_x_rand_prior = prior_marginal.sample(sample_shape=torch.Size(num_sample,))

        # estimate entropy surrogate H(q(f))
        estimator = SpectralScoreEstimator(eta=self.eta, n_eigen_threshold=self.n_eigen_threshold)
        dlog_q = estimator.compute_gradients(noisy_func_x_rand)
        entropy_sur = torch.mean(torch.sum(dlog_q.detach() * noisy_func_x_rand, -1))

        # estimate cross entropy with prior
        cross_entropy_gradients = estimator.compute_gradients(noisy_func_x_rand_prior, noisy_func_x_rand)
        cross_entropy_sur = torch.mean(torch.sum(cross_entropy_gradients.detach() * noisy_func_x_rand, -1))

        fkl_prior = entropy_sur - cross_entropy_sur

        return fkl_prior





    def sample_measurement_set(self, X, num_data):

        # sample from old using geometric distribution
        # p = 2.0 / num_data
        # g = torch.distributions.Geometric(p)
        # n = g.sample()
        # count = 0
        # while n > num_data:
        #     n = g.sample()
        #     count += 1
        #     if count > 10:
        #         n = num_data
        #         break

        n = torch.Tensor([40])
        # sample measurement set with size n
        perm = torch.randperm(int(num_data))
        idx = perm[:n.to(torch.long)]
        measurement_set = X[idx, :]

        return measurement_set


    def save(self, checkpoint_path):
        torch.save(self.state_dict(), checkpoint_path)

    def load(self, checkpoint_path):
        self.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
