"""1D Gaussian mixture from original SVGD paper"""

from typing import Callable

import numpy as np
from scipy.optimize import brentq
from scipy.stats import norm


def _score_fn(x):
    g1 = np.exp(-np.square(x + 2) / 2)
    g2 = np.exp(-np.square(x - 2) / 2)
    fallback = np.copy(-x)
    return np.divide(
        -1 * ((x + 2) * g1 + 2 * (x - 2) * g2),
        g1 + 2 * g2,
        where=(g1 + 2 * g2) != 0,
        out=fallback,
    )


def _pi(x):
    g1 = np.exp(-np.square(x + 2) / 2) / np.sqrt(2 * np.pi)
    g2 = np.exp(-np.square(x - 2) / 2) / np.sqrt(2 * np.pi)
    return g1 / 3 + 2 * g2 / 3


def _cdf(x):
    return norm.cdf(x, loc=-2) / 3 + 2 * norm.cdf(x, loc=2) / 3


class GMixture1D:
    def ppf(self, qs):
        return np.array([brentq(lambda x: _cdf(x) - q, -10, 10) for q in qs])


def generate_example(
    N: int, rng: np.random.Generator
) -> tuple[
    np.ndarray, Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]
]:
    """1D Gaussian mixture from original SVGD paper"""
    x0 = rng.normal(0, 1, size=(N, 1))
    return x0, _score_fn, _pi
