import math
from abc import abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import pandas as pd
import pyro
import sbibm
import torch
from pyro import distributions as pdist
from pyro.distributions import (
    Categorical,
    MixtureSameFamily,
    MultivariateNormal,
    Uniform,
)
from sbibm.tasks.simulator import Simulator
from sbibm.utils.io import get_tensor_from_csv, save_tensor_to_csv
from sbibm.utils.pyro import (
    get_log_prob_fn,
    get_log_prob_grad_fn,
    make_log_prob_grad_fn,
)
from torchdiffeq import odeint


class Task:
    def __init__(
        self,
        dim_data: int,
        dim_parameters: int,
        name: str,
        num_observations: int,
        num_posterior_samples: List[int],
        num_simulations: List[int],
        path: Path,
        name_display: Optional[str] = None,
        num_reference_posterior_samples: int = None,
        observation_seeds: Optional[List[int]] = None,
    ):
        """Base class for tasks.

        Args:
            dim_data: Dimensionality of data.
            dim_parameters: Dimensionality of parameters.
            name: Name of task. Should be the name of the folder in which
                the task is stored. Used with `sbibm.get_task(name)`.
            num_observations: Number of different observations for this task.
            num_posterior_samples: Number of posterior samples to generate.
            num_simulations: List containing number of different simulations to
                run this task for.
            path: Path to folder of task.
            name_display: Display name of task, with correct upper/lower-case
                spelling and spaces. Defaults to `name`.
            num_reference_posterior_samples: Number of reference posterior samples
                to generate for this task. Defaults to `num_posterior_samples`.
            observation_seeds: List of observation seeds to use. Defaults to
                a sequence of length `num_observations`. Override to use specific
                seeds.
        """
        self.dim_data = dim_data
        self.dim_parameters = dim_parameters
        self.name = name
        self.num_observations = num_observations
        self.num_posterior_samples = num_posterior_samples
        self.num_simulations = num_simulations
        self.path = path

        self.name_display = name_display if name_display is not None else name
        self.num_reference_posterior_samples = (
            num_reference_posterior_samples
            if num_reference_posterior_samples is not None
            else num_posterior_samples
        )
        self.observation_seeds = (
            observation_seeds
            if observation_seeds is not None
            else [i + 1000000 for i in range(self.num_observations)]
        )

    @abstractmethod
    def get_prior(self) -> Callable:
        """Get function returning parameters from prior"""
        raise NotImplementedError

    def get_prior_dist(self) -> torch.distributions.Distribution:
        """Get prior distribution"""
        return self.prior_dist

    def get_prior_params(self) -> Dict[str, torch.Tensor]:
        """Get parameters of prior distribution"""
        return self.prior_params

    def get_labels_data(self) -> List[str]:
        """Get list containing parameter labels"""
        return [f"data_{i+1}" for i in range(self.dim_data)]

    def get_labels_parameters(self) -> List[str]:
        """Get list containing parameter labels"""
        return [f"parameter_{i+1}" for i in range(self.dim_parameters)]

    def get_observation(self, num_observation: int) -> torch.Tensor:
        """Get observed data for a given observation number"""
        path = (
            self.path
            / self.name
            / f"num_observation_{num_observation}"
            / "observation.csv"
        )
        return get_tensor_from_csv(path)

    def get_reference_posterior_samples(self, num_observation: int) -> torch.Tensor:
        """Get reference posterior samples for a given observation number"""
        path = (
            self.path
            / self.name
            / f"num_observation_{num_observation}"
            / "reference_posterior_samples.csv.bz2"
        )
        return get_tensor_from_csv(path)

    @abstractmethod
    def get_simulator(self) -> Callable:
        """Get function returning parameters from prior"""
        raise NotImplementedError

    def get_true_parameters(self, num_observation: int) -> torch.Tensor:
        """Get true parameters (parameters that generated the data) for a given observation number"""
        path = (
            self.path
            / self.name
            / f"num_observation_{num_observation}"
            / "true_parameters.csv"
        )
        return get_tensor_from_csv(path)

    def save_data(self, path: Union[str, Path], data: torch.Tensor):
        """Save data to a given path"""
        save_tensor_to_csv(path, data, self.get_labels_data())

    def save_parameters(self, path: Union[str, Path], parameters: torch.Tensor):
        """Save parameters to a given path"""
        save_tensor_to_csv(path, parameters, self.get_labels_parameters())

    def flatten_data(self, data: torch.Tensor) -> torch.Tensor:
        """Flattens data

        Data returned by the simulator is always flattened into 2D Tensors
        """
        return data.reshape(-1, self.dim_data)

    def unflatten_data(self, data: torch.Tensor) -> torch.Tensor:
        """Unflattens data

        Tasks that require more than 2 dimensions for output of the simulator (e.g.
        returning images) may override this method.
        """
        return data.reshape(-1, self.dim_data)

    def _get_log_prob_fn(
        self,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
        posterior: bool = True,
        implementation: str = "pyro",
        **kwargs: Any,
    ) -> Callable:
        """Gets function returning the unnormalized log probability of the posterior or
        likelihood

        Args:
            num_observation: Observation number
            observation: Instead of passing an observation number, an observation may be
                passed directly
            posterior: If False, will get likelihood instead of posterior
            implementation: Implementation to use, `pyro` or `experimental`
            kwargs: Additional keywords passed to `sbibm.utils.pyro.get_log_prob_fn`

        Returns:
            `log_prob_fn` that returns log probablities as `batch_size`
        """
        assert not (num_observation is None and observation is None)
        assert not (num_observation is not None and observation is not None)
        assert type(posterior) is bool

        conditioned_model = self._get_pyro_model(
            num_observation=num_observation,
            observation=observation,
            posterior=posterior,
        )

        log_prob_fn, _ = get_log_prob_fn(
            conditioned_model,
            implementation=implementation,
            **kwargs,
        )

        def log_prob_pyro(parameters):
            assert parameters.ndim == 2

            num_parameters = parameters.shape[0]
            if num_parameters == 1:
                return log_prob_fn({"parameters": parameters})
            else:
                log_probs = []
                for i in range(num_parameters):
                    log_probs.append(
                        log_prob_fn({"parameters": parameters[i, :].reshape(1, -1)})
                    )
                return torch.cat(log_probs)

        def log_prob_experimental(parameters):
            return log_prob_fn({"parameters": parameters})

        if implementation == "pyro":
            return log_prob_pyro
        elif implementation == "experimental":
            return log_prob_experimental
        else:
            raise NotImplementedError

    def _get_log_prob_grad_fn(
        self,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
        posterior: bool = True,
        implementation: str = "pyro",
        **kwargs: Any,
    ) -> Callable:
        """Gets function returning the unnormalized log probability of the posterior

        Args:
            num_observation: Observation number
            observation: Instead of passing an observation number, an observation may be
                passed directly
            posterior: If False, will get likelihood instead of posterior
            implementation: Implementation to use, `pyro` or `experimental`
            kwargs: Passed to `sbibm.utils.pyro.get_log_prob_grad_fn`

        Returns:
            `log_prob_grad_fn` that returns gradients as `batch_size` x
            `dim_parameter`
        """
        assert not (num_observation is None and observation is None)
        assert not (num_observation is not None and observation is not None)
        assert type(posterior) is bool
        assert implementation == "pyro"

        conditioned_model = self._get_pyro_model(
            num_observation=num_observation,
            observation=observation,
            posterior=posterior,
        )
        log_prob_grad_fn, _ = get_log_prob_grad_fn(
            conditioned_model,
            implementation=implementation,
            **kwargs,
        )

        def log_prob_grad_pyro(parameters):
            assert parameters.ndim == 2

            num_parameters = parameters.shape[0]
            if num_parameters == 1:
                grads, _ = log_prob_grad_fn({"parameters": parameters})
                return grads["parameters"].reshape(
                    parameters.shape[0], parameters.shape[1]
                )
            else:
                grads = []
                for i in range(num_parameters):
                    grad, _ = log_prob_grad_fn(
                        {"parameters": parameters[i, :].reshape(1, -1)}
                    )
                    grads.append(grad["parameters"].squeeze())
                return torch.stack(grads).reshape(
                    parameters.shape[0], parameters.shape[1]
                )

        if implementation == "pyro":
            return log_prob_grad_pyro
        else:
            raise NotImplementedError

    def _get_transforms(
        self,
        automatic_transforms_enabled: bool = True,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        """Gets transforms

        Args:
            num_observation: Observation number
            observation: Instead of passing an observation number, an observation may be
                passed directly
            automatic_transforms_enabled: If True, will automatically construct
                transforms to unconstrained space

        Returns:
            Dict containing transforms
        """
        conditioned_model = self._get_pyro_model(
            num_observation=num_observation, observation=observation
        )

        _, transforms = get_log_prob_fn(
            conditioned_model,
            automatic_transform_enabled=automatic_transforms_enabled,
        )

        return transforms

    def _get_observation_seed(self, num_observation: int) -> int:
        """Get observation seed for a given observation number"""
        path = (
            self.path
            / self.name
            / f"num_observation_{num_observation}"
            / "observation_seed.csv"
        )
        return int(pd.read_csv(path)["observation_seed"][0])

    def _get_pyro_model(
        self,
        posterior: bool = True,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
    ) -> Callable:
        """Get model function for use with Pyro

        If `num_observation` or `observation` is passed, the model is conditioned.

        Args:
            num_observation: Observation number
            observation: Instead of passing an observation number, an observation may be
                passed directly
            posterior: If False, will mask prior which will result in model useful
                for calculating log likelihoods instead of log posterior probabilities
        """
        assert not (num_observation is not None and observation is not None)

        if num_observation is not None:
            observation = self.get_observation(num_observation=num_observation)

        prior = self.get_prior()
        simulator = self.get_simulator()

        def model_fn():
            prior_ = pyro.poutine.mask(prior, torch.tensor(posterior))
            return simulator(prior_())

        if observation is not None:
            observation = self.unflatten_data(observation)
            return pyro.condition(model_fn, {"data": observation})
        else:
            return model_fn

    @abstractmethod
    def _sample_reference_posterior(
        self,
        num_samples: int,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Sample reference posterior for given observation

        Args:
            num_samples: Number of samples
            num_observation: Observation number
            observation: Instead of passing an observation number, an observation may be
                passed directly

        Returns:
            Samples from reference posterior
        """
        raise NotImplementedError

    def _save_observation_seed(self, num_observation: int, observation_seed: int):
        """Save observation seed for a given observation number"""
        path = (
            self.path
            / self.name
            / f"num_observation_{num_observation}"
            / "observation_seed.csv"
        )
        path.parent.mkdir(parents=True, exist_ok=True)
        pd.DataFrame(
            [[int(observation_seed), int(num_observation)]],
            columns=["observation_seed", "num_observation"],
        ).to_csv(path, index=False)

    def _save_observation(self, num_observation: int, observation: torch.Tensor):
        """Save observed data for a given observation number"""
        path = (
            self.path
            / self.name
            / f"num_observation_{num_observation}"
            / "observation.csv"
        )
        path.parent.mkdir(parents=True, exist_ok=True)
        self.save_data(path, observation)

    def _save_reference_posterior_samples(
        self, num_observation: int, reference_posterior_samples: torch.Tensor
    ):
        """Save reference posterior samples for a given observation number"""
        path = (
            self.path
            / self.name
            / f"num_observation_{num_observation}"
            / "reference_posterior_samples.csv.bz2"
        )
        path.parent.mkdir(parents=True, exist_ok=True)
        self.save_parameters(path, reference_posterior_samples)

    def _save_true_parameters(
        self, num_observation: int, true_parameters: torch.Tensor
    ):
        """Save true parameters (parameters that generated the data) for a given observation number"""
        path = (
            self.path
            / self.name
            / f"num_observation_{num_observation}"
            / "true_parameters.csv"
        )
        path.parent.mkdir(parents=True, exist_ok=True)
        self.save_parameters(path, true_parameters)

    def _setup(self, n_jobs: int = -1, create_reference: bool = True, **kwargs: Any):
        """Setup the task: generate observations and reference posterior samples

        In most cases, you don't need to execute this method, since its results are stored to disk.

        Re-executing will overwrite existing files.

        Args:
            n_jobs: Number of to use for Joblib
            create_reference: If False, skips reference creation
        """
        from joblib import Parallel, delayed

        def run(num_observation, observation_seed, **kwargs):
            np.random.seed(observation_seed)
            torch.manual_seed(observation_seed)
            self._save_observation_seed(num_observation, observation_seed)

            prior = self.get_prior()
            true_parameters = prior(num_samples=1)
            self._save_true_parameters(num_observation, true_parameters)

            simulator = self.get_simulator()
            observation = simulator(true_parameters)
            self._save_observation(num_observation, observation)

            if create_reference:
                reference_posterior_samples = self._sample_reference_posterior(
                    num_observation=num_observation,
                    num_samples=self.num_reference_posterior_samples,
                    **kwargs,
                )
                num_unique = torch.unique(reference_posterior_samples, dim=0).shape[0]
                assert num_unique == self.num_reference_posterior_samples
                self._save_reference_posterior_samples(
                    num_observation,
                    reference_posterior_samples,
                )

        # Parallel(n_jobs=n_jobs, verbose=50, backend="loky")(
        #     delayed(run)(num_observation, observation_seed, **kwargs)
        #     for num_observation, observation_seed in enumerate(
        #         self.observation_seeds, start=1
        #     )
        # )
        for num_observation, observation_seed in enumerate(
            self.observation_seeds, start=1
        ):
            run(num_observation, observation_seed, **kwargs)


class TwoMoons(Task):
    def __init__(self, p_dist: pdist.Distribution = None):
        """Two Moons"""

        # Observation seeds to use when generating ground truth
        observation_seeds = [
            1000011,  # observation 1
            1000001,  # observation 2
            1000002,  # observation 3
            1000003,  # observation 4
            1000013,  # observation 5
            1000005,  # observation 6
            1000006,  # observation 7
            1000007,  # observation 8
            1000008,  # observation 9
            1000009,  # observation 10
        ]

        super().__init__(
            dim_parameters=2,
            dim_data=2,
            name="two_moons",
            name_display="Two Moons",
            num_observations=10,
            num_posterior_samples=10000,
            num_reference_posterior_samples=10000,
            num_simulations=[100, 1000, 10000, 100000, 1000000],
            observation_seeds=observation_seeds,
            path=Path(__file__).parent.absolute(),
        )

        if p_dist is None:
            prior_bound = 1.0
            self.prior_params = {
                "low": -prior_bound * torch.ones((self.dim_parameters,)),
                "high": +prior_bound * torch.ones((self.dim_parameters,)),
            }
            self.prior_dist = pdist.Uniform(**self.prior_params).to_event(1)
            self.prior_dist.set_default_validate_args(False)
        else:
            self.prior_dist = p_dist

        self.simulator_params = {
            "a_low": -math.pi / 2.0,
            "a_high": +math.pi / 2.0,
            "base_offset": 0.25,
            "r_loc": 0.1,
            "r_scale": 0.01,
        }

    def get_prior(self) -> Callable:
        def prior(num_samples=1):
            return pyro.sample("parameters", self.prior_dist.expand_by([num_samples]))

        return prior

    def get_simulator(self, max_calls: Optional[int] = None) -> Simulator:
        """Get function returning samples from simulator given parameters

        Args:
            max_calls: Maximum number of function calls. Additional calls will
                result in SimulationBudgetExceeded exceptions. Defaults to None
                for infinite budget

        Return:
            Simulator callable
        """

        def simulator(parameters):
            num_samples = parameters.shape[0]

            a_dist = (
                pdist.Uniform(
                    low=self.simulator_params["a_low"],
                    high=self.simulator_params["a_high"],
                )
                .expand_by((num_samples, 1))
                .to_event(1)
            )
            a = a_dist.sample()

            r_dist = (
                pdist.Normal(
                    self.simulator_params["r_loc"], self.simulator_params["r_scale"]
                )
                .expand_by((num_samples, 1))
                .to_event(1)
            )
            r = r_dist.sample()

            p = torch.cat(
                (
                    torch.cos(a) * r + self.simulator_params["base_offset"],
                    torch.sin(a) * r,
                ),
                dim=1,
            )

            return self._map_fun(parameters, p)

        return Simulator(task=self, simulator=simulator, max_calls=max_calls)

    @staticmethod
    def _map_fun(parameters: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
        ang = torch.tensor([-math.pi / 4.0])
        c = torch.cos(ang)
        s = torch.sin(ang)
        z0 = (c * parameters[:, 0] - s * parameters[:, 1]).reshape(-1, 1)
        z1 = (s * parameters[:, 0] + c * parameters[:, 1]).reshape(-1, 1)
        return p + torch.cat((-torch.abs(z0), z1), dim=1)

    @staticmethod
    def _map_fun_inv(parameters: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        ang = torch.tensor([-math.pi / 4.0])
        c = torch.cos(ang)
        s = torch.sin(ang)
        z0 = (c * parameters[:, 0] - s * parameters[:, 1]).reshape(-1, 1)
        z1 = (s * parameters[:, 0] + c * parameters[:, 1]).reshape(-1, 1)
        return x - torch.cat((-torch.abs(z0), z1), dim=1)

    def _likelihood(
        self,
        parameters: torch.Tensor,
        data: torch.Tensor,
        log: bool = True,
    ) -> torch.Tensor:
        if parameters.ndim == 1:
            parameters = parameters.reshape(1, -1)

        assert parameters.shape[1] == self.dim_parameters
        assert data.shape[1] == self.dim_data

        p = self._map_fun_inv(parameters, data).squeeze(0)
        if p.ndim == 1:
            p = p.reshape(1, -1)
        u = p[:, 0] - self.simulator_params["base_offset"]
        v = p[:, 1]

        r = torch.sqrt(u**2 + v**2)
        L = -0.5 * (
            (r - self.simulator_params["r_loc"]) / self.simulator_params["r_scale"]
        ) ** 2 - 0.5 * torch.log(
            2 * torch.tensor([math.pi]) * self.simulator_params["r_scale"] ** 2
        )

        if len(torch.where(u < 0.0)[0]) > 0:
            L[torch.where(u < 0.0)[0]] = -torch.tensor(math.inf)

        return L if log else torch.exp(L)

    def _get_transforms(
        self,
        *args,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        return {
            "parameters": torch.distributions.transforms.IndependentTransform(
                torch.distributions.transforms.identity_transform, 1
            )
        }

    def _get_log_prob_fn(
        self,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> Callable:
        """Get potential function and initial parameters

        The potential function returns the unnormalized negative log
        posterior probability, and is useful to establish and verify
        the reference posterior.

        Args:
            num_observation: Observation number
            observation: Instead of passing an observation number, an observation may be
                passed directly

        Returns:
            Potential function and proposal for initial parameters, e.g., to start MCMC
        """
        assert not (num_observation is None and observation is None)
        assert not (num_observation is not None and observation is not None)

        prior_dist = self.get_prior_dist()

        if num_observation is not None:
            observation = self.get_observation(num_observation=num_observation)

        observation = self.unflatten_data(observation)

        def log_prob_fn(parameters):
            if type(parameters) == dict:
                parameters = parameters["parameters"]
            return self._likelihood(
                parameters=parameters, data=observation, log=True
            ) + prior_dist.log_prob(parameters)

        return log_prob_fn

    def _get_log_prob_grad_fn(
        self,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> Callable:
        lpgf = make_log_prob_grad_fn(
            self._get_log_prob_fn(
                num_observation=num_observation, observation=observation, **kwargs
            )
        )

        def log_prob_grad_fn(parameters):
            num_params = parameters.shape[0]
            grads = []
            for i in range(num_params):
                _, grad = lpgf({"parameters": parameters[i]})
                grads.append(grad)
            if len(grads) > 1:
                return torch.cat(grads).reshape(1, -1)
            else:
                return grad

        return log_prob_grad_fn

    def _sample_reference_posterior(
        self,
        num_samples: int,
        observation: torch.Tensor,  # shape (1,2)
    ) -> torch.Tensor:
        """
        Rejection-sampler in theta-space using prior proposal and likelihood acceptance,
        yielding "partial-moon" shapes under truncated Gaussian or mixture priors.

        Steps:
          1. propose theta from the (truncated) prior
          2. compute likelihood p(x|theta)
          3. accept with probability p(x|theta)/M where M is the max likelihood
        """
        import math

        import torch

        # constants for likelihood
        sigma = 0.01
        mu_r = 0.1
        # maximum likelihood over valid region occurs at rho = mu_r, phi = 0
        M_val = (1 / math.pi) * (1 / (math.sqrt(2 * math.pi) * sigma)) * (1 / mu_r)

        samples = []
        while len(samples) < num_samples:
            # 1) propose theta from truncated prior
            theta = self.prior_dist.sample(())  # shape (2,)

            # 2) compute forward map deterministic part M(theta)
            t1, t2 = theta[0].item(), theta[1].item()
            rot = 1 / math.sqrt(2)
            m_x = -abs(t1 + t2) * rot
            m_y = (-t1 + t2) * rot

            # 3) compute residual u = observation - [m_x+0.25, m_y]
            x_val = observation.squeeze(0)
            u = x_val - torch.tensor([m_x + 0.25, m_y])
            rho = torch.norm(u).item()
            if rho <= 0:
                continue

            # 4) check half-circle support
            phi = math.atan2(u[1].item(), u[0].item())
            if phi < -math.pi / 2 or phi > math.pi / 2:
                continue

            # 5) compute likelihood p(x|theta)
            p_x = (
                (1 / math.pi)
                * (1 / (math.sqrt(2 * math.pi) * sigma))
                * math.exp(-((rho - mu_r) ** 2) / (2 * sigma**2))
                * (1 / rho)
            )

            # 6) accept with probability p_x / M_val
            if torch.rand(()) < (p_x / M_val):
                samples.append(theta)

        return torch.stack(samples, dim=0)


class SLCP(Task):
    def __init__(self, p_dist: pdist.Distribution = None, distractors: bool = False):
        """SLCP"""
        self.num_data = 4
        self.distractors = distractors

        if not self.distractors:
            dim_data = 2 * self.num_data
            name = "slcp"
            name_display = "SLCP"
        else:
            dim_data = 100
            name = "slcp_distractors"
            name_display = "SLCP Distractors"

        # Observation seeds to use when generating ground truth
        # Avoiding extremely spiked posteriors, e.g., 1000006, 1000007, ...
        observation_seeds = [
            1000000,  # observation 1
            1000001,  # observation 2
            1000002,  # observation 3
            1000003,  # observation 4
            1000004,  # observation 5
            1000005,  # observation 6
            1000010,  # observation 7
            1000012,  # observation 8
            1000008,  # observation 9
            1000009,  # observation 10
        ]

        super().__init__(
            dim_parameters=5,
            dim_data=dim_data,
            name=name,
            name_display=name_display,
            num_observations=10,
            num_posterior_samples=10000,
            num_reference_posterior_samples=10000,
            num_simulations=[1000, 10000, 100000, 1000000],
            path=Path(__file__).parent.absolute(),
            observation_seeds=observation_seeds,
        )

        if p_dist is None:
            self.prior_params = {
                "low": torch.tensor([-1.0 for _ in range(self.dim_parameters)]),
                "high": torch.tensor([+1.0 for _ in range(self.dim_parameters)]),
            }
            self.prior_dist = pdist.Uniform(**self.prior_params).to_event(1)
            self.prior_dist.set_default_validate_args(False)
        else:
            self.prior_dist = p_dist
            self.prior_dist.set_default_validate_args(False)

        self.theta_shift = 0.0
        self.theta_rescale = 3.0
        # calculated from 10 million simulations from uniform prior parameters
        self.x_mean = 0.0
        self.x_std = 4.38

    def get_prior(self) -> Callable:
        def prior(num_samples=1):
            return pyro.sample("parameters", self.prior_dist.expand_by([num_samples]))

        return prior

    def get_simulator(self, max_calls: Optional[int] = None) -> Simulator:
        """Get function returning samples from simulator given parameters

        Args:
            max_calls: Maximum number of function calls. Additional calls will
                result in SimulationBudgetExceeded exceptions. Defaults to None
                for infinite budget

        Return:
            Simulator callable
        """

        def simulator(parameters):
            parameters = parameters * self.theta_rescale + self.theta_shift
            num_samples = parameters.shape[0]

            m = torch.stack(
                (parameters[:, [0]].squeeze(), parameters[:, [1]].squeeze())
            ).T
            if m.dim() == 1:
                m.unsqueeze_(0)

            s1 = parameters[:, [2]].squeeze() ** 2
            s2 = parameters[:, [3]].squeeze() ** 2
            rho = (
                torch.nn.Tanh()(parameters[:, [4]]).squeeze().clamp(-1 + 1e-6, 1 - 1e-6)
            )

            S = torch.empty((num_samples, 2, 2))
            S[:, 0, 0] = s1**2
            S[:, 0, 1] = rho * s1 * s2
            S[:, 1, 0] = rho * s1 * s2
            S[:, 1, 1] = s2**2

            # Add eps to diagonal to ensure PSD
            eps = 0.000001
            S[:, 0, 0] += eps
            S[:, 1, 1] += eps

            data_dist = pdist.MultivariateNormal(
                m.unsqueeze(1).float(), S.unsqueeze(1).float()
            ).expand(
                (
                    num_samples,
                    self.num_data,
                )
            )

            if not self.distractors:
                data = pyro.sample("data", data_dist)
                data = (data - self.x_mean) / self.x_std
                return data
            else:
                data = pyro.sample("data", data_dist).reshape((num_samples, 8))

                gmm = torch.load(self.path / "files" / "gmm.torch")
                noise = gmm.sample((num_samples,)).type(data.dtype)

                data_and_noise = torch.cat([data, noise], dim=1)

                permutation_idx = torch.load(
                    self.path / "files" / "permutation_idx.torch"
                )

                return data_and_noise[:, permutation_idx]

        return Simulator(task=self, simulator=simulator, max_calls=max_calls)

    def get_observation(self, num_observation: int) -> torch.Tensor:
        """Get observed data for a given observation number"""
        if not self.distractors:
            path = (
                self.path
                / self.name
                / f"num_observation_{num_observation}"
                / "observation.csv"
            )
            return get_tensor_from_csv(path)
        else:
            path = (
                self.path
                / "files"
                / f"num_observation_{num_observation}"
                / "observation_distractors.csv"
            )
            return get_tensor_from_csv(path)

    def _get_transforms(
        self,
        automatic_transforms_enabled: bool = True,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        """Gets transforms

        Args:
            num_observation: Observation number
            observation: Instead of passing an observation number, an observation may be
                passed directly
            automatic_transforms_enabled: If True, will automatically construct
                transforms to unconstrained space

        Returns:
            Dict containing transforms
        """
        if not self.distractors:
            return super()._get_transforms(
                automatic_transforms_enabled=automatic_transforms_enabled,
                num_observation=num_observation,
                observation=observation,
                **kwargs,
            )
        else:
            task = SLCP(distractors=False)
            return task._get_transforms(
                automatic_transforms_enabled=automatic_transforms_enabled,
                num_observation=num_observation,
                observation=observation,
                **kwargs,
            )

    def unflatten_data(self, data: torch.Tensor) -> torch.Tensor:
        """Unflattens data into multiple observations"""
        if not self.distractors:
            return data.reshape(-1, self.num_data, 2)
        else:
            raise NotImplementedError

    def _sample_reference_posterior(
        self,
        num_samples: int,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
        method: str = "rejection",
    ) -> torch.Tensor:
        """Sample reference posterior for given observation

        Args:
            num_observation: Observation number
            num_samples: Number of samples to generate
            observation: Observed data, if None, will be loaded using `num_observation`
            kwargs: Passed to run_mcmc

        Returns:
            Samples from reference posterior
        """
        from sbibm.algorithms.pytorch.baseline_rejection import run as run_rejection

        # from sbibm.algorithms.pytorch.baseline_sir import run as run_sir
        from sbibm.algorithms.pytorch.utils.proposal import get_proposal

        from priorg.sim.tasks.sir import run_sir

        proposal_samples, _ = run_sir(
            task=self,
            num_observation=num_observation,
            observation=observation,
            num_samples=10000,
            num_simulations=100_000_000,
            batch_size=100_000,
        )

        proposal_dist = get_proposal(
            task=self,
            samples=proposal_samples,
            prior_weight=0.1,
            bounded=False,
            density_estimator="flow",
            flow_model="maf",
        )

        if method == "rejection":
            return run_rejection(
                task=self,
                num_observation=num_observation,
                observation=observation,
                num_samples=num_samples,
                batch_size=10_000,
                num_batches_without_new_max=1_000,
                multiplier_M=1.2,
                proposal_dist=proposal_dist,
            )
        elif method == "sir":
            return run_sir(
                task=self,
                num_observation=num_observation,
                observation=observation,
                num_samples=num_samples,
                num_simulations=100_000_00,
                batch_size=100_00,
                proposal_dist=proposal_dist,
            )
        else:
            raise ValueError(f"Method {method} not supported")

    def _generate_noise_dist_parameters(self):
        import numpy as np

        noise_dim = 92
        n_noise_comps = 20

        rng = np.random
        rng.seed(42)

        loc = torch.from_numpy(
            np.array([15 * rng.normal(size=noise_dim) for i in range(n_noise_comps)])
        )

        cholesky_factors = [
            np.tril(rng.normal(size=(noise_dim, noise_dim)))
            + np.diag(np.exp(rng.normal(size=noise_dim)))
            for i in range(n_noise_comps)
        ]
        scale_tril = torch.from_numpy(3 * np.array(cholesky_factors))

        mix = pdist.Categorical(
            torch.ones(
                n_noise_comps,
            )
        )
        comp = pdist.Independent(
            pdist.MultivariateStudentT(df=2, loc=loc, scale_tril=scale_tril),
            0,
        )
        gmm = pdist.MixtureSameFamily(mix, comp)
        torch.save(gmm, "files/gmm.torch")

        permutation_idx = torch.from_numpy(rng.permutation(noise_dim + 8))
        torch.save(permutation_idx, "files/permutation_idx.torch")

        torch.manual_seed(42)

        for i in range(self.num_observations):
            num_observation = i + 1

            observation = self.get_observation(num_observation)
            noise = gmm.sample().reshape((1, -1)).type(observation.dtype)

            observation_and_noise = torch.cat([observation, noise], dim=1)

            path = (
                self.path
                / self.name
                / f"num_observation_{num_observation}"
                / "observation_distractors.csv"
            )
            self.dim_data = noise_dim + 8
            self.save_data(path, observation_and_noise[:, permutation_idx])


def _sir(t, u, params, N):
    # u: [batch, 3] of current state
    # params: [batch, 2]
    # Returns SIR vector field: [batch, 2]

    beta, gamma = torch.tensor_split(params, 2, dim=1)
    beta, gamma = torch.squeeze(beta), torch.squeeze(gamma)

    S, I, R = u[:, 0], u[:, 1], u[:, 2]
    out = torch.empty_like(u)
    out[:, 0] = -beta * S * I / N
    out[:, 1] = beta * S * I / N - gamma * I
    out[:, 2] = gamma * I
    return out


class SIR(Task):
    def __init__(
        self,
        p_dist: pdist.Distribution = None,
        N: float = 1000000.0,
        I0: float = 1.0,
        R0: float = 0.0,
        days: float = 160.0,
        saveat: float = 1.0,
        total_count: int = 1000,
        summary: Optional[str] = "subsample",
    ):
        """SIR epidemic model

        Inference is performed for two parameters:
        - Contact rate, beta
        - Mean recovery rate, gamma, (in 1/days)

        Args:
            N: Total population
            I0: Initial number of infected individuals
            R0: Initial number of recovered individuals
            days: Number of days
            saveat: When to save during solving
            summary: Summaries to use

        References:
            [1]: https://jrmihalj.github.io/estimating-transmission-by-fitting-mechanistic-models-in-Stan/
        """
        self.dim_data_raw = int(3 * (days / saveat + 1))

        if summary is None:
            dim_data = self.dim_data_raw
        elif summary == "subsample":
            dim_data = 10
        else:
            raise NotImplementedError
        self.summary = summary
        self.total_count = total_count

        # Observation seeds to use when generating ground truth
        observation_seeds = [
            1000000,  # observation 1
            1000001,  # observation 2
            1000010,  # observation 3*
            1000011,  # observation 4*
            1000004,  # observation 5
            1000005,  # observation 6
            1000006,  # observation 7
            1000013,  # observation 8*
            1000008,  # observation 9
            1000009,  # observation 10
        ]

        super().__init__(
            dim_parameters=2,
            dim_data=dim_data,
            name="sir",
            name_display="SIR",
            num_observations=10,
            num_posterior_samples=10000,
            num_reference_posterior_samples=10000,
            num_simulations=[100, 1000, 10000, 100000, 1000000],
            path=Path(__file__).parent.absolute(),
            observation_seeds=observation_seeds,
        )

        if p_dist is None:
            self.prior_params = {
                "loc": torch.tensor([math.log(0.4), math.log(0.125)]),
                "scale": torch.tensor([0.5, 0.2]),
            }
            self.prior_dist = pdist.LogNormal(**self.prior_params).to_event(1)
            self.prior_dist.set_default_validate_args(False)
        else:
            self.prior_dist = p_dist

        self.u0 = torch.tensor([N - I0 - R0, I0, R0])
        self.tspan = torch.tensor([0.0, days])

        self.days = days
        self.saveat = saveat
        self.N = N

    def get_labels_parameters(self) -> List[str]:
        """Get list containing parameter labels"""
        return [r"$\beta$", r"$\gamma$"]

    def get_prior(self) -> Callable:
        def prior(num_samples=1):
            return pyro.sample("parameters", self.prior_dist.expand_by([num_samples]))

        return prior

    def get_simulator(
        self,
        max_calls: Optional[int] = None,
    ) -> Simulator:
        """Get function returning samples from simulator given parameters

        Args:
            max_calls: Maximum number of function calls. Additional calls will
                result in SimulationBudgetExceeded exceptions. Defaults to None
                for infinite budget

        Return:
            Simulator callable
        """

        def simulator(parameters):
            num_samples = parameters.shape[0]

            u0 = self.u0.expand(num_samples, -1)
            ts = torch.arange(0.0, self.days + self.saveat, self.saveat)
            us = odeint(
                func=lambda t, u: _sir(t, u, parameters, self.N),
                y0=u0,
                t=ts,
            )
            us = us.permute(1, 2, 0)

            idx_contains_nan = torch.where(
                torch.isnan(us.reshape(num_samples, -1)).any(axis=1)
            )[
                0
            ]  # noqa
            idx_contains_no_nan = torch.where(
                ~torch.isnan(us.reshape(num_samples, -1)).any(axis=1)
            )[
                0
            ]  # noqa

            if self.summary is None:
                return us

            elif self.summary == "subsample":
                data = float("nan") * torch.ones((num_samples, self.dim_data))
                if len(idx_contains_nan) == num_samples:
                    return data

                us = us[:, 1, ::17].reshape(num_samples, -1)  # Only use I population
                data[idx_contains_no_nan, :] = pyro.sample(
                    "data",
                    pdist.Binomial(
                        total_count=self.total_count,
                        probs=(us[idx_contains_no_nan, :] / self.N).clamp(0.0, 1.0),
                    ).to_event(1),
                )
                return data

            else:
                raise NotImplementedError

        return Simulator(task=self, simulator=simulator, max_calls=max_calls)

    def unflatten_data(self, data: torch.Tensor) -> torch.Tensor:
        """Unflattens data into multiple observations"""
        if self.summary is None:
            return data.reshape(-1, 3, int(self.dim_data / 3))
        else:
            return data.reshape(-1, self.dim_data)

    def _sample_reference_posterior(
        self,
        num_samples: int,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Sample reference posterior for given observation

        Args:
            num_observation: Observation number
            num_samples: Number of samples to generate
            observation: Observed data, if None, will be loaded using `num_observation`
            kwargs: Passed to run_mcmc

        Returns:
            Samples from reference posterior
        """
        from sbibm.algorithms.pyro.mcmc import run as run_mcmc
        from sbibm.algorithms.pytorch.baseline_rejection import run as run_rejection
        from sbibm.algorithms.pytorch.utils.proposal import get_proposal

        if num_observation is not None:
            initial_params = self.get_true_parameters(num_observation=num_observation)
        else:
            initial_params = None

        proposal_samples = run_mcmc(
            task=self,
            kernel="Slice",
            jit_compile=False,
            num_warmup=10_000,
            num_chains=1,
            num_observation=num_observation,
            observation=observation,
            num_samples=num_samples,
            initial_params=initial_params,
            automatic_transforms_enabled=True,
        )

        proposal_dist = get_proposal(
            task=self,
            samples=proposal_samples,
            prior_weight=0.1,
            bounded=True,
            density_estimator="flow",
            flow_model="nsf",
        )

        samples = run_rejection(
            task=self,
            num_observation=num_observation,
            observation=observation,
            num_samples=num_samples,
            batch_size=10_000,
            num_batches_without_new_max=1_000,
            multiplier_M=1.2,
            proposal_dist=proposal_dist,
        )

        return samples


def _lotka_volterra(t, u, params):
    # u: [batch, 2] of current state
    # params: [batch, 4]
    # Returns Lotka-Volterra vector field: [batch, 2]

    alpha, beta, gamma, delta = torch.tensor_split(params, 4, dim=1)
    alpha, beta, gamma, delta = (
        torch.squeeze(alpha),
        torch.squeeze(beta),
        torch.squeeze(gamma),
        torch.squeeze(delta),
    )

    X = u[:, 0]
    Y = u[:, 1]

    out = torch.empty_like(u)
    out[:, 0] = alpha * X - beta * X * Y
    out[:, 1] = -gamma * Y + delta * X * Y
    return out


class LotkaVolterra(Task):
    def __init__(
        self,
        p_dist: pdist.Distribution = None,
        days: float = 20.0,
        saveat: float = 0.1,
        summary: Optional[str] = "subsample",
    ):
        """Lotka-Volterra model

        Args:
            N: Total population
            I0: Initial number of infected individuals
            R0: Initial number of recovered individuals
            days: Number of days
            saveat: When to save during solving
            summary: Summaries to use

        References:
            [1]: https://mc-stan.org/users/documentation/case-studies/lotka-volterra-predator-prey.html
        """
        self.dim_data_raw = int(2 * (days / saveat + 1))

        if summary is None:
            dim_data = self.dim_data_raw
        elif summary == "subsample":
            dim_data = 20
        else:
            raise NotImplementedError
        self.summary = summary

        # Observation seeds to use when generating ground truth
        observation_seeds = [
            1000020,  # observation 1
            1000030,  # observation 2
            1000034,  # observation 3
            1000013,  # observation 4
            1000004,  # observation 5
            1000011,  # observation 6
            1000012,  # observation 7
            1000039,  # observation 8
            1000041,  # observation 9
            1000009,  # observation 10
        ]

        super().__init__(
            dim_parameters=4,
            dim_data=dim_data,
            name="lotka_volterra",
            name_display="Lotka-Volterra",
            num_observations=len(observation_seeds),
            num_posterior_samples=10000,
            num_reference_posterior_samples=10000,
            num_simulations=[100, 1000, 10000, 100000, 1000000],
            path=Path(__file__).parent.absolute(),
            observation_seeds=observation_seeds,
        )

        # Prior
        mu_p1 = -0.125
        mu_p2 = -3.0
        sigma_p = 0.5

        if p_dist is None:
            self.prior_params = {
                "loc": torch.tensor([mu_p1, mu_p2, mu_p1, mu_p2]),
                "scale": torch.tensor([sigma_p, sigma_p, sigma_p, sigma_p]),
            }
            self.prior_dist = pdist.LogNormal(**self.prior_params).to_event(1)
            self.prior_dist.set_default_validate_args(False)
        else:
            self.prior_dist = p_dist

        self.u0 = torch.tensor([30.0, 1.0])
        self.tspan = torch.tensor([0.0, days])
        self.days = days
        self.saveat = saveat

        # NOTE: For subsample statistic
        self.total_count = 1000  # TODO: Value?

    def get_labels_parameters(self) -> List[str]:
        """Get list containing parameter labels"""
        return [r"$\alpha$", r"$\beta$", r"$\gamma$", r"$\delta$"]

    def get_prior(self) -> Callable:
        def prior(num_samples=1):
            return pyro.sample("parameters", self.prior_dist.expand_by([num_samples]))

        return prior

    def get_simulator(
        self,
        max_calls: Optional[int] = None,
    ) -> Simulator:
        """Get function returning samples from simulator given parameters

        Args:
            max_calls: Maximum number of function calls. Additional calls will
                result in SimulationBudgetExceeded exceptions. Defaults to None
                for infinite budget

        Return:
            Simulator callable
        """

        def simulator(parameters):
            num_samples = parameters.shape[0]

            u0 = self.u0.expand(num_samples, -1)
            ts = torch.arange(start=0.0, end=self.days + self.saveat, step=self.saveat)
            us = odeint(
                func=lambda t, p: _lotka_volterra(t, p, parameters),
                y0=u0,
                t=ts,
            )
            us = us.permute(1, 2, 0)

            idx_contains_nan = torch.where(
                torch.isnan(us.reshape(num_samples, -1)).any(axis=1)
            )[
                0
            ]  # noqa
            idx_contains_no_nan = torch.where(
                ~torch.isnan(us.reshape(num_samples, -1)).any(axis=1)
            )[
                0
            ]  # noqa

            if self.summary is None:
                return us

            elif self.summary == "subsample":
                data = float("nan") * torch.ones((num_samples, self.dim_data))
                if len(idx_contains_nan) == num_samples:
                    return data

                us = us[:, :, ::21].reshape(num_samples, -1)
                data[idx_contains_no_nan, :] = pyro.sample(
                    "data",
                    pdist.LogNormal(
                        loc=torch.log(us[idx_contains_no_nan, :].clamp(1e-10, 10000.0)),
                        scale=0.1,
                    ).to_event(1),
                )
                return data

            else:
                raise NotImplementedError

        return Simulator(task=self, simulator=simulator, max_calls=max_calls)

    def unflatten_data(self, data: torch.Tensor) -> torch.Tensor:
        """Unflattens data into multiple observations"""
        if self.summary is None:
            return data.reshape(-1, 2, int(self.dim_data / 2))
        else:
            return data.reshape(-1, self.dim_data)

    def _sample_reference_posterior(
        self,
        num_samples: int,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Sample reference posterior for given observation

        Args:
            num_observation: Observation number
            num_samples: Number of samples to generate
            observation: Observed data, if None, will be loaded using `num_observation`
            kwargs: Passed to run_mcmc

        Returns:
            Samples from reference posterior
        """
        from sbibm.algorithms.pyro.mcmc import run as run_mcmc
        from sbibm.algorithms.pytorch.baseline_rejection import run as run_rejection
        from sbibm.algorithms.pytorch.utils.proposal import get_proposal

        if num_observation is not None:
            initial_params = self.get_true_parameters(num_observation=num_observation)
        else:
            initial_params = None

        proposal_samples = run_mcmc(
            task=self,
            kernel="Slice",
            jit_compile=False,
            num_warmup=10_000,
            num_chains=1,
            num_observation=num_observation,
            observation=observation,
            num_samples=num_samples,
            initial_params=initial_params,
            automatic_transforms_enabled=True,
        )

        proposal_dist = get_proposal(
            task=self,
            samples=proposal_samples,
            prior_weight=0.1,
            bounded=True,
            density_estimator="flow",
            flow_model="nsf",
        )

        samples = run_rejection(
            task=self,
            num_observation=num_observation,
            observation=observation,
            num_samples=num_samples,
            batch_size=10_000,
            num_batches_without_new_max=1_000,
            multiplier_M=1.2,
            proposal_dist=proposal_dist,
        )

        return samples


class GaussianLinear(Task):
    def __init__(
        self,
        p_dist: pdist.Distribution = None,
        dim: int = 10,
        prior_scale: float = 0.1,
        simulator_scale: float = 0.1,
    ):
        super().__init__(
            dim_parameters=dim,
            dim_data=dim,
            name=Path(__file__).parent.name,
            name_display="Gaussian Linear",
            num_observations=10,
            num_posterior_samples=10000,
            num_reference_posterior_samples=10000,
            num_simulations=[100, 1000, 10000, 100000, 1000000],
            path=Path(__file__).parent.absolute(),
        )

        self.dim = dim
        # simulator noise precision
        self.sim_precision = torch.inverse(simulator_scale * torch.eye(dim))

        # set or build prior
        if p_dist is None:
            self.prior_dist = MultivariateNormal(
                loc=torch.zeros(dim),
                precision_matrix=torch.inverse(prior_scale * torch.eye(dim)),
            )
        else:
            self.prior_dist = p_dist

        # detect prior type
        if isinstance(self.prior_dist, MultivariateNormal):
            self.prior_type = "gaussian"
            self.prior_precision = self.prior_dist.precision_matrix  # [dim,dim]
            self.prior_mean = self.prior_dist.loc  # [dim]
        elif isinstance(self.prior_dist, MixtureSameFamily):
            self.prior_type = "mixture"
            mix = self.prior_dist.mixture_distribution  # Categorical
            comp = (
                self.prior_dist.component_distribution
            )  # MultivariateNormal batch_shape=[K]
            self.prior_weights = mix.probs  # [K]
            self.comp_means = comp.loc  # [K,dim]
            self.comp_covs = comp.covariance_matrix  # [K,dim,dim]
        elif isinstance(self.prior_dist, Uniform):
            self.prior_type = "uniform"
        else:
            raise ValueError(f"Unsupported prior type: {type(self.prior_dist)}")

    def get_prior(self) -> Callable:
        def prior(num_samples=1):
            return pyro.sample("parameters", self.prior_dist.expand_by([num_samples]))

        return prior

    def get_simulator(self, max_calls=None):
        def simulator(parameters):
            return pyro.sample(
                "data",
                MultivariateNormal(
                    loc=parameters,
                    precision_matrix=self.sim_precision,
                ),
            )

        return Simulator(task=self, simulator=simulator, max_calls=max_calls)

    def _get_reference_posterior(
        self,
        num_observation=None,
        observation=None,
    ):
        # must pass exactly one of num_observation or observation
        assert (num_observation is None) ^ (observation is None)
        if num_observation is not None:
            observation = self.get_observation(num_observation)

        # single-datum case
        x = observation.reshape(-1)  # [dim]
        n = 1

        if self.prior_type == "gaussian":
            # standard Gaussian–Gaussian update
            precision_post = self.prior_precision + n * self.sim_precision
            cov_post = torch.inverse(precision_post)
            mean_post = cov_post @ (
                self.prior_precision @ self.prior_mean + n * (self.sim_precision @ x)
            )
            return MultivariateNormal(loc=mean_post, covariance_matrix=cov_post)

        elif self.prior_type == "mixture":
            # Gaussian‐mixture prior → Gaussian‐mixture posterior
            noise_cov = torch.inverse(self.sim_precision)  # [dim,dim]
            K = self.prior_weights.size(0)

            post_means = []
            post_covs = []
            for k in range(K):
                prior_cov_k = self.comp_covs[k]  # [dim,dim]
                prior_prec_k = torch.inverse(prior_cov_k)
                post_prec_k = prior_prec_k + n * self.sim_precision
                post_cov_k = torch.inverse(post_prec_k)
                post_mean_k = post_cov_k @ (
                    prior_prec_k @ self.comp_means[k] + n * (self.sim_precision @ x)
                )
                post_means.append(post_mean_k)
                post_covs.append(post_cov_k)

            post_means = torch.stack(post_means, dim=0)  # [K,dim]
            post_covs = torch.stack(post_covs, dim=0)  # [K,dim,dim]

            # update mixture weights via p(x|component)
            pred_dist = MultivariateNormal(
                loc=self.comp_means, covariance_matrix=self.comp_covs + noise_cov
            )
            log_px = pred_dist.log_prob(x.expand_as(self.comp_means))  # [K]
            log_w = torch.log(self.prior_weights) + log_px
            post_weights = torch.softmax(log_w, dim=0)  # [K]

            post_comp = MultivariateNormal(loc=post_means, covariance_matrix=post_covs)
            post_mix = Categorical(probs=post_weights)
            return MixtureSameFamily(
                mixture_distribution=post_mix, component_distribution=post_comp
            )

        elif self.prior_type == "uniform":
            # flat prior ⇒ posterior ∝ likelihood
            post_prec = n * self.sim_precision
            post_cov = torch.inverse(post_prec)
            post_mean = x
            return MultivariateNormal(loc=post_mean, covariance_matrix=post_cov)

        else:
            raise RuntimeError("Unknown prior type in posterior computation")

    def _sample_reference_posterior(
        self,
        num_samples,
        num_observation=None,
        observation=None,
    ):
        post = self._get_reference_posterior(
            num_observation=num_observation,
            observation=observation,
        )
        return post.sample((num_samples,))


if __name__ == "__main__":
    # slcp = SLCP()

    # prior_dist = slcp.get_prior_dist()
    # simulator = slcp.get_simulator()

    # theta = prior_dist.sample((1000,))
    # x = simulator(theta)

    gl = GaussianLinear()
    prior_dist = gl.get_prior_dist()
    simulator = gl.get_simulator()

    theta = prior_dist.sample((1000,))
    x = simulator(theta)
