from collections.abc import Callable
from functools import partial
from pathlib import Path

import jax.numpy as jnp
import jax
import chex

from sves.kernels import RBF, KDE
from sves.benchmarks import Benchmark


class GMM(Benchmark):
    def __init__(
        self,
        rng: chex.PRNGKey,
        lb: float = -6.,
        ub: float = 6.,
        kernel_rad: float = 1.,
        n_modes: int = 4,
        dim: int = 2,
        name: str = "GMM"
    ) -> None:
        fglob = -1 / (kernel_rad * (2 ** dim + 1))
        super().__init__(lb, ub, dim, fglob, name)
        self.kernel_rad = kernel_rad

        # Instantiate problem
        rng_w, rng_m = jax.random.split(rng)
        self.weights = jax.random.uniform(rng_w, (n_modes,), minval=0., maxval=10.)
        self.weights /= jnp.sum(self.weights)
        self.modes = jax.random.uniform(rng_m, (n_modes, dim), minval=lb+2, maxval=ub-2)

    def get_objective_derivative(self) -> [Callable, Callable]:
        """Return the objective and its derivative functions."""
        eval_fn = jax.jit(lambda x: -jnp.log(KDE(RBF(), self.modes, self.weights, self.kernel_rad)(x)))
        return jax.vmap(eval_fn), jax.vmap(jax.grad(lambda x: -eval_fn(x)))

    def sample(self, rng: chex.PRNGKey, num_samples: int) -> chex.Array:
        component_indices = jax.random.choice(rng, self.modes.shape[0], p=self.weights, shape=(num_samples,))
        normal_samples = jax.random.normal(
            rng,
            (num_samples, self.dim)
        ) * self.kernel_rad
        samples = normal_samples + self.modes[component_indices]
        return samples


class DoubleBanana(Benchmark):
    def __init__(
        self,
        rng: chex.PRNGKey,
        lb: float = -2.,
        ub: float = 2.,
        dim: int = 2,
        name: str = "Double Banana"
    ) -> None:
        """Code: from Matrix valued kernel SVGD paper.

        https://github.com/dilinwang820/matrix_svgd/blob/master/2d_toy/code/environment.py."""
        super().__init__(lb, ub, dim, 0., name)
        self.sig1 = 1.
        self.sig2 = .09
        self.y = jnp.log(30)

    def get_objective_derivative(self) -> [Callable, Callable]:
        """Exploration benchmark"""
        return jax.vmap(lambda x: -self.eval_fun(x)), jax.vmap(jax.grad(lambda x: self.eval_fun(x)))

    @partial(jax.jit, static_argnums=(0,))
    def eval_fun(self, x):
        fx = jnp.log((1 - x[0])**2 + 100 * (x[1] - x[0]**2)**2 + 1e-10)
        p = x.T @ x / (2 * self.sig1) + (self.y - fx) ** 2 / (2 * self.sig2)
        return -p

    def sample(self, rng: chex.PRNGKey, num_samples: int) -> chex.Array:
        """Load samples here.

        Sampling can only be done via MCMC methods or SGLD.
        Check for script in scripts/.
        """
        import pickle

        storage_path = Path(__file__).resolve().parent.parent / "data"

        try:
            with open(storage_path / f"banana/gt_samples{num_samples}.pkl", "rb") as handler:
                gt_samples = pickle.load(handler)
                return gt_samples
        except FileNotFoundError as e:
            print(e)
