import torch.nn as nn
import torch
import numpy as np
from numpy.random import gamma
from torch.optim import Optimizer
from torch.utils.data import TensorDataset, DataLoader
import itertools
import os
import copy
from itertools import repeat, islice, cycle
import math
import torch.nn.functional as F
from torch.distributions import Gamma

class H_SA_SGHMC(Optimizer):
    """ Reproduced from https://github.com/JavierAntoran/Bayesian-Neural-Networks/blob/master/src/Stochastic_Gradient_HMC_SA/optimizers.py
        Stochastic Gradient Hamiltonian Monte-Carlo Sampler that uses scale adaption during burn-in
        procedure to find some hyperparamters. A gaussian prior is placed over parameters and a Gamma
        Hyperprior is placed over the prior's standard deviation"""

    def __init__(self, params, lr=1e-2, base_C=0.05, gauss_sig=0.1, alpha0=10, beta0=10):

        self.eps = 1e-6
        self.alpha0 = alpha0
        self.beta0 = beta0

        if gauss_sig == 0:
            self.weight_decay = 0
        else:
            self.weight_decay = 1 / (gauss_sig ** 2)

        if self.weight_decay <= 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(self.weight_decay))
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if base_C < 0:
            raise ValueError("Invalid friction term: {}".format(base_C))

        defaults = dict(
            lr=lr,
            base_C=base_C,
        )
        super(H_SA_SGHMC, self).__init__(params, defaults)

    def step(self, burn_in=False, resample_momentum=False, resample_prior=False):
        """Simulate discretized Hamiltonian dynamics for one step"""
        loss = None

        for group in self.param_groups:  # iterate over blocks -> the ones defined in defaults. We dont use groups.
            for p in group["params"]:  # these are weight and bias matrices
                if p.grad is None:
                    continue
                state = self.state[p]  # define dict for each individual param
                if len(state) == 0:
                    state["iteration"] = 0
                    state["tau"] = torch.ones_like(p)
                    state["g"] = torch.ones_like(p)
                    state["V_hat"] = torch.ones_like(p)
                    state["v_momentum"] = torch.zeros_like(
                        p)  # p.data.new(p.data.size()).normal_(mean=0, std=np.sqrt(group["lr"])) #
                    state['weight_decay'] = self.weight_decay

                state["iteration"] += 1  # this is kind of useless now but lets keep it provisionally

                if resample_prior:
                    alpha = self.alpha0 + p.data.nelement() / 2
                    beta = self.beta0 + (p.data ** 2).sum().item() / 2
                    gamma_sample = gamma(shape=alpha, scale=1 / (beta), size=None)
                    #                     print('std', 1/np.sqrt(gamma_sample))
                    state['weight_decay'] = gamma_sample

                base_C, lr = group["base_C"], group["lr"]
                weight_decay = state["weight_decay"]
                tau, g, V_hat = state["tau"], state["g"], state["V_hat"]

                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)

                # update parameters during burn-in
                if burn_in:  # We update g first as it makes most sense
                    tau.add_(-tau * (g ** 2) / (
                                V_hat + self.eps) + 1)  # specifies the moving average window, see Eq 9 in [1] left
                    tau_inv = 1. / (tau + self.eps)
                    g.add_(-tau_inv * g + tau_inv * d_p)  # average gradient see Eq 9 in [1] right
                    V_hat.add_(-tau_inv * V_hat + tau_inv * (d_p ** 2))  # gradient variance see Eq 8 in [1]

                V_sqrt = torch.sqrt(V_hat)
                V_inv_sqrt = 1. / (V_sqrt + self.eps)  # preconditioner

                if resample_momentum:  # equivalent to var = M under momentum reparametrisation
                    state["v_momentum"] = torch.normal(mean=torch.zeros_like(d_p),
                                                       std=torch.sqrt((lr ** 2) * V_inv_sqrt))
                v_momentum = state["v_momentum"]

                noise_var = (2. * (lr ** 2) * V_inv_sqrt * base_C - (lr ** 4))
                noise_std = torch.sqrt(torch.clamp(noise_var, min=1e-16))
                # sample random epsilon
                noise_sample = torch.normal(mean=torch.zeros_like(d_p), std=torch.ones_like(d_p) * noise_std)

                # update momentum (Eq 10 right in [1])
                v_momentum.add_(- (lr ** 2) * V_inv_sqrt * d_p - base_C * v_momentum + noise_sample)

                # update theta (Eq 10 left in [1])
                p.data.add_(v_momentum)

        return loss


    def transform_param_name(self, name):
        return name.replace(".", "@")

    def inverse_transform_param_name(self, name):
        return name.replace("@", ".")

    def sample_functions(self, theta, x, n_samples, log=False):
        outputs = []

        for _ in range(n_samples):

            param = {key: self.m[self.transform_param_name(key)] + torch.randn_like(value) * torch.exp(self.log_s_[self.transform_param_name(key)])
                     for key, value in dict(self.model.named_parameters()).items()}
            
            outputs.append(torch.func.functional_call(self.model, param, (theta, x), strict=False))

        outputs = torch.stack(outputs, dim=1)
        outputs = outputs.unsqueeze(dim=2)

        if not log:
            outputs = outputs.exp()

        return outputs
        
    
class MeasurementGenerator():
    def __init__(self, train_set, parameters_norm_fct=None, observation_norm_fct=None):
        self.train_iterator = iter(train_set)
        
        self.data_iterator = cycle(self.train_iterator)

        self.parameters_norm_fct = parameters_norm_fct
        self.observation_norm_fct = observation_norm_fct

    def get(self, n_data):
        theta = None
        x = None

        # Create samples from the joint
        while x is None or theta is None or x.shape[0] < n_data/2 or theta.shape[0] < n_data/2:
            theta_cur, x_cur = next(self.data_iterator)
            
            if theta is None:
                theta = theta_cur
            else:
                theta = torch.cat((theta, theta_cur), dim=0)

            if x is None:
                x = x_cur
            else:
                x = torch.cat((x, x_cur), dim=0)

        # random permute in order to not always select the same samples in a batch.
        perm = torch.randperm(x.shape[0])
        x = x[perm, ...]
        theta = theta[perm, ...]

        # Cut to only keep the required amount of samples
        x = x[:math.ceil(n_data/2), ...]
        theta = theta[:math.ceil(n_data/2), ...]

        # Create samples from the marginal
        theta_prime = torch.roll(theta, 1, dims=0)

        theta = torch.cat((theta, theta_prime), dim=0)
        x = torch.cat((x, x), dim=0)

        if self.parameters_norm_fct:
            theta = self.parameters_norm_fct(theta)

        if self.observation_norm_fct:
            x = self.parameters_norm_fct(x)

        return theta, x


class LipschitzFunction(nn.Module):
    def __init__(self, dim, lipschitz_config):
        super(LipschitzFunction, self).__init__()
        nb_layers = lipschitz_config["nb_layers"]
        nb_neurons = lipschitz_config["nb_neurons"]
        
        self.head = []

        if nb_layers == 0:
            self.head.append(nn.Linear(dim, 1))
            self.head.append(nn.Softplus())
        
        else:
            self.head.append(nn.Linear(dim, nb_neurons))
            for _ in range(nb_layers -1):
                self.head.append(nn.Linear(nb_neurons, nb_neurons))
                self.head.append(nn.Softplus())

            self.head.append(nn.Linear(nb_neurons, 1))

        self.head = nn.ModuleList(self.head)

    def forward(self, x):
        for layer in self.head:
            x = layer(x)
        return x


class LipschitzFunctionWithInputs(nn.Module):
    def __init__(self, measurement_size, benchmark, lipschitz_config):
        super(LipschitzFunctionWithInputs, self).__init__()
        
        nb_layers = lipschitz_config["nb_layers"]
        nb_neurons = lipschitz_config["nb_neurons"]

        self.observable_shape = benchmark.get_observable_shape()
        self.embedding_dim = benchmark.get_embedding_dim()
        self.parameter_dim = benchmark.get_parameter_dim()
        output_dim = 1
        dim = self.embedding_dim + self.parameter_dim + output_dim

        embedding_build = benchmark.get_embedding_build()
        self.embedding = embedding_build(self.embedding_dim, self.observable_shape)
        
        if lipschitz_config["model"] == "deep_set":
            self.deep_set = []
            if nb_layers == 0:
                self.deep_set.append(nn.Linear(dim, 1))
                self.deep_set.append(nn.Softplus())
            else:
                self.deep_set.append(nn.Linear(dim, nb_neurons))
                for _ in range(nb_layers):
                    self.deep_set.append(nn.Linear(nb_neurons, nb_neurons))
                    self.deep_set.append(nn.Softplus())

            self.deep_set = nn.ModuleList(self.deep_set)

            self.model_type = "deep_set"
            head_dim = nb_neurons

        elif lipschitz_config["model"] == "mlp":
            self.model_type = "mlp"
            head_dim = dim * measurement_size

        else:
            raise NotImplementedError("Lipschitz model not implemented.")

        self.head = []
        if nb_layers == 0:
            self.head.append(nn.Linear(head_dim , 1))
            self.head.append(nn.Softplus())
        
        else:
            self.head.append(nn.Linear(head_dim, nb_neurons))
            for _ in range(nb_layers -1):
                self.head.append(nn.Linear(nb_neurons, nb_neurons))
                self.head.append(nn.Softplus())

            self.head.append(nn.Linear(nb_neurons, 1))

        self.head = nn.ModuleList(self.head)

    def forward(self, theta, x, out):
        # theta, x, out of size n_samples, n_mesurement, dim

        n_samples = x.shape[0]
        n_measurement = x.shape[1]
        x = x.view((n_samples * n_measurement, -1))
        embed = self.embedding(x)
        embed = embed.view((n_samples, n_measurement, -1))
        y = torch.cat((theta, embed, out), dim=2)
        
        if self.model_type == "deep_set":
            # Process all the measurements independently
            y = y.view((n_samples * n_measurement, -1))
            for layer in self.deep_set:
                y = layer(y)

            # Average the features of all measurements to make it permutation invariant.
            y = y.view((n_samples, n_measurement, -1))
            y = torch.mean(y, dim=1)

        if self.model_type == "mlp":
            # Concat all measurements in a single feature vector
            y = y.view(n_samples, -1)

        for layer in self.head:
            y = layer(y)

        return y


def weights_init(m):
    """Reproduced from https://github.com/tranbahien/you-need-a-good-prior/blob/master/optbnn/prior_mappers/wasserstein_mapper.py"""
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)


class WassersteinDistance():
    """Adapted from https://github.com/tranbahien/you-need-a-good-prior/blob/master/optbnn/prior_mappers/wasserstein_mapper.py"""
    def __init__(self, bnn, np_model, measurement_size, lipschitz_config,
                 benchmark, output_dim, use_lipschitz_constraint=True,
                 lipschitz_constraint_type="gp", wasserstein_lr=0.01,
                 device='cpu', gpu_np_model=True, data_generator=None,
                 penalty_coef=10):
        self.bnn = bnn
        self.np_model = np_model
        self.device = device
        self.output_dim = output_dim
        self.measurement_size = measurement_size
        self.lipschitz_constraint_type = lipschitz_constraint_type
        self.data_generator = data_generator
        assert self.lipschitz_constraint_type in ["gp", "lp"]

        self.lipschitz_f = LipschitzFunctionWithInputs(measurement_size, benchmark, lipschitz_config)
        self.lipschitz_f = self.lipschitz_f.to(self.device)
        self.gpu_np_model = gpu_np_model
        self.values_log = []

        self.optimiser = torch.optim.Adagrad(self.lipschitz_f.parameters(),
                                             lr=wasserstein_lr)
        self.use_lipschitz_constraint = use_lipschitz_constraint
        self.penalty_coeff = penalty_coef

    def calculate(self, theta, x, nnet_samples, np_samples):
        d = 0.
        theta = theta.repeat(nnet_samples.shape[0], 1, 1)
        x = x.repeat(nnet_samples.shape[0], 1, 1)
        for dim in range(self.output_dim):
            f_samples = self.lipschitz_f(theta, x, nnet_samples[:, :, dim].unsqueeze(dim=2))
            f_np = self.lipschitz_f(theta, x, np_samples[:, :, dim].unsqueeze(dim=2))
            d += torch.mean(f_samples, 0) - torch.mean(f_np, 0)
        return d

    def compute_gradient_penalty(self, theta, x, samples_p, samples_q):
        theta = theta.repeat(samples_p.shape[0], 1, 1)
        x = x.repeat(samples_p.shape[0], 1, 1)

        eps = torch.rand((samples_p.shape[0], 1)).to(samples_p.device)
        out = eps * samples_p.detach() + (1 - eps) * samples_q.detach()
        out = out.unsqueeze(dim=2)

        theta.requires_grad = True
        x.requires_grad = True
        out.requires_grad = True
        Y = self.lipschitz_f(theta, x, out)

        gradients_out = torch.autograd.grad(
            Y, out, grad_outputs=torch.ones(Y.size(), device=self.device),
            create_graph=True, retain_graph=True, only_inputs=True)[0]
        
        gradients_out = gradients_out.squeeze()
        f_gradient_norm = gradients_out.norm(2, dim=1)

        if self.lipschitz_constraint_type == "gp":
            # Gulrajani2017, Improved Training of Wasserstein GANs
            return ((f_gradient_norm - 1) ** 2).mean()

        elif self.lipschitz_constraint_type == "lp":
            # Henning2018, On the Regularization of Wasserstein GANs
            # Eq (8) in Section 5
            return ((torch.clamp(f_gradient_norm - 1, 0., np.inf))**2).mean()

    def wasserstein_optimisation(self, theta, x, n_samples, n_steps=10, threshold=None, debug=False):
        for p in self.lipschitz_f.parameters():
            p.requires_grad = True

        for i in range(n_steps):
            
            print("wasserstein step {}".format(i))
            if self.data_generator:
                theta, x = self.data_generator.get(self.measurement_size)

            if not self.gpu_np_model:
                x = x.to("cpu")
                theta = theta.to("cpu")

            # Draw functions from GP
            np_samples = self.np_model.sample_functions(
                theta, x, n_samples).detach().float().to(self.device)
            if self.output_dim > 1:
                np_samples = np_samples.squeeze()

            if not self.gpu_np_model:
                x = x.to(self.device)
                theta = theta.to(self.device)

            self.optimiser.zero_grad()

            # Draw functions from Bayesian Neural network
            nnet_samples = self.bnn.sample_functions(
                theta, x, n_samples).detach().float().to(self.device)
            if self.output_dim > 1:
                nnet_samples = nnet_samples.squeeze()

            #  It was of size: [n_dim, N, n_out]
            # will be of size: [N, n_dim, n_out]
            np_samples = np_samples.transpose(0, 1)
            nnet_samples = nnet_samples.transpose(0, 1)

            objective = -self.calculate(theta, x, nnet_samples, np_samples)

            if debug:
                self.values_log.append(-objective.item())

            if self.use_lipschitz_constraint:
                penalty = 0.
                for dim in range(self.output_dim):
                    penalty += self.compute_gradient_penalty(
                        theta, x, nnet_samples[:, :, dim], np_samples[:, :, dim])
                objective += self.penalty_coeff * penalty
            objective.backward()

            if threshold is not None:
                # Gradient Norm
                params = self.lipschitz_f.parameters()
                grad_norm = torch.cat([p.grad.data.flatten() for p in params]).norm()

            self.optimiser.step()
            if not self.use_lipschitz_constraint:
                for p in self.lipschitz_f.parameters():
                    p.data = torch.clamp(p, -.1, .1)
            if threshold is not None and grad_norm < threshold:
                print('WARNING: Grad norm (%.3f) lower than threshold (%.3f). ', end='')
                print('Stopping optimization at step %d' % (i))
                if debug:
                    ## '-1' because the last wssr value is not recorded
                    self.values_log = self.values_log + [self.values_log[-1]] * (n_steps-i-1)
                break
        for p in self.lipschitz_f.parameters():
            p.requires_grad = False


class WassersteinDistanceOld():
    """Adapted from https://github.com/tranbahien/you-need-a-good-prior/blob/master/optbnn/prior_mappers/wasserstein_mapper.py"""
    def __init__(self, bnn, np_model, lipschitz_f_dim, lipschitz_config, 
                 output_dim, use_lipschitz_constraint=True,
                 lipschitz_constraint_type="gp", wasserstein_lr=0.01,
                 device='cpu', gpu_np_model=True):
        self.bnn = bnn
        self.np_model = np_model
        self.device = device
        self.output_dim = output_dim
        self.lipschitz_f_dim = lipschitz_f_dim
        self.lipschitz_constraint_type = lipschitz_constraint_type
        assert self.lipschitz_constraint_type in ["gp", "lp"]

        self.lipschitz_f = LipschitzFunction(lipschitz_f_dim, lipschitz_config)
        self.lipschitz_f = self.lipschitz_f.to(self.device)
        self.gpu_np_model = gpu_np_model
        self.values_log = []

        self.optimiser = torch.optim.Adagrad(self.lipschitz_f.parameters(),
                                             lr=wasserstein_lr)
        self.use_lipschitz_constraint = use_lipschitz_constraint
        self.penalty_coeff = 10

    def calculate(self, nnet_samples, np_samples):
        d = 0.
        for dim in range(self.output_dim):
            f_samples = self.lipschitz_f(nnet_samples[:, :, dim].T)
            f_np = self.lipschitz_f(np_samples[:, :, dim].T)
            d += torch.mean(torch.mean(f_samples, 0) - torch.mean(f_np, 0))
        return d

    def compute_gradient_penalty(self, samples_p, samples_q):
        eps = torch.rand(samples_p.shape[1], 1).to(samples_p.device)
        X = eps * samples_p.t().detach() + (1 - eps) * samples_q.t().detach()
        X.requires_grad = True
        Y = self.lipschitz_f(X)
        gradients = torch.autograd.grad(
            Y, X, grad_outputs=torch.ones(Y.size(), device=self.device),
            create_graph=True, retain_graph=True, only_inputs=True)[0]
        f_gradient_norm = gradients.norm(2, dim=1)

        if self.lipschitz_constraint_type == "gp":
            # Gulrajani2017, Improved Training of Wasserstein GANs
            return ((f_gradient_norm - 1) ** 2).mean()

        elif self.lipschitz_constraint_type == "lp":
            # Henning2018, On the Regularization of Wasserstein GANs
            # Eq (8) in Section 5
            return ((torch.clamp(f_gradient_norm - 1, 0., np.inf))**2).mean()

    def wasserstein_optimisation(self, theta, x, n_samples, n_steps=10, threshold=None, debug=False):
        for p in self.lipschitz_f.parameters():
            p.requires_grad = True

        n_samples_bag = n_samples * 1
        if not self.gpu_np_model:
            x = x.to("cpu")
            theta = theta.to("cpu")

        # Draw functions from GP
        np_samples_bag = self.np_model.sample_functions(
            theta, x, n_samples_bag).detach().float().to(self.device)
        if self.output_dim > 1:
            np_samples_bag = np_samples_bag.squeeze()

        if not self.gpu_np_model:
            x = x.to(self.device)
            theta = theta.to(self.device)

        # Draw functions from Bayesian Neural network
        nnet_samples_bag = self.bnn.sample_functions(
            theta, x, n_samples_bag).detach().float().to(self.device)
        if self.output_dim > 1:
            nnet_samples_bag = nnet_samples_bag.squeeze()

        #  It was of size: [n_dim, N, n_out]
        # will be of size: [N, n_dim, n_out]
        np_samples_bag = np_samples_bag.transpose(0, 1)
        nnet_samples_bag = nnet_samples_bag.transpose(0, 1)
        dataset = TensorDataset(np_samples_bag, nnet_samples_bag)
        data_loader = DataLoader(dataset, batch_size=n_samples, num_workers=0)
        batch_generator = itertools.cycle(data_loader)

        for i in range(n_steps):
            np_samples, nnet_samples = next(batch_generator)
            #         was of size: [N, n_dim, n_out]
            # needs to be of size: [n_dim, N, n_out]
            np_samples = np_samples.transpose(0, 1)
            nnet_samples = nnet_samples.transpose(0, 1)

            self.optimiser.zero_grad()
            objective = -self.calculate(nnet_samples, np_samples)
            if debug:
                self.values_log.append(-objective.item())

            if self.use_lipschitz_constraint:
                penalty = 0.
                for dim in range(self.output_dim):
                    penalty += self.compute_gradient_penalty(
                        nnet_samples[:, :, dim], np_samples[:, :, dim])
                objective += self.penalty_coeff * penalty
            objective.backward()

            if threshold is not None:
                # Gradient Norm
                params = self.lipschitz_f.parameters()
                grad_norm = torch.cat([p.grad.data.flatten() for p in params]).norm()

            self.optimiser.step()
            if not self.use_lipschitz_constraint:
                for p in self.lipschitz_f.parameters():
                    p.data = torch.clamp(p, -.1, .1)
            if threshold is not None and grad_norm < threshold:
                print('WARNING: Grad norm (%.3f) lower than threshold (%.3f). ', end='')
                print('Stopping optimization at step %d' % (i))
                if debug:
                    ## '-1' because the last wssr value is not recorded
                    self.values_log = self.values_log + [self.values_log[-1]] * (n_steps-i-1)
                break
        for p in self.lipschitz_f.parameters():
            p.requires_grad = False


class MapperWasserstein(object):
    """Adapted from https://github.com/tranbahien/you-need-a-good-prior/blob/master/optbnn/prior_mappers/wasserstein_mapper.py"""
    def __init__(self, np_model, bnn, data_generator, out_dir,
                 lipschitz_config, output_dim=1, n_data=256,
                 wasserstein_steps=(200, 200), wasserstein_lr=0.01, wasserstein_thres=0.01, 
                 logger=None, gpu_np_model=False, lipschitz_constraint_type="gp", device="cpu",
                 distance_version="old", benchmark=None, discriminator_generator=None,
                 penalty_coef=10):
        
        self.np_model = np_model
        self.bnn = bnn
        self.data_generator = data_generator
        self.n_data = n_data
        self.output_dim = output_dim
        self.out_dir = out_dir
        self.device = device
        self.gpu_np_model = gpu_np_model
        self.distance_version = distance_version

        assert lipschitz_constraint_type in ["gp", "lp"]
        self.lipschitz_constraint_type = lipschitz_constraint_type

        if type(wasserstein_steps) != list and type(wasserstein_steps) != tuple:
            wasserstein_steps = (wasserstein_steps, wasserstein_steps)
        self.wasserstein_steps = wasserstein_steps
        self.wasserstein_threshold = wasserstein_thres

        # Move models to configured device
        if gpu_np_model:
            self.np_model = self.np_model.to(self.device)
        self.bnn = self.bnn.to(self.device)

        # Initialize the module of wasserstance distance
        if distance_version == "old":
            self.wasserstein = WassersteinDistanceOld(
                self.bnn, self.np_model,
                self.n_data, lipschitz_config, 
                output_dim=self.output_dim, wasserstein_lr=wasserstein_lr, 
                device=self.device, gpu_np_model=self.gpu_np_model,
                lipschitz_constraint_type=self.lipschitz_constraint_type)
            
        elif distance_version == "new":
            self.wasserstein = WassersteinDistance(
                self.bnn, self.np_model,self.n_data,
                lipschitz_config, benchmark,
                output_dim=self.output_dim, wasserstein_lr=wasserstein_lr, 
                device=self.device, gpu_np_model=self.gpu_np_model,
                lipschitz_constraint_type=self.lipschitz_constraint_type,
                data_generator=discriminator_generator, penalty_coef=penalty_coef)

        else:
            raise NotImplementedError("Distance version does not exist.")

        # Setup logger
        self.print_info = print if logger is None else logger.info


    def optimize(self, num_iters, n_samples=128, lr=1e-2,
                 print_every=1, debug=False, restart_lipschitz=False):
        wdist_hist = []

        wasserstein_steps = self.wasserstein_steps
        prior_optimizer = torch.optim.RMSprop(self.bnn.parameters(), lr=lr)

        ## Initialisation of lipschitz_f
        self.wasserstein.lipschitz_f.apply(weights_init)

        # Prior loop
        for it in range(1, num_iters+1):
            # Draw X
            theta, x = self.data_generator.get(self.n_data)
            x = x.to(self.device)
            theta = theta.to(self.device)
            if not self.gpu_np_model:
                x = x.to("cpu")
                theta = theta.to("cpu")

            # Draw functions from NP model
            np_samples = self.np_model.sample_functions(
                theta, x, n_samples).detach().float().to(self.device)
            if self.output_dim > 1:
                np_samples = np_samples.squeeze()

            if not self.gpu_np_model:
                x = x.to(self.device)
                theta = theta.to(self.device)

            # Draw functions from BNN
            nnet_samples = self.bnn.sample_functions(
                theta, x, n_samples).float().to(self.device)
            if self.output_dim > 1:
                nnet_samples = nnet_samples.squeeze()

            #print("nnet_samples = {}".format(nnet_samples.sum(dim=0)))
            #print("np_samples = {}".format(np_samples.sum(dim=0)))

            if restart_lipschitz:
                ## Initialisation of lipschitz_f
                self.wasserstein.lipschitz_f.apply(weights_init)

            # Optimisation of lipschitz_f
            self.wasserstein.wasserstein_optimisation(theta, x, 
                n_samples, n_steps=wasserstein_steps[1],
                threshold=self.wasserstein_threshold, debug=debug)
            prior_optimizer.zero_grad()


            if self.distance_version == "old":
                wdist = self.wasserstein.calculate(nnet_samples, np_samples)
                np_samples = np_samples.transpose(0, 1)
                nnet_samples = nnet_samples.transpose(0, 1)
            elif self.distance_version == "new":
                #  It was of size: [n_dim, N, n_out]
                # will be of size: [N, n_dim, n_out]
                np_samples = np_samples.transpose(0, 1)
                nnet_samples = nnet_samples.transpose(0, 1)
                wdist = self.wasserstein.calculate(theta, x, nnet_samples, np_samples)

            wdist.backward()
            # print("value = {}".format(self.bnn.m["npe@model@0@bias"]))
            # print("grad = {}".format(self.bnn.m["npe@model@0@bias"].grad))
            # print("value = {}".format(self.bnn.s_["npe@model@0@bias"]))
            # print("grad = {}".format(self.bnn.s_["npe@model@0@bias"].grad))
            prior_optimizer.step()

            with torch.no_grad():
                mmd = self.compute_mmd_given_samples(nnet_samples, np_samples)

            wdist_hist.append(float(wdist))
            if (it % print_every == 0) or it == 1:
                self.print_info(">>> Iteration # {:3d}: "
                                "Wasserstein Dist {:.4f} "
                                "MMD {:.4f}".format(
                                    it, float(wdist), float(mmd)))


        # Save accumulated list of intermediate wasserstein values
        if debug:
            values = np.array(self.wasserstein.values_log).reshape(-1, 1)
            path = os.path.join(self.out_dir, "wsr_intermediate_values.log")
            np.savetxt(path, values, fmt='%.6e')
            self.print_info('Saved intermediate wasserstein values in: ' + path)

        return wdist_hist
    
    def compute_mmd_given_samples(self, nnet_samples, np_samples):
        # Samples of shape [n samples, n_measurements]
        number_samples = math.floor(nnet_samples.shape[0]/2)

        nnet_samples_1 = nnet_samples[:number_samples, ...]
        nnet_samples_2 = nnet_samples[number_samples:, ...]
        np_samples_1 = np_samples[:number_samples, ...]
        np_samples_2 = np_samples[number_samples:, ...]

        lengthscale = math.sqrt(nnet_samples_1.shape[1])

        def rbf_kernel(p_samples, q_samples):
            # Samples of shape [n samples, n_measurements]
            return torch.exp(torch.sum((p_samples - q_samples)**2, dim=1) / (2 * lengthscale**2))
            # Return Tensor of shape [n_samples]
        
        term_1 = torch.mean(rbf_kernel(np_samples_1, np_samples_2)).item()
        term_2 = torch.mean(rbf_kernel(nnet_samples_1, nnet_samples_2)).item()
        term_3 = torch.mean(rbf_kernel(nnet_samples_1, np_samples_1)).item()

        return term_1 + term_2 - 2 * term_3

    def compute_mmd(self, theta, x, n_samples):
        with torch.no_grad():
            #theta, x = self.data_generator.get(n_data)

            x = x.to(self.device)
            theta = theta.to(self.device)
            if not self.gpu_np_model:
                x = x.to("cpu")
                theta = theta.to("cpu")

            np_samples = self.np_model.sample_functions(
                theta, x, n_samples).detach().float().to(self.device)
            if self.output_dim > 1:
                np_samples = np_samples.squeeze()

            if not self.gpu_np_model:
                x = x.to(self.device)
                theta = theta.to(self.device)

            nnet_samples = self.bnn.sample_functions(
                theta, x, n_samples).float().to(self.device)
            if self.output_dim > 1:
                nnet_samples = nnet_samples.squeeze()

            return self.compute_mmd_given_samples(nnet_samples, np_samples)


def freeze_model(model: nn.Module):
    for param in model.parameters():
        param.requires_grad = False