import abc
from copy import deepcopy
from typing import Optional

import jax
import numpy as np
import torch

import jax.numpy as jnp


class DifferentiableTemperedPosterior(abc.ABC):
    def __init__(self, log_likelihood: callable, log_prior: callable, parameter_transform: callable):
        """
        Class that enables the computation of the tempered posterior and its gradients w.r.t. x.

        :param parameter_transform: callable that takes a parameter array and returns the transformed parameter array
        along with the corresponding log jacobian.
        :param log_likelihood: callable that returns an array of log likelihoods given transformed parameters.
        :param log_prior: callable that returns an array of log priors given transformed parameters.
        """
        self._log_likelihood_fn = log_likelihood
        self._log_prior_fn = log_prior
        self._parameter_transform_fn = parameter_transform

        self._log_likelihood_cache = None
        self._log_likelihood_gradient_cache = None
        self._log_prior_cache = None
        self._log_prior_gradient_cache = None

    def parameter_transform(self, x, **kwargs):
        return self._parameter_transform_fn(x, **kwargs)[0]

    def clear_cache(self):
        """
        Clear the computed log likelihood, log prior, and their gradients by setting them to None.
        """
        self._log_likelihood_cache = None
        self._log_likelihood_gradient_cache = None
        self._log_prior_cache = None
        self._log_prior_gradient_cache = None

    @abc.abstractmethod
    def log_likelihood(self, x, **kwargs):
        raise NotImplementedError

    @abc.abstractmethod
    def log_prior(self, x, **kwargs):
        raise NotImplementedError

    @abc.abstractmethod
    def log_likelihood_gradient(self, x, **kwargs):
        raise NotImplementedError

    @abc.abstractmethod
    def log_prior_gradient(self, x, **kwargs):
        raise NotImplementedError

    def value(self, beta: float, x: Optional[torch.Tensor], **kwargs) -> torch.Tensor:
        """
        Compute the value of the log tempered posterior.
        """
        return beta * self.log_likelihood(x, **kwargs) + self.log_prior(x, **kwargs)

    def gradient(self, beta: float, x: Optional[torch.Tensor], **kwargs) -> torch.Tensor:
        """
        Compute the gradient of the log tempered posterior w.r.t. x.
        """
        return beta * self.log_likelihood_gradient(x, **kwargs) + self.log_prior_gradient(x, **kwargs)

    def compute_all(self, beta: float, x: torch.Tensor, compute_gradients: bool = True, **kwargs):
        self.value(beta, x, **kwargs)
        if compute_gradients:
            self.gradient(beta, x, **kwargs)


class TorchPosterior(DifferentiableTemperedPosterior):
    def __init__(self, log_likelihood: callable, log_prior: callable, parameter_transform: callable = None):
        super().__init__(log_likelihood, log_prior, parameter_transform)

    def log_likelihood(self, x, **kwargs):
        if self._log_likelihood_cache is not None:
            return self._log_likelihood_cache

        if self._parameter_transform_fn is not None:
            x, _ = self._parameter_transform_fn(x)  # logj has no effect here
        value = self._log_likelihood_fn(x, **kwargs)
        self._log_likelihood_cache = value
        return value

    def log_prior(self, x, **kwargs):
        if self._log_prior_cache is not None:
            return self._log_prior_cache

        if self._parameter_transform_fn is not None:
            x, logj = self._parameter_transform_fn(x)
        value = self._log_prior_fn(x, **kwargs)
        self._log_prior_cache = value
        return value

    def log_likelihood_gradient(self, x, **kwargs):
        if self._log_likelihood_gradient_cache is not None:
            return self._log_likelihood_gradient_cache

        x.requires_grad_(True)
        value = self.log_likelihood(x, **kwargs)
        grad = torch.autograd.grad(value.sum(), x, create_graph=True)[0]
        self._log_likelihood_gradient_cache = grad
        x.requires_grad_(False)
        return grad

    def log_prior_gradient(self, x, **kwargs):
        if self._log_prior_gradient_cache is not None:
            return self._log_prior_gradient_cache

        x.requires_grad_(True)
        value = self.log_prior(x, **kwargs)
        grad = torch.autograd.grad(value.sum(), x, create_graph=True)[0]
        self._log_prior_gradient_cache = grad
        x.requires_grad_(False)
        return grad


# class TorchPosterior(DifferentiableTemperedPosterior):
#     def __init__(self, log_likelihood: callable, log_prior: callable):
#         """
#         Differentiable tempered posterior in torch.
#
#         :param log_likelihood: callable that returns a torch Tensor of log likelihoods.
#         :param log_prior: callable that returns a torch Tensor of log priors.
#         """
#         super().__init__(log_likelihood=log_likelihood, log_prior=log_prior)
#
#     def compute_all(self, x: torch.Tensor, compute_gradients: bool = True, **kwargs):
#         """
#         Compute the log likelihood, log prior, gradient of log likelihood w.r.t. x, and gradient of log prior w.r.t. x.
#         The computed values are stored in self._log_likelihood_cache, self._log_prior_cache,
#         self._log_likelihood_gradient_cache, self._log_prior_gradient_cache.
#
#         :param x: DLA particles.
#         :param compute_gradients: if False, compute log likelihood and log prior values, but not their gradients.
#         :param kwargs: keyword arguments for the log likelihood and log prior functions.
#         TODO it would probably be better to split these keyword arguments into two parts.
#         """
#         # Likelihood computations
#         x1 = deepcopy(x.detach())
#         x1.requires_grad_(True)
#
#         log_likelihood = self._log_likelihood_fn(x1, **kwargs)
#         self._log_likelihood_cache = log_likelihood.detach()
#
#         if compute_gradients:
#             log_likelihood_gradient = torch.autograd.grad(log_likelihood.sum(), x1, create_graph=True)[0]
#             self._log_likelihood_gradient_cache = log_likelihood_gradient.detach()
#
#         # Prior computations
#         x2 = deepcopy(x.detach())
#         x2.requires_grad_(True)
#
#         log_prior = self._log_prior_fn(x2, **kwargs)
#         self._log_prior_cache = log_prior.detach()
#
#         if compute_gradients:
#             log_prior_gradient = torch.autograd.grad(log_prior.sum(), x2, create_graph=True)[0]
#             self._log_prior_gradient_cache = log_prior_gradient.detach()


class JaxPosterior(DifferentiableTemperedPosterior):
    def __init__(self, log_likelihood: callable, log_prior: callable, parameter_transform: callable = None):
        """
        Differentiable tempered posterior in Jax.
        Data is internally stored using torch Tensors, but the likelihood and prior computations are in Jax.

        :param log_likelihood: callable that returns a jax array of log likelihoods.
        :param log_prior: callable that returns a jax array of log priors.
        """

        def log_likelihood_wrapper(x):
            if parameter_transform is None:
                value = log_likelihood(x)
            else:
                value = log_likelihood(parameter_transform(x)[0])
            return jnp.sum(value)

        def log_prior_wrapper(x):
            if parameter_transform is None:
                value = log_prior(x)
            else:
                x, logj = parameter_transform(x)
                value = log_prior(x) + logj
            return jnp.sum(value)

        super().__init__(log_likelihood_wrapper, log_prior_wrapper, parameter_transform)

        self._log_likelihood_grad_fn = jax.grad(self._log_likelihood_fn)
        self._log_prior_grad_fn = jax.grad(self._log_prior_fn)

        self._log_likelihood_cache = None
        self._log_likelihood_gradient_cache = None
        self._log_prior_cache = None
        self._log_prior_gradient_cache = None

    def log_likelihood(self, x, **kwargs):
        if self._log_likelihood_cache is not None:
            return self._log_likelihood_cache

        x = jnp.array(deepcopy(x.detach()))
        value = self._log_likelihood_fn(x, **kwargs)
        self._log_likelihood_cache = torch.tensor(np.array(value), requires_grad=False)

        # if compute_gradients:
        #     log_likelihood_gradient = self._log_likelihood_grad_fn(x1)
        #     self._log_likelihood_gradient_cache = torch.tensor(np.array(log_likelihood_gradient), requires_grad=False)
        return value

    def log_prior(self, x, **kwargs):
        if self._log_prior_cache is not None:
            return self._log_prior_cache

        x = jnp.array(deepcopy(x.detach()))
        value = self._log_prior_fn(x, **kwargs)
        self._log_prior_cache = torch.tensor(np.array(value), requires_grad=False)
        return value

    def log_likelihood_gradient(self, x, **kwargs):
        if self._log_likelihood_gradient_cache is not None:
            return self._log_likelihood_gradient_cache

        x.requires_grad_(True)
        value = self.log_likelihood(x, **kwargs)
        grad = torch.autograd.grad(value.sum(), x, create_graph=True)[0]
        self._log_likelihood_gradient_cache = grad
        x.requires_grad_(False)
        return grad

    def log_prior_gradient(self, x, **kwargs):
        if self._log_prior_gradient_cache is not None:
            return self._log_prior_gradient_cache

        #x = jnp.array(deepcopy(x.detach()))
        x.requires_grad_(True)
        value = self.log_prior(x, **kwargs)
        grad = torch.autograd.grad(value.sum(), x, create_graph=True)[0]
        self._log_prior_gradient_cache = grad
        x.requires_grad_(False)
        return grad
