import logging
from typing import Any, Dict, Optional, Tuple

import torch
import torch.nn as nn
from sbi import inference as inference
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential
from torch import Tensor

from sbibm.algorithms.sbi.utils import (
    wrap_posterior,
    wrap_prior_dist,
    wrap_simulator_fn,
)
from sbibm.tasks.task import Task


class SynthLikNet(nn.Module):
    """
    bootstrap a per-sample likelihood from input prior values `theta` under
    the assumption that all `x_i = sim(theta_i)` were generated by a multivariate
    normal along `x_i.shape[-1]`.

    Class to be used with SBI's SNLE algorithms. We used SynthLikNet as a quick
    implementation of the classical synthetic likelihood algorithm published by
    Wood (2010). It is reported as SL on the website and in the paper's appendix.

    Wood, S. N. (2010). "Statistical inference for noisy nonlinear ecological
    dynamic systems". Nature 466 (August), 1102–1104.
    """

    def __init__(self, simulator, num_simulations_per_step=100, diag_eps=0.0):
        super(SynthLikNet, self).__init__()
        self.simulator = simulator
        self.num_simulations_per_step = num_simulations_per_step
        self.diag_eps = diag_eps

        # dummy parameter to make LikelihoodBasedPosterior happy
        # (deduces device from SynthLikNet.parameters())
        self.estimated_mn = torch.nn.parameter.UninitializedParameter(
            requires_grad=False
        )

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def log_prob(self, inputs: Tensor, context=Optional[Tensor]) -> Tensor:
        thetas = context
        observation = inputs

        log_probs = []
        for i in range(thetas.shape[0]):
            xs = self.simulator(
                thetas[i, :].reshape(1, -1).repeat(self.num_simulations_per_step, 1)
            )

            # Estimate mean
            m = torch.mean(xs, dim=0)

            # estimate covariance using unbiased sample variance
            xm = xs - m
            S = torch.matmul(xm.T, xm) / (xs.shape[0] - 1)
            S = S + self.diag_eps * torch.eye(xs.shape[1])

            # Score observations
            dist = torch.distributions.MultivariateNormal(
                loc=m,
                covariance_matrix=S,
                validate_args=False,  # to discard expensive check for psd'ness
            )

            log_probs.append(dist.log_prob(observation[i, :].reshape(1, -1)))

        return torch.cat(log_probs)

    def sample(self, *args, **kwargs):
        raise NotImplementedError

    def train(self, *args, **kwargs):
        pass

    def eval(self, *args, **kwargs):
        pass


def run(
    task: Task,
    num_samples: int,
    num_simulations: int,
    num_simulations_per_step: int = 100,
    num_observation: Optional[int] = None,
    observation: Optional[torch.Tensor] = None,
    automatic_transforms_enabled: bool = False,
    mcmc_method: str = "slice_np",
    mcmc_parameters: Dict[str, Any] = {},
    diag_eps: float = 0.0,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
    """Runs (S)NLE from `sbi`

    Args:
        task: Task instance
        num_observation: Observation number to load, alternative to `observation`
        observation: Observation, alternative to `num_observation`
        num_samples: Number of samples to generate from posterior
        num_simulations: Simulation budget
        num_simulations_per_step: Number of simulations per MCMC step
        automatic_transforms_enabled: Whether to enable automatic transforms
        mcmc_method: MCMC method
        mcmc_parameters: MCMC parameters
        diag_eps: Epsilon applied to diagonal

    Returns:
        Samples from posterior, number of simulator calls, log probability of true params if computable
    """
    assert not (num_observation is None and observation is None)
    assert not (num_observation is not None and observation is not None)

    log = logging.getLogger(__name__)

    log.info(f"Running SL")

    prior = task.get_prior_dist()
    if observation is None:
        observation = task.get_observation(num_observation)

    simulator = task.get_simulator()

    transforms = task._get_transforms(automatic_transforms_enabled)["parameters"]
    if automatic_transforms_enabled:
        prior = wrap_prior_dist(prior, transforms)
        simulator = wrap_simulator_fn(simulator, transforms)

    likelihood_estimator = SynthLikNet(
        simulator=simulator,
        num_simulations_per_step=num_simulations_per_step,
        diag_eps=diag_eps,
    )

    potential_fn, theta_transform = likelihood_estimator_based_potential(
        likelihood_estimator=likelihood_estimator,
        prior=prior,
        x_o=None,
        enable_transform=not automatic_transforms_enabled,
    )
    posterior = MCMCPosterior(
        potential_fn=potential_fn,
        proposal=prior,
        theta_transform=theta_transform,
        method=mcmc_method,
        x_shape=observation.shape,
        **mcmc_parameters,
    )

    posterior.set_default_x(observation)

    posterior = wrap_posterior(posterior, transforms)

    # assert simulator.num_simulations == num_simulations

    samples = posterior.sample((num_samples,)).detach()

    return samples, simulator.num_simulations, None
