import math
import os
import shutil

import numpy as np
import torch
import torch.nn as nn
from lampe.inference import NPE
from lampe.inference import NPELoss
from tqdm import tqdm

from .base import Model
from .base import ModelFactory
from .np_priors import GPPrior
from .prior_mappers import DistanceBasedPriorMapper
from .prior_mappers import DatasetMeasurementGenerator, UniformMeasurementGenerator, HybridMeasurementGenerator
from .variational_distributions import GaussianNNParametersDistribution
from .variational_distributions import HierarchicalGaussianNNParametersDistribution
from .bayesian_methods.hmc import HMCmodel
from .bayesian_methods.vi import VImodel


import wandb


class BayesianNPEFactory(ModelFactory):
    def __init__(self, config, benchmark, simulation_budget, model_class=None):
        if model_class is None:
            model_class = BayesianNPEModel

        config_run = config.copy()
        for idx in range(len(config["simulation_budgets"])):
            if config["simulation_budgets"][idx] == simulation_budget:
                break
        config_run["train_batch_size"] = config["train_batch_size"][idx]
        config_run["temperature"] = config["temperature"][idx]
        config_run["max_temperature"] = config["max_temperature"][idx]
        config_run["weight_decay"] = config["weight_decay"][idx]

        super().__init__(config_run, benchmark, simulation_budget, model_class)

    def get_train_time(self, benchmark_time, epochs):
        return 2 * super().get_train_time(benchmark_time, epochs)

    def get_coverage_time(self, benchmark_time):

        if self.config["bnn_method"] == "vi":
            nb_networks = self.config["nb_networks"]

        elif self.config["bnn_method"] == "hmc":
            nb_networks = self.config["samples_per_chain"] * \
                self.config["nb_chains"]

        else:
            raise NotImplementedError(
                "bnn_method '{}' not implemented.".format(self.config["bnn_method"]))

        return nb_networks*super().get_coverage_time(benchmark_time)

    def get_test_time(self, benchmark_time):

        if self.config["bnn_method"] == "vi":
            nb_networks = self.config["nb_networks"]

        elif self.config["bnn_method"] == "hmc":
            nb_networks = self.config["samples_per_chain"] * \
                self.config["nb_chains"]

        else:
            raise NotImplementedError(
                "bnn_method '{}' not implemented.".format(self.config["bnn_method"]))

        return nb_networks*super().get_test_time(benchmark_time)

    def require_multiple_trainings(self):
        return True

    def nb_trainings_required(self):
        """The model requires multiple trainings if hmc method is used. The number of trainings corresponds to the number of chains. If vi method is used, only one training is required."""

        if self.config["bnn_method"] == "vi":
            return 1
        elif self.config["bnn_method"] == "hmc":
            return self.config["nb_chains"]
        else:
            raise NotImplementedError(
                "bnn_method '{}' not implemented.".format(self.config["bnn_method"]))

    def is_trained(self, id):

        if self.config["bnn_method"] == "vi":
            return self.model_class.is_trained(self.config["bnn_method"], self.get_model_path(id))
        elif self.config["bnn_method"] == "hmc":
            return self.model_class.is_trained(self.config["bnn_method"], self.get_model_path(id//self.config["nb_chains"]), id % self.config["nb_chains"])
        else:
            raise NotImplementedError(
                "bnn_method '{}' not implemented.".format(self.config["bnn_method"]))


class NPEWithEmbedding(nn.Module):
    def __init__(
        self,
        npe,
        embedding,
        normalize_observation_fct,
        unnormalize_observation_fct,
        normalize_parameters_fct,
        unnormalize_parameters_fct,
        normalization_log_jacobian,
    ):
        super().__init__()
        self.npe = npe
        self.embedding = embedding
        self.normalize_observation_fct = normalize_observation_fct
        self.unnormalize_observation_fct = unnormalize_observation_fct
        self.normalize_parameters_fct = normalize_parameters_fct
        self.unnormalize_parameters_fct = unnormalize_parameters_fct
        self.normalization_log_jacobian = normalization_log_jacobian

    def forward(self, theta, x):
        x = self.normalize_observation_fct(x)
        theta = self.normalize_parameters_fct(theta)
        return self.npe(theta, self.embedding(x)) + self.normalization_log_jacobian

    def sample(self, x, shape):
        x = self.normalize_observation_fct(x)
        model_output = self.npe.flow(self.embedding(x)).sample(shape)
        model_output = self.unnormalize_parameters_fct(model_output)
        return model_output


class BayesianLoss(nn.Module):
    def __init__(self, estimator, likelihoodLoss, prior, n_train):
        super().__init__()
        self.estimator = estimator
        self.likelihoodLoss = likelihoodLoss(estimator)
        self.prior = prior
        self.n_train = n_train
        print("n_train = {}".format(n_train))

    def forward(self, theta, x):
        lik_loss = self.likelihoodLoss(theta, x)
        prior_loss = -self.prior.prior_log_prob(self.estimator)
        return lik_loss * self.n_train + prior_loss

    def resample_prior(self):
        self.prior.resample_prior()


class BayesianNPEModel(Model):
    def __init__(self, benchmark, model_path, config, normalization_constants):

        class Embedding(nn.Module):
            def __init__(self, embedding, normalize_function):
                super().__init__()
                self.embedding = embedding
                self.normalize_function = normalize_function

            def forward(self, x):
                return self.embedding(self.normalize_function(x))

        super().__init__(normalization_constants)
        self.benchmark = benchmark
        self.config = config
        self.observable_shape = benchmark.get_observable_shape()
        self.embedding_dim = benchmark.get_embedding_dim()
        self.parameter_dim = benchmark.get_parameter_dim()
        self.device = benchmark.get_device()
        print("device = {}".format(self.device))

        self.model_path = model_path
        os.makedirs(self.model_path, exist_ok=True)
        self.config = config

        self.prior = benchmark.get_prior()

        embedding_kwargs = {}
        if "embedding_nb_layers" in config.keys() or "embedding_nb_neurons" in config.keys():
            embedding_kwargs["nb_layers"] = config["embedding_nb_layers"]
            embedding_kwargs["nb_neurons"] = config["embedding_nb_neurons"]

            embedding_build = benchmark.get_embedding_build(modified=True)
            self.embedding = embedding_build(self.embedding_dim, self.observable_shape, **embedding_kwargs).to(
                self.device
            )

            self.embedding_dim = config["embedding_nb_neurons"]
        else:
            embedding_build = benchmark.get_embedding_build()
            self.embedding = embedding_build(self.embedding_dim, self.observable_shape).to(
                self.device
            )

        flow_build, flow_kwargs = benchmark.get_flow_build()

        # Update hyperparams if specified
        if "nb_layers" in config.keys():
            flow_kwargs["hidden_features"] = [
                flow_kwargs["hidden_features"][0]]*config["nb_layers"]

        if "nb_neurons" in config.keys():
            flow_kwargs["hidden_features"] = [config["nb_neurons"]
                                              for _ in flow_kwargs["hidden_features"]]

        if "nb_transforms" in config.keys():
            flow_kwargs["transforms"] = config["nb_transforms"]

        self.flow = NPE(
            self.parameter_dim, self.embedding_dim, build=flow_build, **flow_kwargs
        ).to(self.device)
        self.model = NPEWithEmbedding(
            self.flow,
            self.embedding,
            self.normalize_observation,
            self.unnormalize_observation,
            self.normalize_parameters,
            self.unnormalize_parameters,
            self.get_normalization_log_jacobian(),
        )

        self.bnn_prior = self.get_bnn_prior()
        self.bnn_prior.to(self.device)

        self.vi_model = NPEWithEmbedding(
            self.flow,
            nn.Identity(),
            nn.Identity(),
            self.unnormalize_observation,
            self.normalize_parameters,
            self.unnormalize_parameters,
            self.get_normalization_log_jacobian(),
        )

        self.embedding = Embedding(self.embedding, self.normalize_observation)

        if config["bnn_method"] == "vi":
            # self.model = VImodel(
            #     self.vi_model, config, self.device, model_path, embedding=self.embedding)
            self.model = VImodel(
                 self.model, config, self.device, model_path, embedding=None)
        elif config["bnn_method"] == "hmc":
            self.model = HMCmodel(self.model, config, self.device, model_path)
        else:
            raise NotImplementedError(
                "bnn_method '{}' not implemented.".format(self.config["bnn_method"]))
        self.model = self.model.to(self.device)

    @classmethod
    def is_trained(cls, bnn_method, model_path=None, chain_id=None):

        assert bnn_method in ["vi", "hmc"]
        assert model_path is not None or bnn_method == "vi"
        assert chain_id is not None or bnn_method == "vi"

        if bnn_method == "vi":
            return os.path.exists(os.path.join(model_path, "VIparametrisaton.pt"))
        elif bnn_method == "hmc":
            return os.path.exists(os.path.join(model_path, "trained_{}.pt".format(chain_id)))

    @classmethod
    def is_initialized(cls, model_path):
        return os.path.exists(os.path.join(model_path, "bnn_prior.pt"))

    def log_prob(self, theta, x, id=None) -> torch.Tensor:
        """
        Returns the log probability of theta given x. If id is None, the log probability is computed using "nb_networks" networks defined in the config file for vi and all the networks computed from all the chains for hmc. If id is not None, the log probability is computed using the network with id "id" for vi and the network computed from the chain with id "id" for hmc.

        Args:
            theta (torch.Tensor): tensor of shape (batch_size, parameter_dim)
            x (torch.Tensor): tensor of shape (batch_size, observable_shape)
            id (int, optional): id of the network to use. Defaults to None.

        Returns:
            torch.Tensor: tensor of shape (batch_size)
        """
        return self.model.log_prob(theta, x, id_net=id)

    def log_prob_one_model(self, theta, x):
        # id is None so random network is used
        return self.model.log_prob(theta, x)

    def sample(self, x, shape, id=None):
        """
        Samples from the posterior distribution of theta given x. If id is None, the sample is computed using "nb_networks" networks defined in the config file for vi and all the networks computed from all the chains for hmc. If id is not None, the sample is computed using the network with id "id" for vi and the network computed from the chain with id "id" for hmc.

        Args:
            x (torch.Tensor): tensor of shape (batch_size, observable_shape)
            shape (torch.Size): shape of the sample
            id (int, optional): id of the network to use. Defaults to None.

        Returns:
            torch.Tensor: tensor of shape (shape, batch_size, parameter_dim)
        """
        x = x.to(self.device)
        return self.model.sample(x, shape, id_net=id)

    def get_posterior_fct(self):
        def get_posterior(x):
            class Posterior:
                def __init__(self, sampling_fct, log_prob_fct):
                    self.sample = sampling_fct
                    self.log_prob = log_prob_fct

            return Posterior(
                lambda shape: self.sample(x.to(self.device), shape).cpu(),
                lambda theta: self.log_prob(
                    theta.to(self.device), x.unsqueeze(0).to(self.device)
                )
                .squeeze(0)
                .cpu(),
            )

        return get_posterior

    def prior_log_prob(self, theta, x, n_estimators):
        x = x.to(self.device)
        theta = theta.to(self.device)
        outputs = self.bnn_prior.sample_functions(
            theta, x, n_estimators).squeeze(dim=2)
        outputs = torch.logsumexp(outputs, dim=-1) - np.log(
            n_estimators
        )

        return outputs

    def prior_sample(self, x, shape):
        x = x.to(self.device)
        nb_samples = np.prod(np.array(shape))
        samples = []
        for _ in range(nb_samples):
            samples.append(self.bnn_prior.sample(x))

        samples = torch.stack(samples, dim=0).view(*shape, -1)

        return samples

    def get_prior_fct(self, n_estimators):
        def get_prior(x):
            class Prior:
                def __init__(self, sampling_fct, log_prob_fct):
                    self.sample = sampling_fct
                    self.log_prob = log_prob_fct

            return Prior(
                lambda shape: self.prior_sample(
                    x.to(self.device), shape).cpu(),
                lambda theta: self.prior_log_prob(
                    theta.to(self.device), x.unsqueeze(
                        0).to(self.device), n_estimators
                )
                .squeeze(0)
                .cpu(),
            )

        return get_prior

    def get_bnn_prior(self):
        if "init_low_variance_init" in self.config.keys():
            low_variance_init = self.config["init_low_variance_init"]
        else:
            low_variance_init = False

        if low_variance_init:
            std_init_value = self.config["init_std_init_value"]
        else:
            std_init_value = None

        if self.config["prior_variational_family"] == "gaussian":
            return GaussianNNParametersDistribution(self.model, shared=False, low_variance_init=low_variance_init, std_init_value=std_init_value)
        elif self.config["prior_variational_family"] == "gaussian_softplus":
            return GaussianNNParametersDistribution(self.model, shared=False, softplus=True, beta=self.config["beta"], low_variance_init=low_variance_init, std_init_value=std_init_value)
        elif self.config["prior_variational_family"] == "shared_gaussian":
            return GaussianNNParametersDistribution(self.model, shared=True)
        elif self.config["prior_variational_family"] == "hierarchical_gaussian":
            return HierarchicalGaussianNNParametersDistribution(
                self.model, shared=False
            )
        elif self.config["prior_variational_family"] == "shared_hierarchical_gaussian":
            return HierarchicalGaussianNNParametersDistribution(self.model, shared=True)
        else:
            raise NotImplementedError(
                "Prior variational family not implemented.")

    def __call__(self, theta, x):
        return self.log_prob(theta, x)

    def sampling_enabled(self):
        return True

    def save(self):
        self.model.save()

    def save_init(self):
        torch.save(
            self.bnn_prior.state_dict(), os.path.join(self.model_path, "bnn_prior.pt")
        )

    def load_model(self, chain_id, index):
        self.embedding.load_state_dict(
            torch.load(
                os.path.join(
                    self.model_path, "embedding_{}_{}.pt".format(
                        chain_id, index)
                )
            )
        )
        self.flow.load_state_dict(
            torch.load(
                os.path.join(self.model_path,
                             "flow_{}_{}.pt".format(chain_id, index))
            )
        )

    def load(self):
        self.model.load()

    def load_init(self):
        self.bnn_prior.load_state_dict(
            torch.load(os.path.join(self.model_path, "bnn_prior.pt"), map_location=self.device)
        )

    def delete_models(self, chain_id):
        shutil.rmtree(
            os.path.join(self.model_path, "trained_{}.pt".format(chain_id)),
            ignore_errors=True,
        )
        shutil.rmtree(
            os.path.join(self.model_path,
                         "embedding_{}_*.pt".format(chain_id)),
            ignore_errors=True,
        )
        shutil.rmtree(
            os.path.join(self.model_path, "flow_{}_*.pt".format(chain_id)),
            ignore_errors=True,
        )
        shutil.rmtree(os.path.join(self.model_path,
                      "bnn_prior.pt"), ignore_errors=True)

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    def get_loss_fct(self):
        return NPELoss

    def train_models(self, train_set, val_set, config, chain_id=None):

        config = self.config
        loss_fct = self.get_loss_fct()
        if self.config["bnn_method"] == "vi":
            self.model.train_models(
                train_set, val_set, config, loss_fct, self.bnn_prior)

        elif self.config["bnn_method"] == "hmc":
            self.model.train_models(
                train_set, val_set, config, chain_id, loss_fct, self.bnn_prior)

        else:
            raise NotImplementedError(
                "bnn_method '{}' not implemented.".format(self.config["bnn_method"]))

    def initialize(self, train_set, val_set, config, debug=False, no_wandb_init=False):
        if config["optimize_prior"]:
            return self.optimize_prior(train_set, val_set, debug=debug, no_wandb_init=no_wandb_init)
        else:
            self.bnn_prior.set_mean(0.0)
            self.bnn_prior.set_std(config["prior_std"])

    def get_np_distribution(self, measurement_generator=None):
        # return DirichletPrior(self.benchmark, self.config, self.device)
        return GPPrior(self.benchmark, self.config, self.device, measurement_generator=measurement_generator)

    def wrap_bnn_prior(self, bnn_prior):
        return bnn_prior

    def optimize_prior(self, train_set, val_set, debug=False, no_wandb_init=False):
        if self.config["use_wandb"] and not no_wandb_init:
            wandb.init(
                project=self.config["wandb_project"],
                entity=self.config["wandb_user"],
                config=self.config,
            )

        # Create measurement dataset

        if self.config["measurement_generator_type"] == "uniform":
            upper = None
            lower = None
            if "automatic_observable_bounds" in self.config.keys() and self.config["automatic_observable_bounds"]:
                for _, x in train_set:
                    if lower is None:
                        lower = torch.min(x, dim=0)[0]
                    else:
                        lower = torch.min(torch.cat((x, lower.unsqueeze(0))), dim=0)[0]

                    if upper is None:
                        upper = torch.max(x, dim=0)[0]
                    else:
                        upper = torch.max(torch.cat((x, upper.unsqueeze(0))), dim=0)[0]

                observable_bounds = (lower, upper)

                print("observable_bounds = {}".format(observable_bounds))

            else:
                observable_bounds = self.benchmark.get_observation_domain()

        if self.config["measurement_generator_type"] == "dataset":
            measurement_generator = DatasetMeasurementGenerator(train_set)
            val_measurement_generator = DatasetMeasurementGenerator(val_set)

        elif self.config["measurement_generator_type"] == "uniform":
            measurement_generator = UniformMeasurementGenerator(
                [x * self.config["extend_parameter_domain"]
                    for x in self.benchmark.get_domain()], observable_bounds
            )
            val_measurement_generator = UniformMeasurementGenerator(
                [x * self.config["extend_parameter_domain"]
                    for x in self.benchmark.get_domain()], observable_bounds
            )

        elif self.config["measurement_generator_type"] == "hybrid":
            measurement_generator = HybridMeasurementGenerator(
                [x * self.config["extend_parameter_domain"]
                    for x in self.benchmark.get_domain()], train_set
            )
            val_measurement_generator = HybridMeasurementGenerator(
                [x * self.config["extend_parameter_domain"]
                    for x in self.benchmark.get_domain()], val_set
            )

        else:
            raise NotImplementedError(
                "Measurement generator type {} not implemented.".format(
                    self.config["measurement_generator_type"]
                )
            )
        
        # Initialize the non parametric distribution to match
        if "automatic_kernel" in self.config.keys() and self.config["automatic_kernel"]:
            np_distribution = self.get_np_distribution(measurement_generator=measurement_generator)
        else:
            np_distribution = self.get_np_distribution()

        log_dir = os.path.join(self.model_path, "logs")
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        output_dim = 1

        n_data = self.config["measurement_set_size"]
        n_samples = self.config["init_function_samples"]
        log_space = self.config["gp_log_space"]

        init_lr = self.config["init_lr"]
        init_iter = self.config["init_iter"]
        optimizer = self.config["init_optimizer"]

        optimizer_params = {"optimizer": optimizer, "lr": init_lr}
        if optimizer == "sgd":
            optimizer_params["momentum"] = self.config["init_momentum"]
        if optimizer == "adamw":
            optimizer_params["weight_decay"] = self.config["init_weight_decay"]

        use_wandb = self.config["use_wandb"]

        nb_val_steps = self.config["init_nb_val_steps"]
        val_step_every = self.config["init_val_step_every"]
        schedule_init_lr = self.config["schedule_init_lr"]

        distance_type = self.config["distance_type"]
        distance_config = {}

        if (
            distance_type == "wasserstein_old"
            or distance_type == "wasserstein"
            or distance_type == "discriminator"
        ):
            critic_config = {}
            critic_config["model"] = self.config["critic_model"]
            critic_config["nb_layers"] = self.config["critic_nb_layers"]
            critic_config["nb_neurons"] = self.config["critic_nb_neurons"]
            critic_config["batch_size"] = self.config["critic_batch_size"]

            if self.config["use_discriminator_generator"]:
                if self.config["measurement_generator_type"] == "dataset":
                    discriminator_generator = DatasetMeasurementGenerator(
                        train_set)
                else:
                    discriminator_generator = UniformMeasurementGenerator(
                        self.benchmark.get_domain(),
                        observable_bounds,
                    )
            else:
                discriminator_generator = None

        if distance_type == "wasserstein_old":
            distance_config["critic_config"] = critic_config
            distance_config["wasserstein_steps"] = self.config["wasserstein_steps"]
            distance_config["wasserstein_lr"] = self.config["wasserstein_lr"]
            distance_config["wasserstein_threshold"] = self.config[
                "wasserstein_threshold"
            ]
            distance_config["lipschitz_constraint_type"] = self.config[
                "lipschitz_constraint_type"
            ]
            distance_config["restart_lipschitz"] = self.config["restart_lipschitz"]

        elif distance_type == "wasserstein":
            distance_config["critic_config"] = critic_config
            distance_config["wasserstein_steps"] = self.config["wasserstein_steps"]
            distance_config["wasserstein_lr"] = self.config["wasserstein_lr"]
            distance_config["wasserstein_threshold"] = self.config[
                "wasserstein_threshold"
            ]
            distance_config["lipschitz_constraint_type"] = self.config[
                "lipschitz_constraint_type"
            ]
            distance_config["benchmark"] = self.benchmark
            distance_config["discriminator_generator"] = discriminator_generator
            distance_config["penalty_coef"] = self.config["penalty_coef"]
            distance_config["restart_lipschitz"] = self.config["restart_lipschitz"]
            distance_config["clip_gradient"] = self.config["lipschitz_clip_gradient"]
            if distance_config["clip_gradient"]:
                distance_config["clipping_norm"] = self.config[
                    "lipschitz_clipping_norm"
                ]
            else:
                distance_config["clipping_norm"] = None

        elif distance_type == "discriminator":
            distance_config["critic_config"] = critic_config
            distance_config["discriminator_steps"] = self.config["discriminator_steps"]
            distance_config["discriminator_lr"] = self.config["discriminator_lr"]
            distance_config["benchmark"] = self.benchmark
            distance_config["discriminator_generator"] = discriminator_generator
            distance_config["restart_discriminator"] = self.config[
                "restart_discriminator"
            ]
            distance_config["regularize_discriminator"] = self.config[
                "regularize_discriminator"
            ]
            if distance_config["regularize_discriminator"]:
                distance_config["penalty_coef"] = self.config["penalty_coef"]
            else:
                distance_config["penalty_coef"] = None

            distance_config["clip_gradient"] = self.config[
                "discriminator_clip_gradient"
            ]
            if distance_config["clip_gradient"]:
                distance_config["clipping_norm"] = self.config[
                    "discriminator_clipping_norm"
                ]
            else:
                distance_config["clipping_norm"] = None

        elif distance_type == "KL":
            pass

        elif distance_type == "stein_KL":
            if self.config["set_eta"]:
                eta = self.config["eta"]
            else:
                eta = None

            if self.config["set_num_eigs"]:
                num_eigs = self.config["num_eigs"]
            else:
                num_eigs = None

            distance_config["eta"] = eta
            distance_config["num_eigs"] = num_eigs
            distance_config["joint_entropy"] = self.config["joint_entropy"]
            distance_config["loss_divider"] = self.config["loss_divider"]

        else:
            raise NotImplementedError(
                "Distance version '{}' does not exist.".format(distance_type)
            )

        clip_gradient = self.config["generator_clip_gradient"]
        if clip_gradient == "threshold":
            clipping_norm = self.config["generator_clipping_norm"]
            clipping_quantile = None
        elif clip_gradient == "quantile":
            clipping_norm = None
            clipping_quantile = self.config["generator_clipping_quantile"]
        else:
            clipping_norm = None
            clipping_quantile = None

        # Create the mapper
        mapper = DistanceBasedPriorMapper(
            np_distribution,
            self.wrap_bnn_prior(self.bnn_prior),
            measurement_generator,
            log_dir,
            log_space,
            distance_type,
            distance_config,
            output_dim=output_dim,
            n_data=n_data,
            gpu_np_model=not (self.device == "cpu"),
            device=self.device,
            clip_gradient=clip_gradient,
            clipping_norm=clipping_norm,
            clipping_quantile=clipping_quantile,
            use_wandb=use_wandb,
            validation_data_generator=val_measurement_generator,
            nb_val_steps=nb_val_steps,
            val_steps_every=val_step_every,
            schedule_lr=schedule_init_lr,
        )

        return mapper.optimize(
            init_iter, optimizer_params, n_samples=n_samples, debug=debug
        )

    def is_ensemble(self) -> bool:
        """
        Returns True if the model is composed of more than one network, False otherwise.
        """
        return self.get_nb_networks() > 1

    def get_nb_networks(self):
        """
        Returns the number of networks used by the model to compute the bayesian model average. If the bnn_method is vi, the number of networks is defined in the config file. If the bnn_method is hmc, the number of networks corresponds to all the networks computed from all the chains.
        """

        if self.config["bnn_method"] == "vi":
            return self.config["nb_networks"]
        elif self.config["bnn_method"] == "hmc":
            nb_networks = self.config["samples_per_chain"] * \
                self.config["nb_chains"]
            return nb_networks
        else:
            raise NotImplementedError(
                "bnn_method '{}' not implemented.".format(self.config["bnn_method"]))
