from typing import Optional

import jax
from flax.struct import dataclass
from chex import Array
import jax.numpy as jnp


def pairwise_distances(X):
    """Compute the squared Euclidean distance between points in X."""
    diff = X[:, None, :] - X[None, :, :]
    sq_distances = jnp.sum(diff ** 2, axis=-1)
    return sq_distances


@dataclass
class Kernel:
    def __init__(self):
        pass

    def __call__(self, x1: Array, x2: Array, bandwidth: float) -> Array:
        pass

    @staticmethod
    def median_heuristic(X):
        """Estimate the median heuristic for the lengthscale based on the median of pairwise squared distances."""
        dists = pairwise_distances(X)
        median_sq_dist = jnp.median(dists)
        median_ls = median_sq_dist / jnp.log(X.shape[0] + 1)
        return jnp.maximum(median_ls, 1e-8)


class RBF(Kernel):
    bandwidth: float

    def __init__(self, bandwidth: float = 1.):
        self.bandwidth = bandwidth

    def __call__(self, x1: Array, x2: Array, bandwidth: float) -> Array:
        return jnp.exp(-0.5 * jnp.sum((x1 - x2) ** 2) / bandwidth)


def KDE(kernel: Kernel, modes: Array, weights: Optional[Array] = None, bandwidth: float = 1.):
    """Kernel density estimation."""
    if weights is None:
        weights = jnp.ones(modes.shape[0]) / modes.shape[0]
    return lambda xi: jnp.sum(
        jax.vmap(kernel, in_axes=(None, 0, None))(xi, modes, bandwidth) * weights
    )
