from abc import abstractmethod
from collections.abc import Callable
from chex import Array
import jax.numpy as jnp

from sves.plotting import plot_particles_pdf


class Benchmark:
    def __init__(self, lb: Array | float, ub: Array | float, dim: int, fglob: float, name: str) -> None:
        self.lower_bounds = lb
        self.upper_bounds = ub
        self.dim = dim
        self.fglob = fglob
        self.name = name

    @abstractmethod
    def get_objective_derivative(self) -> [Callable, Callable]:
        pass

    def get_objective(self) -> Callable:
        return self.get_objective_derivative()[0]

    def plot(self, x: Array, lb: Array, ub: Array):
        """Plotting code. Works for synthetic benchmarks."""
        plot_particles_pdf(
            x[0],
            lambda y: jnp.exp(-self.get_objective()(y)),
            None,
            lb,
            ub
        )
