from functools import partial
from pathlib import Path
from typing import Callable

import jax
import jax.numpy as jnp
import scipy.stats.qmc as qmc
from chex import Array, PRNGKey
from matplotlib import pyplot as plt

from sves.benchmarks import Benchmark
from sves.benchmarks.utils.spline import MyBSpline
from sves.plotting import plot_pdf
from sves.kernels import RBF, KDE


class RamosPaper(Benchmark):
    """Inspired by example from paper:
        'Path Signatures for Diversity in Probabilistic Trajectory Optimisation'
        by Barcelos et al. (2023)
        Code: https://github.com/lubaroli/sigsvgd
    """
    def __init__(
        self,
        lb: Array | float = 0.,
        ub: Array | float = 5.,
        target: tuple = (4.5, 4.),
        n_via: int = 2,
        sim_steps: int = 100,
        dim: int = 2,
        fglob: float = 0.,
        name: str = "F_RAMOS"
    ):
        self.n_via = n_via
        self.sim_steps = sim_steps
        self.bot_dim = dim
        self.qinit = jnp.array([0.25, 0.75])
        self.target = jnp.array(target)
        self.bspl = MyBSpline(n_via + 2, sim_steps)

        super().__init__(lb, ub, n_via * dim, fglob, name)

        # Sample Halton sequence
        n_obstacles = 15  # constant from code in git
        samples = qmc.Halton(2, seed=123).random(n_obstacles)
        self.means = qmc.scale(samples, lb + .5, ub - .5)
        self.cov = 0.25
        self.prior = jax.vmap(KDE(RBF(), self.means, bandwidth=self.cov ** 2))

    def get_objective_derivative(self) -> [Callable, Callable]:
        """Return evalautable objectives."""
        return lambda x: self.log_likelihood(x), jax.vmap(jax.grad(lambda x: -self.log_likelihood(x).squeeze()))

    def get_trajectories(self, x: Array) -> [Array, Array, Array]:
        q, qDot, qdDot = self.bspl(x)
        return q, qDot, qdDot, x

    @partial(jax.jit, static_argnums=(0,))
    def log_likelihood(self, candidates: Array) -> Array:
        # Pre- and append init and target states
        x = candidates.reshape(-1, self.n_via, self.bot_dim)
        x = jnp.concatenate([jnp.tile(self.qinit, (x.shape[0], 1))[:, None], x], axis=-2)
        x = jnp.concatenate((x, jnp.array([self.target] * x.shape[0])[:, None]), axis=1)
        q, _, _, x = self.get_trajectories(x)

        qDot_approx = q[:, 1:] - q[:, :-1]
        probability_term = jax.vmap(self.prior, 1)(q).sum(axis=0)   # Sum the prior along the trajectory
        smoothness_term = jnp.sqrt(jnp.einsum("ijk,ijk->ik", qDot_approx, qDot_approx)).sum(axis=-1)

        return probability_term + smoothness_term * 5.

    def plot(self, candidates: Array, lb: Array, ub: Array) -> None:
        # Create figure with constrained layout
        plt.figure(figsize=(6, 6), constrained_layout=True)

        # Add the obstacles
        plot_pdf(lambda x: self.prior(x), None, lb, ub)

        # Scatter candidates
        candidates, fitness = candidates
        x = candidates.reshape(-1, self.n_via, self.bot_dim)
        x = jnp.concatenate([jnp.tile(self.qinit, (x.shape[0], 1))[:, None], x], axis=-2)
        x = jnp.concatenate((x, jnp.array([self.target] * x.shape[0])[:, None]), axis=1)
        q, qDot, qdDot, _ = self.get_trajectories(x)
        plt.scatter(*self.target, marker="*", s=80, c="red")
        plt.scatter(*x.T, c="orange", alpha=.2)

        # Plot the paths with fitness-based coloring
        for i in range(q.shape[0]):
            plt.plot(*(q[i].T), c="salmon", alpha=0.9, linewidth=1.5)

        # Set the x and y limits
        plt.xlim([lb, ub])
        plt.ylim([lb, ub])
        plt.xticks([])
        plt.yticks([])

    @staticmethod
    def sample(rng: PRNGKey, num_samples: int) -> Array:
        """Load samples here.

        Sampling can only be done via MCMC methods or SGLD.
        Check for script in scripts/.
        """
        import pickle

        storage_path = Path(__file__).resolve().parent.parent / "data"

        try:
            with open(storage_path / f"ramos_eval/gt_samples{num_samples}.pkl", "rb") as handler:
                gt_samples = pickle.load(handler)
                return gt_samples
        except FileNotFoundError as e:
            print(e)
