import torch
from torch import nn
import math
import numpy as np
from scipy.stats import multivariate_normal

from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt
import ot

from pdPINN.model.model_utiities import divergence, gradient, ST
from pdPINN.model.modules import *
from pdPINN.model.siren import to_numpy, to_tensor, batch_predict2, dataloader_from_np
from pdPINN.model.sampler import *
from sklearn.neighbors import KNeighborsRegressor


def build_fourier_net(input_dim, num_hidden_layers, hidden_features):
    fourier_layer = FourierEmbedding(input_dim=input_dim, num_frequencies=256)
    attention_layer = AttentionLayer(input_dim=fourier_layer.out_features, embedding_dim=hidden_features,
                                     num_layers=3)
    fc_net = FCLayer(in_features=attention_layer.out_features,
                     out_features=1, nonlinearity="tanh",
                     num_hidden_layers=num_hidden_layers, hidden_features=hidden_features)
    net = [fourier_layer, attention_layer, fc_net]
    return net


class MassCons3d(nn.Module):
    mse = nn.MSELoss()

    def __init__(self,
                 X, Y,
                 tmin, tmax,
                 xmin, xmax,
                 ymin, ymax,
                 zmin, zmax,
                 hidden_features=64,
                 vf_hidden_features=64,
                 density_hidden_layers=3,
                 vf_hidden_layers=3,
                 nonlinearity='relu',
                 include_boundary_conditions=True,
                 normalize_loss=True,
                 **kwargs
                 ):
        super().__init__()

        self.in_features = X.shape[1]

        self.hidden_features = hidden_features
        # self.device = device
        self.sobol_engine = torch.quasirandom.SobolEngine(dimension=self.in_features, scramble=True)

        self.tmin, self.tmax = tmin, tmax
        self.xmin, self.xmax = xmin, xmax
        self.ymin, self.ymax = ymin, ymax
        self.zmin, self.zmax = zmin, zmax

        self.min_extent = torch.nn.Parameter(
            torch.tensor([self.tmin, self.xmin, self.ymin, self.tmin], requires_grad=False))
        self.max_extent = torch.nn.Parameter(
            torch.tensor([self.tmax, self.xmax, self.ymax, self.tmax], requires_grad=False))

        self.normalize_layer = torch.jit.script(NormalizeLayer.from_data(X=X, ignore_dims=[], trainable=False))
        self.standardize_mass_layer = torch.jit.script(InverseStandardizeLayer(X=Y[..., [0]], trainable=True))
        self.standardize_vf_layer = torch.jit.script(InverseStandardizeLayer(X=Y[..., 1:], trainable=False))

        self.density_branch = torch.jit.script(FCLayer(in_features=self.in_features, out_features=1,
                                      hidden_features=hidden_features, num_hidden_layers=density_hidden_layers,
                                      outermost_linear=True, outermost_positive=False,
                                      nonlinearity=nonlinearity,
                                      sine_frequency=kwargs.get("sine_frequency", None)))

        self.velocity_branch = torch.jit.script(FCLayer(in_features=self.in_features, out_features=Y.shape[1] - 1,
                                       outermost_linear=True,
                                       outermost_positive=False,
                                       hidden_features=vf_hidden_features, num_hidden_layers=vf_hidden_layers,
                                       weight_init=None,
                                       nonlinearity=nonlinearity,
                                       sine_frequency=kwargs.get("sine_frequency", None)))

        # self.velocity_cov_tril = (torch.nn.Parameter(torch.tril(torch.rand((3, 3)) + torch.eye(3)).detach())
        #                           .to(self.device))
        self.velocity_cov_tril = torch.nn.Parameter(torch.eye(3), requires_grad=True)

        self.density_sd = torch.nn.Parameter(torch.exp(torch.FloatTensor([0.])),
                                             requires_grad=True)

        self.include_boundary_conditions = include_boundary_conditions
        self.normalize_loss = normalize_loss

        self.mean = torch.nn.Parameter(torch.tensor(np.mean(X[..., 1:], 0), dtype=torch.float32), requires_grad=False)
        self.cov = torch.nn.Parameter(torch.tensor(np.cov(X[..., 1:].T), dtype=torch.float32),
                                      requires_grad=False)  # .reshape(-1, 1)

        self.fraction_mcmc = .5  # 0.8

        self.num_samples = kwargs.get("n_samples_constraints")
        self.rar = kwargs.get("rar")
        self.ot_rar = kwargs.get("ot_rar")
        assert not (self.rar and self.ot_rar), "Only one adaptive refinement method at once allowed."

        unique_x = np.unique(X[..., 1:], axis=0)
        hull = ConvexHull(unique_x)
        self.convex_hull_tensor = torch.nn.Parameter(self.to_tensor(unique_x[hull.vertices]), requires_grad=False)
        self.convex_hull_np = unique_x[hull.vertices]

        self.dir = torch.distributions.Dirichlet(self.to_tensor(np.ones(self.convex_hull_np.shape[0]))).expand(
            (self.num_samples,))
        self.bounding_act = torch.nn.Hardtanh(min_val=-1e-5, max_val=1e12)

    def _compile(self, X):
        self.prior_sp, self.prior_torch, self.prior_torch_space = self.estimate_gaussian_from_data(X, scale_cov=0)
        self.prior_sp2d, self.prior_torch2d, self.prior_torch_space2d = self.estimate_gaussian_from_data(X[..., 1:],
                                                                                                         scale_cov=0)

        self.mvn = torch.distributions.MultivariateNormal(loc=self.mean, covariance_matrix=self.cov * 3)

        self.log_prior_max = torch.nn.Parameter(
            self.prior_torch.log_prob(self.prior_torch.mean),
            requires_grad=False
        ).to(self.device)

        def proposal(num_samples):
            time = np.random.uniform(self.tmin, self.tmax, num_samples)
            space = self.prior_sp2d.rvs(num_samples)
            samples = np.concatenate([time[..., np.newaxis], space], -1)
            return samples

        self.num_mcmc_chains = 500
        self.sampler = LocalSampler(log_p_unnormalized=self._log_p,
                                    # log_dp_unnormalized=self.gradient,
                                    num_chains=self.num_mcmc_chains,
                                    data_dim=X.shape[1],
                                    eps=0.001,
                                    prior_sample=proposal,  # self.prior_sp.rvs,
                                    prior_logprob=self.prior_sp.logpdf
                                    )

    @property
    def device(self):
        return next(self.parameters()).device

    def _log_p(self, x):
        return to_numpy(self.sample_log_prob(self.to_tensor(x, dtype=torch.float32, device=self.device)))

    def sample_log_prob(self, x):
        rho = self.density(x, normalize_inputs=True) * ((x[..., 0] >= self.tmin) * (x[..., 0] <= self.tmax)).reshape(-1,
                                                                                                                     1)
        # weight = torch.exp(self.prior_torch_space.log_prob(x[..., 1:]) - self.log_prior_max).reshape(-1, 1)
        # weight = torch.exp(self.prior_torch.log_prob(x) - self.log_prior_max).reshape(-1, 1)

        return torch.log(rho + 1e-10)

    def velocity(self, coords, normalize_inputs=True):
        """

        :param coords:
        :return:
        """

        # dens = self.density(coords, normalize_inputs=True).detach()**0.5
        # weight = self.distance_vel(coords[..., [-1]])
        if normalize_inputs:
            coords = self.normalize_layer(coords)
        # coords = self.positional_encoding(coords).squeeze()

        vel = self.velocity_branch(coords)
        # tmp = torch.ones_like(vel)
        # tmp[..., :-1] = weight
        # threshold = self.density_threshold(dens)
        # pred = self.velocity_branch(coords)
        # magnitude = pred[..., [0]] ** 2
        # direction = torch.tanh(pred[..., [1]]) * 2 * np.pi
        # vel_z = pred[..., [2]]
        #
        # vel = torch.cat([magnitude * torch.cos(direction),
        #                  magnitude * torch.sin(direction),
        #                  vel_z], -1)
        # return self.standardize_vf_layer(vel)  # * tmp  # * threshold
        return vel  # * tmp  # * threshold

    def dist_to_top(self, x):
        return 1. - x

    def dist_to_bottom(self, x):
        return x

    def sqrt_density(self, coords, normalize_inputs=True):
        """

        :param coords_transformed:
        :return:
        """
        # weight = self.distance_rho(coords[..., [-1]])

        if normalize_inputs:
            coords_transformed = self.normalize_layer(coords)
        else:
            coords_transformed = coords

        # coords = self.positional_encoding(coords).squeeze()

        sqrt_density = self.standardize_mass_layer(self.density_branch(coords_transformed))
        # log_density = self.density_branch(coords)
        weight = torch.exp(self.prior_torch_space.log_prob(coords[..., 1:]) - self.log_prior_max).reshape(-1, 1)
        # weight = torch.exp(self.prior_torch.log_prob(coords) - self.log_prior_max).reshape(-1, 1)

        return sqrt_density * weight  # *  1/(2 * np.pi * self.scaled_radius* 10)

    def density(self, coords, normalize_inputs=True):
        """
        Predict density
        Args:
            coords ():
            normalize_inputs ():

        Returns:

        """
        rho = torch.pow(self.sqrt_density(coords, normalize_inputs), 2)
        return self.bounding_act(rho)

    def constraint_loss(self, coords=None, get_losses=False):
        # sample & predict (u,v, rho) uniformly in input space
        # coords_sampled = torch.rand(num_samples, self.in_features).to(self.device)
        # coords = self.sample_background("uniform", num_samples)
        if coords is None:
            coords = self.grid_samples  # [idx]
        coords_normalized = self.normalize_layer(coords)

        log_rho = self.sqrt_density(coords, normalize_inputs=True)
        rho = torch.pow(log_rho, 2)
        # rho = self.density(coords, normalize_inputs=True)
        velocity = self.velocity(coords, normalize_inputs=True)

        mass_flux = rho * velocity
        # div_massflux = gradient(mass_flux, coords)[..., 1:].sum(-1)
        div_massflux = divergence(mass_flux, coords, x_offset=1)

        drho_dt = gradient(rho, coords)[..., [0]]

        assert div_massflux.shape == drho_dt.shape
        # loss_consistency = ((drho_dt + div_massflux)).pow(2).mean().sqrt()  # - 1e-2*rho.mean()
        sinks_and_sources = (drho_dt + div_massflux)
        # sinks_and_sources_grad = gradient(sinks_and_sources, coords)

        continuity_of_mass_rmse = (sinks_and_sources.pow(2)).mean().sqrt()  # self.cos(x1, x2)
        loss_consistency = continuity_of_mass_rmse / (
                torch.max(rho) - torch.min(rho)).detach() if self.normalize_loss else continuity_of_mass_rmse / 5000

        # loss_consistency + 1e-7 * rho.mean()
        # loss_consistency = loss_consistency  # + 1e-6 * rho.mean() #+ 1e-4* sinks_and_sources_grad.abs().mean().sqrt()
        # if calc_hessian:
        #     hessian = custom_hessian(log_rho, coords)[0].pow(2).mean().sqrt()
        # else:
        hessian = 0.

        if get_losses:
            return sinks_and_sources.pow(2)
        return dict(pde_weight=loss_consistency, hessian=hessian,
                    sinks_and_sources_deviation=continuity_of_mass_rmse)
        # return dict(pde=loss_consistency, hessian=hessian)

    def sample_background(self, sampling_method: ST, num_samples) -> torch.Tensor:
        if sampling_method == ST.uniform:
            coords = self.sobol_engine.draw(num_samples).to(self.device)
            coords = coords * (self.max_extent - self.min_extent) + self.min_extent
            # coords = coords * 4. - 2.
            # coords = self.normalize_layer._inverse(coords)
        elif sampling_method == ST.gaussian:
            coords = self.prior_torch.sample((num_samples,))
        elif sampling_method == ST.dirichlet:
            # raise NotImplementedError("Todo: Dirichlet")
            coords = self.to_tensor(self.sample_within_convex_hull(num_samples, numpy=True), device=self.device)
        else:
            raise ValueError(f"Unknown background sampling method '{sampling_method}'.")
        assert coords.shape[0] == num_samples
        return coords

    def sample_within_convex_hull(self, num_samples, numpy=True):
        if numpy:
            A = np.random.dirichlet(np.ones(self.convex_hull_np.shape[0]), (num_samples))
            samples = A @ self.convex_hull_np

            time = np.random.uniform(self.tmin, self.tmax, samples.shape[0]).reshape(-1, 1)
            samples = np.concatenate([time, samples], -1)
        else:
            A = self.dir.sample()
            samples = A @ self.convex_hull_tensor
            time = torch.rand(samples.shape[0], device=self.device).reshape(-1, 1) * (self.tmax - self.tmin) + self.tmin
            samples = torch.cat([time, samples], -1)

        return samples

    def sample_signal_domain(self, num_samples, sampling_method: ST = ST.uniform, burnin=500, rar=False, ot_rar=False):
        sample_weights = torch.ones((num_samples,), device=self.device)

        if sampling_method in [ST.uniform, ST.dirichlet, ST.gaussian]:
            samples = self.sample_background(sampling_method=sampling_method, num_samples=num_samples)
        elif sampling_method in [ST.importance_sampling]:
            n_seeds = max(10_000, num_samples * 10)  # num_samples * 10
            n_seeds = min(n_seeds, 100_000)  # num_samples * 10

            seeds = self.sample_background(sampling_method=ST.uniform, num_samples=n_seeds)
            logits_seeds = torch.log(self.constraint_loss(self.allow_input_derivatives(seeds), get_losses=True)+ 1e-10)
            # losses_seeds = torch.log(self.density(seeds) + 1e-10)
            knn = KNeighborsRegressor(1, weights="uniform")
            proposals = self.sample_background(sampling_method=ST.uniform, num_samples=100_000)
            logits_proposals = knn.fit(to_numpy(seeds), to_numpy(logits_seeds)).predict(to_numpy(proposals.squeeze()))

            # probs = (torch.abs(losses_samples) / torch.sum(torch.abs(losses_samples))).squeeze()
            # dist = torch.distributions.Categorical(probs=probs)
            dist = torch.distributions.Categorical(logits=to_tensor(logits_proposals, device=self.device).squeeze())
            # dist = torch.distributions.Categorical(logits=to_tensor(logits_proposals, device="cpu").squeeze())
            idx = dist.sample((num_samples,))
            # idx = dist.sample(num_samples, losses_samples)
            # idx = torch.multinomial(probs, num_samples,replacement=True)
            samples = proposals[idx, ...]
            sample_weights = (1 / torch.exp(dist.log_prob(idx))).detach()
        else:
            num_mcmc_samples = int(self.fraction_mcmc * num_samples)
            num_uniform_samples = int((1 - self.fraction_mcmc) * num_samples)

            samples = self.sample_background(sampling_method=ST.gaussian, num_samples=num_uniform_samples)
            if sampling_method == ST.it_pdpinn:

                # n_seeds = num_samples * 100
                # seeds = self.sample_background(sampling_method="uniform", num_samples=n_seeds)
                # # losses_seeds = torch.log(self.constraint_loss(self.allow_input_derivatives(seeds), get_losses=True))
                # losses_seeds = torch.log(self.density(seeds) + 1e-10)
                # knn = KNeighborsRegressor(1, weights="uniform")
                #
                proposals = self.sample_background(sampling_method=ST.uniform, num_samples=100_000)
                logits_proposals = torch.log(self.density(proposals) + 1e-10)

                # losses_proposals = knn.fit(to_numpy(seeds), to_numpy(losses_seeds)).predict(
                #     to_numpy(proposals.squeeze()))

                # probs = (torch.abs(losses_samples) / torch.sum(torch.abs(losses_samples))).squeeze()
                # dist = torch.distributions.Categorical(probs=probs)
                dist = torch.distributions.Categorical(logits=to_tensor(logits_proposals, device=self.device).squeeze())
                idx = dist.sample((num_samples,))
                # idx = dist.sample(num_samples, losses_samples)
                # idx = torch.multinomial(probs, num_samples,replacement=True)
                samples_mcmc = proposals[idx, ...]
                # sample_weights = (1 / torch.exp(dist.log_prob(idx))).detach()
            else:
                xt_init = self.sample_within_convex_hull(self.sampler.num_chains * 30)
                density_at_init = self.sqrt_density(self.to_tensor(xt_init))
                # vals, idx = torch.topk(density_at_init.squeeze(), self.sampler.num_chains)
                idx = torch.argmax(density_at_init, 0)
                xt_init = np.broadcast_to(xt_init[idx], (self.sampler.num_chains, *xt_init.shape[1:]))
                xt_init = np.random.normal(xt_init, 0.01)
                # xt_init = torch.broadcast_to(self.to_tensor(xt_init[idx, ...]).to("cpu"), (self.sampler.num_chains, *xt_init.shape[1:]))
                # xt_init = xt_init+torch.randn(xt_init.shape, device="cpu")*0.001

                # np.random.normal(xt_init, 0.001)
                # xt_init = xt_init[to_numpy(idx), ...]
                # xt_init = self.prior_sp.rvs(self.sampler.num_chains)

                samples_mcmc_all = self.sampler.sample_chains(x_init=xt_init, burnin=burnin,
                                                              n_samples=burnin + num_mcmc_samples)
                samples_mcmc = samples_mcmc_all[np.random.randint(samples_mcmc_all.shape[0],
                                                                  size=num_mcmc_samples), ...]  # np.random.choice(samples_mcmc_all, num_mcmc_samples)
                assert samples_mcmc.shape[0] == num_mcmc_samples

                samples_mcmc = self.to_tensor(samples_mcmc)
                # TODO: MCMC
                if sampling_method == ST.pdpinn_euler:
                    raise NotImplementedError("Todo: mcmc_euler")
                elif sampling_method != ST.mh_pdpinn:
                    raise ValueError(f"sampling_method='{sampling_method}' is unknown.")
            samples = torch.cat([samples, samples_mcmc], 0)
        #
        # if sampling_method == "mcmc":
        #
        #     times = [0, 1., 2., 3.]
        #
        #     min, max = -2, 2
        #
        #     samples_np = samples_mcmc_all
        #     bins = sum([[val - .1, val + .1] for val in times], [])
        #     hist, edges = np.histogramdd(samples_np, bins=[bins, 50, 50], range=[(-5, 5), (min, max), (min, max)])
        #     hist = hist[::2, ...]
        #
        #     X, Y = np.mgrid[min:max:50j, min:max:50j]
        #     xxyy = np.stack([X, Y], -1).reshape(-1, 2)
        #
        #     fig, axs = plt.subplots(3, len(times), figsize=(10, 10))
        #     axs = axs.T
        #     densities = []
        #
        #     for i, t in enumerate(times):
        #         ttxxyy = np.concatenate([np.ones((xxyy.shape[0], 1)) * t, xxyy], -1)
        #
        #         velocity_hat, density_hat = batch_predict2(self,
        #                                                    dataloader_from_np([ttxxyy, ], mb_size=1024 * 4,
        #                                                                       shuffle=False))
        #         density_hat = density_hat ** 2
        #         density_hat = density_hat.reshape(50, 50)
        #         densities.append(density_hat.sum())
        #
        #         im = axs[i, 0].imshow(density_hat, origin='lower', cmap="Blues", extent=[min, max, min, max])
        #         divider = make_axes_locatable(axs[i, 0])
        #         cax = divider.append_axes('right', size='5%', pad=0.05)
        #         fig.colorbar(im, cax=cax, orientation='vertical')
        #
        #         im = axs[i, 1].imshow(hist[i, ...], origin='lower', cmap="Blues", extent=[min, max, min, max])
        #         divider = make_axes_locatable(axs[i, 1])
        #         cax = divider.append_axes('right', size='5%', pad=0.05)
        #         fig.colorbar(im, cax=cax, orientation='vertical')
        #
        #         # axs[i, 0].scatter(*to_numpy(samples_mcmc)[:, 1:].T, marker='+', c=to_numpy(samples_mcmc)[:, 0],)
        #         axs[i, 0].axvline(self.xmin)
        #         axs[i, 0].axvline(self.xmax)
        #         axs[i, 0].axhline(self.ymin)
        #         axs[i, 0].axhline(self.ymax)
        #         axs[i, 0].set_title(f"p(x, t={t})")
        #         axs[i, 1].set_title(f"samples tε[{t - .1}, {t + .1}]")
        #
        #     axs[-1, -1].hist(samples_np[:, 0], bins=50)
        #     axs[-1, -1].axvline(self.tmin)
        #     axs[-1, -1].axvline(self.tmax)
        #     # axs[-1, -1].set_xlim(-1, 5)
        #     axs[-1, -1].set_xlabel("t")
        #     axs[-1, -1].set_title("marginal samples time")
        #     # axs[-2, -1].scatter(*samples_np[:, 1:].T, marker='+', alpha=0.01)
        #     # axs[-2, -1].set_title("marginal samples space")
        #     axs[-3, -1].bar(x=times, height=densities)
        #     axs[-3, -1].set_title("marginal density time")
        #     axs[-3, -1].set_xlabel("t")
        #     plt.tight_layout()
        #     plt.show()

        if rar:
            samples_cat = torch.cat([self.allow_input_derivatives(samples), self.grid_samples], dim=0)
            losses_cat = self.constraint_loss(samples_cat, get_losses=True).squeeze()
            _, idx = torch.topk(losses_cat, num_samples)
            samples = samples_cat[idx, ...]

        if ot_rar:
            k = num_samples // 4
            samples_searchmax = self.sample_background(num_samples=num_samples, sampling_method=ST.uniform)
            losses_cat = self.constraint_loss(samples_searchmax, get_losses=True)

            _, idx = torch.topk(losses_cat.squeeze(), k=k)
            samples_topk = samples_searchmax[idx]
            samples_uniform_learnmap = self.sample_background(num_samples=k, sampling_method=ST.uniform)
            Ae, be = ot.da.OT_mapping_linear(samples_uniform_learnmap, samples_topk)
            samples = samples @ Ae + be

        return self.allow_input_derivatives(samples)

    def do_sample(self, num_samples, sampling_method=ST.uniform, first=False) -> None:
        if sampling_method != ST.none:
            self.grid_samples = self.sample_signal_domain(int(num_samples), sampling_method=sampling_method,
                                                          rar=self.rar if not first else False,
                                                          ot_rar=self.ot_rar if not first else False
                                                          )

    def reconstruction_loss(self, coords, target):
        velocity_hat, density_hat = self.velocity(coords), self.sqrt_density(coords)

        density_dist = torch.distributions.Normal(loc=density_hat, scale=self.density_sd)

        mask = target[..., [0]] > 0
        # mask = target[..., [0]] >= 0
        mask_bc = torch.broadcast_to(mask, (mask.shape[0], 3))
        velocity_hat_masked = torch.masked_select(velocity_hat, mask_bc).reshape(-1, 3)
        vf_target_masked = torch.masked_select(target[..., 1:], mask_bc).reshape(-1, 3)
        velocity_dist = torch.distributions.MultivariateNormal(loc=velocity_hat_masked,
                                                               scale_tril=self.velocity_cov_tril)
        # np.allclose(tmp_1c.cpu().detach().numpy(), tmp_3c.reshape(-1, 3)[:, 0].cpu().detach().numpy())
        loss_density = -density_dist.log_prob(target[..., [0]]).mean()
        loss_velocity = -velocity_dist.log_prob(vf_target_masked).mean()

        mse_density = self.mse(input=density_hat, target=target[..., [0]])
        loss_u = self.mse(input=velocity_hat_masked[..., [0]],
                          target=vf_target_masked[..., [0]])
        loss_v = self.mse(input=velocity_hat_masked[..., [1]],
                          target=vf_target_masked[..., [1]])
        loss_w = self.mse(input=velocity_hat_masked[..., [2]],
                          target=vf_target_masked[..., [2]])

        return dict(density=mse_density, nlogprob_density=loss_density,
                    u=loss_u, v=loss_v, w=loss_w, velocity=loss_velocity)

    def training_loss(self, coords, target, weight_dict: dict, n_samples_constraints=1024):
        rec_loss = self.reconstruction_loss(coords, target)
        if "pde_weight" in weight_dict:
            constraint_loss = self.constraint_loss()
        else:
            constraint_loss = {"pde_weight": 0., "hessian": 0}

        loss_dict = {**rec_loss, **constraint_loss}

        total_loss = 0.
        for key, _ in weight_dict.items():
            total_loss += weight_dict[key] * loss_dict[key]

        return total_loss, loss_dict

    @staticmethod
    def allow_input_derivatives(coords):
        """
                allows to take derivative w.r.t. input

        Args:
            coords ():

        Returns:

        """
        return coords.detach().requires_grad_(True).to(coords.device)
        # return coords.requires_grad_(True)

    def estimate_gaussian_from_data(self, X: np.array, scale_cov: float = 5):
        mu_prior = np.average(X, 0)  # , weights=Y[:, 0])
        cov_prior = np.cov(X, rowvar=False) * (
                scale_cov * np.eye(X.shape[1]) + np.ones(X.shape[1]))  # , aweights=Y[:, 0])

        prior_sp = multivariate_normal(mean=mu_prior, cov=cov_prior)
        prior_torch = torch.distributions.MultivariateNormal(loc=self.to_tensor(mu_prior, device=self.device),
                                                             covariance_matrix=self.to_tensor(cov_prior,
                                                                                              device=self.device))

        prior_torch_space = torch.distributions.MultivariateNormal(loc=self.to_tensor(mu_prior[1:], device=self.device),
                                                                   covariance_matrix=self.to_tensor(cov_prior[1:, 1:],
                                                                                                    device=self.device))
        return prior_sp, prior_torch, prior_torch_space

    def to_tensor(self, input, dtype=torch.float32, **kwargs):
        return to_tensor(input, dtype=dtype, device=self.device)
