from typing import NamedTuple
from jaxtyping import PyTree

# TODO: line search SGD + preconditionned SGLD + think of flat landscape for CSLGD
# Learning with Differentiable Perturbed Optimizers
# in CSGLD, think of scaling variance of random walk with flatness of landscape


class ParticlesApprox(NamedTuple):
    """
    Represents particles approximation.

    Attributes:
        thetas (PyTreeDef): The particle positions.
        weights (PyTreeDef): The particle weights.
    """

    thetas: PyTree
    weights: PyTree
