from typing import Optional

import numpy as np
import matplotlib.pyplot as plt


def plot_pdf(pdf, ax: Optional = None, xmin: float=-3., xmax: float=3.):
    x = np.linspace(xmin, xmax, 50)
    x_ = np.stack(np.meshgrid(x, x), axis=-1).reshape(-1, 2)
    energies = pdf(x_)

    plt.figure(figsize=(7, 7))
    if ax is not None:
        ax.contourf(x, x, energies.reshape(50, 50), levels=20)
    else:
        plt.contourf(x, x, energies.reshape(50, 50), levels=20)


def plot_particles_pdf(x, objective, score_fn, xmin: float = -3., xmax: float = 3.):
    plot_pdf(objective, score_fn, xmin, xmax)
    plt.scatter(*x.T, color="salmon")
    plt.xlim(xmin, xmax)
    plt.ylim(xmin, xmax)
