import numpy as np
from typing import Optional, Callable, Tuple
from numpy import ndarray as Array
from dynamic import integrate
from tqdm import tqdm
from scipy.special import logsumexp
from torch import Tensor
from diffusion.denoiser import Lorenz63Denoiser, ConditionalMMPSDenoiser
from azula.sample import DDIMSampler
import torch


def forecast(
    x_0: Array,
    y: Array,
    dt: float = 1e-3,
    sigma: float = 0.25,
    obs_dt: float = 0.5,
    verbose: bool = True,
) -> Array:
    """
    Forecast the particles without using observations as a reference for the results.
    Input(s):
        - x_0 (Array): input particles drawn from p(x^{0}) with dimension (num_particles, 3).
        - y (Array): observations with dimension (num_assim_steps, d).
        - dt (float): time step used to integrate the system.
        - sigma (float): noise level of the stochastic term in the stochastic L63 Equation.
        - obs_dt (float): time interval between two observed states.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, num_particles, 3).
    """
    # Instanciate the output and the particles
    num_assim_steps, num_particles = y.shape[0], x_0.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, num_particles, 3))
    posteriors[0], particles = x_0, x_0.copy()

    # Get the number of integration to transition from x^{k} to x^{k+1}
    num_steps = int(obs_dt / dt)

    # Loop on observations
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Draw particles from p(x^{k+1} | x^{k})
        particles = integrate(x_0=particles, num_steps=num_steps, dt=dt, sigma=sigma)

        # Update the output
        posteriors[k] = particles.copy()
    
    return posteriors

def BPF(
    x_0: Array,
    y: Array,
    H: Callable[[Array], Array],
    sigma_y: Array,
    dt: float = 1e-3,
    sigma: float = 0.25,
    obs_dt: float = 0.5,
    verbose: bool = True,
) -> Array:
    """
    Apply the Bootstrap Particle Filter (BPF) to approximate the filtering distribution p(x^{k} | y^{1:k}) at each step.
    Input(s):
        - x_0 (Array): input particles drawn from p(x^{0}) with dimension (num_particles, 3).
        - y (Array): observations with dimension (num_assim_steps, d).
        - H (Callable[[Array], Array]): observation operator from (batch_size, 3) to (batch_size, d).
        - sigma_y (Array): diagonal covariance matrix of the observations with dimension (d,).
        - dt (float): time step used to integrate the system.
        - sigma (float): noise level of the stochastic term in the stochastic L63 Equation.
        - obs_dt (float): time interval between two observed states.
        - verbose (True): if True, a tqdm is bar is used.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, num_particles, 3).
    """
    # Internal function to compute the log of normalized weights
    def compute_log_weights(
        obs: Array,
        particles: Array,
        operator: Callable[[Array], Array] = H,
        sigma: Array = sigma_y,
    ) -> Array:
        """
        Compute the log of normalized weights.
        Input(s):
            - obs (Array): current observation with dimension (d,).
            - particles (Array): current particles with dimension (num_particles, 3).
            - operator (Callable[[Array], Array]): observation operator from (batch_size, 3) to (batch_size, d).
            - sigma (Array): diagonal covariance matrix of observations with dimension (d,).
        """
        v = obs - operator(particles)
        unnormalized_weights = -0.5 * np.sum((v**2) / sigma, axis = -1)
        normalized_weights = unnormalized_weights - logsumexp(unnormalized_weights)
        return normalized_weights

    # Internal function to do resampling
    def resampling(log_weights: Array) -> Array:
        """
        Do systematic resampling i.e draw the indices of new particles.
        Input(s):
            - log_weights (Array): log of normalized weights with dimension (num_particles,).
        """
        # Compute the weights (without the log)
        weights = np.exp(log_weights - np.max(log_weights))
        weights = weights / np.sum(weights)

        # Select indices
        N = log_weights.shape[0]
        u0 = np.random.uniform(0.0, 1.0 / N)
        positions = u0 + np.arange(N) / N
        cumulative_sum = np.cumsum(weights)
        indices = np.searchsorted(cumulative_sum, positions, side='right')

        return indices.astype(np.int64)


    # Instanciate the output and the particles
    num_assim_steps, num_particles = y.shape[0], x_0.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, num_particles, 3))
    posteriors[0], particles = x_0, x_0.copy()

    # Get the number of integration to transition from x^{k} to x^{k+1}
    num_steps = int(obs_dt / dt)

    # Loop on observations
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Get the current observation
        y_kp1 = y[k - 1]

        # Draw particles from p(x^{k+1} | x^{k})
        particles = integrate(x_0=particles, num_steps=num_steps, dt=dt, sigma=sigma)

        # Compute the log of normalized weights
        log_weights = compute_log_weights(obs=y_kp1, particles=particles, operator=H, sigma=sigma_y)
    
        # Do resampling using the weights
        indices = resampling(log_weights=log_weights)
        particles = particles[indices]
    
        # Update the output
        posteriors[k] = particles.copy()
    
    return posteriors

def EnKF(
    x_0: Array,
    y: Array,
    H: Callable[[Array], Array],
    sigma_y: Array,
    dt: float = 1e-3,
    sigma: float = 0.25,
    obs_dt: float = 0.5,
    verbose: bool = True,
) -> Array:
    """
    Apply the Ensemble Kalman Filter (EnKF) to approximate the filtering distribution p(x^{k} | y^{1:k}) at each step.
    Input(s):
        - x_0 (Array): input particles drawn from p(x^{0}) with dimension (num_particles, 3).
        - y (Array): observations with dimension (num_assim_steps, d).
        - H (Callable[[Array], Array]): observation operator from (batch_size, 3) to (batch_size, d).
        - sigma_y (Array): diagonal covariance matrix of the observations with dimension (d,).
        - dt (float): time step used to integrate the system.
        - sigma (float): noise level of the stochastic term in the stochastic L63 Equation.
        - obs_dt (float): time interval between two observed states.
        - verbose (True): if True, a tqdm is bar is used.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, num_particles, 3).
    """
    # Instanciate the output and the particles
    num_assim_steps, num_particles = y.shape[0], x_0.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, num_particles, 3))
    posteriors[0], particles = x_0, x_0.copy()

    # Get the number of integration to transition from x^{k} to x^{k+1}
    num_steps = int(obs_dt / dt)

    # Loop on observations
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Get the current observation
        y_kp1 = y[k - 1]

        # Draw particles from p(x^{k+1} | x^{k}) (forecast step)
        particles = integrate(x_0=particles, num_steps=num_steps, dt=dt, sigma=sigma)

        # Compute mean and anomaly matrix in state space
        x_mean = np.mean(particles, axis=0)
        A = particles - x_mean[None, :]

        # Compute mean and anomaly matrix in observation space
        Y_f = H(particles)   
        y_mean = np.mean(Y_f, axis=0)
        Y_p = Y_f - y_mean[None, :]

        # Estimate ensemble covariance matrix
        P_xy = (1. / (num_particles - 1.)) * (A.T @ Y_p)
        P_yy = (1. / (num_particles - 1.)) * (Y_p.T @ Y_p) + np.diag(sigma_y)
        
        # Compute Kalman gain K
        K = np.linalg.solve(P_yy.T, P_xy.T).T

        # Analysis
        y_obs_pert = y_kp1[None, :] + np.sqrt(sigma_y)[None, :] * np.random.randn(num_particles, y_kp1.shape[0])
        innovation = (y_obs_pert - Y_f).T
        particles = particles + (K @ innovation).T

        # Update the output
        posteriors[k] = particles.copy()
    
    return posteriors

def EnSF(
    x_0: Array,
    y: Array,
    grad_J: Callable[[Array, Array], Array],
    dt: float = 1e-3,
    sigma: float = 0.25,
    obs_dt: float = 0.5,
    eps_alpha: float = 0.5,
    eps_beta: float = 0.025,
    num_sampling_steps: int = 500,
    vectorized: bool = True,
    verbose: bool = True,
) -> Array:
    """
    Apply the EnSF algorithm to approximate the filtering distribution p(x^{k} | y^{1:k}) at each step.
    See "An Ensemble Score Filter for Tracking High-Dimensional Nonlinear Dynamical Systems" for more details.
    Argument(s):
        - x_0 (Array): input particles drawn from p(x^{0}) with dimension (num_particles, 3).
        - y (Array): observations with dimension (num_assim_steps, d).
        - grad_J (Callable[[Array, Array], Array]): gradient of the energy function from (batch_size, 3) x (d,) to (batch_size, 3).
        - dt (float): time step used to integrate the system.
        - sigma (float): noise level of the stochastic term in the stochastic L63 Equation.
        - obs_dt (float): time interval between two observed states.
        - eps_alpha (float): first hyperparameter of EnSF.
        - eps_beta (float): second hyperparameter of EnSF.
        - num_sampling_steps (int): number of sampling step to solve the reverse diffusion equation.
        - vectorized (bool): if True, computations are done in parallel.
        - verbose (bool): if True, a tqdm is bar is used.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, num_particles, 3).
    """
    # Internal function to compute the schedule
    def get_schedule(t: float, eps_alpha: float = eps_alpha, eps_beta: float = eps_beta) -> Tuple[float, float]:
        """
        Compute (alpha_t, beta_t) for a given time t in [0,1].
        Input(s):
            - t (float): current time of the reverse diffusion process in [0,1].
            - eps_alpha (float): first hyperparameter of EnSF.
            - eps_beta (float): second hyperparameter of EnSF.
        """
        alpha = 1. - t * (1 - eps_alpha)
        beta = eps_beta + t * (1. - eps_beta)
        return alpha, beta
    
    # Internal function to compute the dirft and diffusion coefficients
    def get_SDE_coefficients(t: float, eps_alpha: float = eps_alpha, eps_beta: float = eps_beta) -> Tuple[float, float]:
        """
        Compute (b_t, sigma_t) for a given time t in [0,1].
        Input(s):
            - t (float): current time of the reverse diffusion process in [0,1].
            - eps_alpha (float): first hyperparameter of EnSF.
            - eps_beta (float): second hyperparameter of EnSF.
        """
        b = (eps_alpha - 1.) / (1. - t * (1 - eps_alpha))
        sigma = (1. - eps_beta) - 2 * b * (eps_beta + t * (1. - eps_beta))
        return b, sigma
    
    # Internal function to compute the component of the prior score
    def get_prior_score_components(t: float, z_t: Array, forecast: Array) -> Array:
        """
        Compute the components needed to estimate the prior score by MC.
        Input(s):
            - t (float): current time of the reverse diffusion process in [0,1].
            - z_t (Array): current sample(s) at time t of the reverse diffusion process with dimension (3,) or (num_particles, 3).
            - forecast (Array): particles at the next step obtained by forecasting curent particles, with dimension (num_particles, 3).
        """
        alpha, beta = get_schedule(t=t)
        if z_t.ndim == 1:
            return (z_t[None, :] - alpha * forecast) / beta
        else:
            return (z_t[:, None, :] - alpha * forecast[None, :, :]) / beta
    
    # Internal function to compute the log of normalized weights
    def _log_weights(t: float, z_t: Array, forecast: Array) -> Array:
        """
        Compute the log of the normalized weights needed to estimate the prior score by MC.
        Input(s):
            - t (float): current time of the reverse diffusion process in [0,1].
            - z_t (Array): current sample(s) at time t of the sampling process with dimension (3,) or (num_particles, 3).
            - forecast (Array): particles at the next step obtained by forecasting curent particles, with dimension (num_particles, 3).
        """
        alpha, beta = get_schedule(t=t)
        if z_t.ndim == 1:
            log_weights = -0.5 * (1. / beta) * np.sum((z_t[None,:] - alpha * forecast) ** 2, axis=-1)
            log_weights = log_weights - logsumexp(log_weights)
        else:
            diffs = z_t[:, None, :] - alpha * forecast[None, :, :]
            log_weights = -0.5 * (1. / beta) * np.sum(diffs ** 2, axis=-1)
            log_weights = log_weights - logsumexp(log_weights, axis=1, keepdims=True)
        return log_weights
    
    # Internal function to compute the prior score
    def get_prior_score(t: float, z_t: Array, forecast: Array) -> Array:
        """
        Estimate the prior score by MC.
        Input(s):
            - t (float): current time of the reverse diffusion process in [0,1].
            - z_t (Array): current sample(s) at time t of the sampling process with dimension (3,) or (num_particles, 3).
            - forecast (Array): particles at the next step obtained by forecasting curent particles, with dimension (num_particles, 3).
        """
        # Get the weights and components
        log_weights = _log_weights(t=t, z_t=z_t, forecast=forecast)
        components = get_prior_score_components(t=t, z_t=z_t, forecast=forecast)
        
        # Compute the prior
        if z_t.ndim == 1:
            weights = np.exp(log_weights - np.max(log_weights))
            weights /= np.sum(weights)
            prior = - np.sum(weights[:, None] * components, axis=0)
        else:
            weights = np.exp(log_weights - np.max(log_weights, axis=1, keepdims=True))
            weights /= np.sum(weights, axis=1, keepdims=True)
            prior = - np.sum(weights[:, :, None] * components, axis=1)
        return prior
    
    # Internal function to likelihood score
    def get_likelihood_score(t: float, z_t: Array, y: Array) -> Array:
        """
        Estimate the prior score by MC.
        Input(s):
            - t (float): current time of the reverse diffusion process in [0,1].
            - z_t (Array): current sample(s) at time t of the sampling process with dimension (3,) or (num_particles, 3).
            - y (Array): - y (Array): observation with dimension (d,).
        """
        if z_t.ndim == 1:
            return (t - 1.) * grad_J(z_t[None, :], y)[0]
        else:
            return (t - 1.) * grad_J(z_t, y)
    
    # Internal function to do one step of the algorithm
    def step(
        particles: Array,
        y: Array,
        num_steps: int,
        dt: float = dt,
        sigma: float = sigma,
        num_sampling_steps: int = num_sampling_steps,
    ) -> Array:
        """
        One step of the EnSF algorithm which consists in:
            1) Propagate the particles to the next predictive distribution.
            2) Propagate each previous particle to the next filtering distribution.
        Input(s):
            - particles (Array): current particles representing the current filtering distribution with dimension (num_particles, 3).
            - y (Array): current observation with dimension (d,).
            - num_steps (int): number of steps to do during the integration of the system.
            - dt (float): time step used to integrate the system.
            - sigma (float): noise level of the stochastic term in the stochastic L63 Equation.
            - num_sampling_steps (int): number of sampling step to solve the reverse diffusion process.
        """
        # 1) Propagate the particles
        z_1 = particles.copy()
        z_0 = integrate(x_0=z_1, num_steps=num_steps, dt=dt, sigma=sigma)
        mean, var = np.mean(z_0, axis=0), np.var(z_0, axis=0)

        # 2) Propage each particle to the next filtering distribution
        if not vectorized:
            next_particles, diffusion_dt = np.zeros_like(z_0), np.round(1. / num_sampling_steps, 6)
            for n in range(next_particles.shape[0]):
                t = 1.
                alpha, beta = get_schedule(t=t)
                scale = np.sqrt((alpha ** 2) * var + beta)
                z = np.random.normal(loc=(alpha * mean), scale=scale)
                for _ in range(num_sampling_steps):
                    # Get the drift and diffusion coefficients
                    b, sigma = get_SDE_coefficients(t=t)
                    # Estimate the score of the posterior
                    prior_score = get_prior_score(t=t, z_t=z, forecast=z_0)
                    likelihood_score = get_likelihood_score(t=t, z_t=z, y=y)
                    posterior_score = prior_score + likelihood_score
                    # Update the current sample
                    dw = np.sqrt(diffusion_dt) * np.random.randn(*z.shape)
                    z = z - (b * z - sigma * posterior_score) * diffusion_dt - np.sqrt(sigma) * dw
                    # Update the time of the process
                    t -= diffusion_dt
                next_particles[n] = z
            return next_particles
        else:
            diffusion_dt, t = np.round(1. / num_sampling_steps, 6), 1.
            alpha, beta = get_schedule(t=t)
            scale = np.sqrt((alpha ** 2) * var + beta)
            z = np.random.normal(loc=(alpha * mean), scale=scale, size=z_0.shape)
            for _ in range(num_sampling_steps):
                # Get the drift and diffusion coefficients
                b, sigma_diff = get_SDE_coefficients(t=t)
                # Estimate the score of the posterior (Vectorized)
                prior_score = get_prior_score(t=t, z_t=z, forecast=z_0)
                likelihood_score = get_likelihood_score(t=t, z_t=z, y=y)
                posterior_score = prior_score + likelihood_score
                # Update the current sample (Vectorized)
                dw = np.sqrt(diffusion_dt) * np.random.randn(*z.shape)
                z = z - (b * z - sigma_diff * posterior_score) * diffusion_dt - np.sqrt(sigma_diff) * dw
                # Update the time of the process
                t -= diffusion_dt
            return z
    
    # Instanciate the output and the particles
    num_assim_steps, num_particles = y.shape[0], x_0.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, num_particles, 3))
    posteriors[0], particles = x_0, x_0.copy()

    # Get the number of integration to transition from x^{k} to x^{k+1}
    num_steps = int(obs_dt / dt)

    # Loop on observations
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Get the current observation
        y_kp1 = y[k - 1]

        # Do one step of the EnSF algorithm
        particles = step(particles=particles, y=y_kp1, num_steps=num_steps)

        # Update the output
        posteriors[k] = particles.copy()
    
    return posteriors

def EnFF(
    x_0: Array,
    y: Array,
    grad_J: Callable[[Array, Array], Array],
    dt: float = 1e-3,
    sigma: float = 0.25,
    obs_dt: float = 0.5,
    scheduler: float = 0.2,
    sigma_min: float = 1e-2,
    num_sampling_steps: int = 20,
    vectorized: bool = True,
    verbose: bool = True,
) -> Array:
    """
    Apply the EnFF algorithm to approximate the filtering distribution p(x^{k} | y^{1:k}) at each step.
    See "Flow Matching for Efficient and Scalable Data Assimilation" for more details.
    Argument(s):
        - x_0 (Array): input particles drawn from p(x^{0}) with dimension (num_particles, 3).
        - y (Array): observations with dimension (num_assim_steps, d).
        - grad_J (Callable[[Array, Array], Array]): gradient of the energy function from (batch_size, 3) x (d,) to (batch_size, 3).
        - dt (float): time step used to integrate the system.
        - sigma (float): noise level of the stochastic term in the stochastic L63 Equation.
        - obs_dt (float): time interval between two observed states.
        - scheduler (float): constant term used to compute the localized guidance.
        - sigma_min (float): noise defining the conditional vector filed u_t(z | z_{0}, z_{1}).
        - num_sampling_steps (int): number of sampling step for flow matching.
        - vectorized (bool): if True, computation are done in parralel.
        - verbose (bool): if True, a tqdm is bar is used.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, num_particles, 3).
    """
    # Internal function to compute the conditional vector field
    def conditional_VF(z_0: Array, z_1: Array) -> Array:
        """
        Filtering-to-predictive (F2P) conditional vector field.
        Input(s):
            - z_0 (Array): samples from the reference distribution (the previous filtering distribution here) with dimension (num_particles, 3).
            - z_1 (Array): samples from the target distribution (the next predictive distribution here) with dimension (num_particles, 3).
        Returns
            - u (Array): filtering-to-predictive conditional vector field u = z_1 - z_0 with dimension (num_particles, 3).
        """
        u = z_1 - z_0
        return u

    # Internal function to compute the log of normalized weights
    def _log_weights(t: float, z_t: Array, z_0: Array, z_1: Array, sigma_min: float = sigma_min) -> Array:
        """
        Compute the log of the normalized weights needed to estimate the marginal VF by MC.
        Input(s):
            - t (float): current time of the sampling process in [0,1].
            - z_t (Array): one sample at time t of the sampling process with dimension (3,) or (num_particles, 3).
            - z_0 (Array): samples from the reference distribution (the previous filtering distribution here) with dimension (num_particles, 3).
            - z_1 (Array): samples from the target distribution (the next predictive distribution here) with dimension (num_particles, 3).
        Returns
            - log_weights (Array): log of normalized weights with dimension (num_particles,).
        """
        v = t * z_1 + (1.0 - t) * z_0
        if z_t.ndim == 1:
            log_weights = -0.5 * (1. / (sigma_min**2)) * np.sum((z_t[None, :] - v)**2, axis=-1)
            log_weights = log_weights - logsumexp(log_weights)
        else:
            diffs = z_t[:, None, :] - v[None, :, :]
            log_weights = -0.5 * (1.0 / (sigma_min**2)) * np.sum(diffs**2, axis=-1)
            log_weights = log_weights - logsumexp(log_weights, axis=1, keepdims=True)
        return log_weights

    # Internal function to compute the marginal vector field
    def marginal_VF(t: float, z_t: Array, z_0: Array, z_1: Array) -> Array:
        """
        Estimate the marginal VF by MC.
        Input(s):
            - t (float): current time of the sampling process in [0,1].
            - z_t (Array): one sample at time t of the sampling process with dimension (3,) or (num_particles, 3).
            - z_0 (Array): samples from the reference distribution (the previous filtering distribution here) with dimension (num_particles, 3).
            - z_1 (Array): samples from the target distribution (the next predictive distribution here) with dimension (num_particles, 3).
        Returns:
            - u (Array): MC estimation of the marginal VF with dimension (3,).
        """
        u_cond = conditional_VF(z_0=z_0, z_1=z_1)
        log_weights = _log_weights(t=t, z_t=z_t, z_0=z_0, z_1=z_1)
        if log_weights.ndim == 1:
            weights = np.exp(log_weights - np.max(log_weights))
            weights /= np.sum(weights)
            u = np.sum(weights[:, None] * u_cond, axis=0)
        else:
            row_max = np.max(log_weights, axis=1, keepdims=True)
            weights = np.exp(log_weights - row_max)
            weights = weights / np.sum(weights, axis=1, keepdims=True)
            u = weights @ u_cond
        return u

    # Internal function to compute the guidance vector field
    def guidance_VF(hat_z_1: Array, y: Array, scheduler: float = scheduler) -> Array:
        """
        Estimate the guidance VF by linearizing the likelihood.
        Input(s):
            - hat_z_1 (Array): estimation of E_{z_1 ~ p(z_{1} | z_{t})}[z_{1}] with dimension (3,) or (num_particles, 3).
            - y (Array): observation with dimension (d,).
        Returns
            - g (Array): estimation of guidance VF by linearization with dimension (3,).
        """
        if hat_z_1.ndim == 1:
            g = -scheduler * grad_J(hat_z_1[None, :], y)[0]
        else:
            g = -scheduler * grad_J(hat_z_1, y)
        return g

    # Internal function to do one step of the algorithm
    def step(
        particles: Array,
        y: Array,
        num_steps: int,
        dt: float = dt,
        sigma: float = sigma,
        num_sampling_steps: int = num_sampling_steps,
    ) -> Array:
        """
        One step of the EnFF algorithm which consists in:
            1) Propagate the particles to the next predictive distribution.
            2) Propagate each previous particle to the next filtering distribution.
        Input(s):
            - particles (Array): current particles representing the current filtering distribution with dimension (num_particles, 3).
            - y (Array): current observation with dimension (d,).
            - num_steps (int): number of steps to do during the integration of the system.
            - dt (float): time step used to integrate the system.
            - sigma (float): noise level of the stochastic term in the stochastic L63 Equation.
            - num_sampling_steps (int): number of sampling step for flow matching.
        Returns:
            - next_particles (Array): particles representing the next filtering distribution with dimension (num_particles, 3).
        """
        # 1) Propagate the particles
        z_0 = particles.copy()
        z_1 = integrate(x_0=z_0, num_steps=num_steps, dt=dt, sigma=sigma)

        # 2) Propage each particle to the next filtering distribution
        if not vectorized:
            next_particles, flow_dt = np.zeros_like(z_1), np.round(1. / num_sampling_steps, 6)
            for n in range(z_0.shape[0]):
                z, t = z_0[n].copy(), 0.0
                for _ in range(num_sampling_steps):
                    # Approximate the marginal VF by MC
                    u_t = marginal_VF(t=t, z_t=z, z_0=z_0, z_1=z_1)
                    # Approximate the guidance VF by linearization
                    hat_z_1 = z + (1.0 - t) * u_t
                    g_t = guidance_VF(hat_z_1=hat_z_1, y=y)
                    # Do one Euler integration step
                    tilde_u_t = u_t + g_t
                    z += flow_dt * tilde_u_t
                    t += flow_dt
                next_particles[n] = z
            return next_particles
        else:
            z, flow_dt, t = z_0.copy(), np.round(1. / num_sampling_steps, 6), 0.
            for _ in range(num_sampling_steps):
                # Approximate the marginal VF by MC
                u_t = marginal_VF(t=t, z_t=z, z_0=z_0, z_1=z_1)
                # Approximate the guidance VF by linearization
                hat_z_1 = z + (1.0 - t) * u_t
                g_t = guidance_VF(hat_z_1=hat_z_1, y=y)
                # Do one Euler integration step
                tilde_u_t = u_t + g_t
                z += flow_dt * tilde_u_t
                t += flow_dt
            return z
    
    # Instanciate the output and the particles
    num_assim_steps, num_particles = y.shape[0], x_0.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, num_particles, 3))
    posteriors[0], particles = x_0, x_0.copy()

    # Get the number of integration to transition from x^{k} to x^{k+1}
    num_steps = int(obs_dt / dt)

    # Loop on observations
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Get the current observation
        y_kp1 = y[k - 1]

        # Do one step of the EnFF algorithm
        particles = step(particles=particles, y=y_kp1, num_steps=num_steps)

        # Update the output
        posteriors[k] = particles.copy()
    
    return posteriors

@torch.no_grad()
def FA_APF(
    denoiser: Lorenz63Denoiser,
    x_0: Array,
    N_min: int,
    N_max: int,
    y: Array,
    H: Callable[[Tensor], Tensor],
    sigma_y: Array,
    std_z: Array,
    std_x: Array,
    mean_x: Array,
    weights_computation: str = "one-shot",
    num_samples_mc: int = 10_000,
    max_iter_alpha: int = 100,
    verbose: bool = True,
) -> Array:
    """
    Apply the Fully-Adapted Auxiliary Particle Filter (FA-APF) to approximate the filtering distribution p(x^{k} | y^{1:k}) at each step.
    Input(s):
        - denoiser (Lorenz63Denoiser): trained denoiser for the Lorenz63 stochastic system.
        - x_0 (Array): input particles drawn from p(x^{0}) with dimension (num_particles, 3).
        - N_min (int): minimum number of effective particles.
        - N_max (int): maximum number of effective particles.
        - y (Array): observations with dimension (num_assim_steps, d).
        - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 3) to (batch_size, d).
        - sigma_y (Array): diagonal covariance matrix of the observations with dimension (d,).
        - std_z (Array): standard deviations of residuals with dimension (3,).
        - std_x (Array): standard deviations of states with dimension (3,).
        - mean_x (Array): means of states with dimension (3,).
        - weights_computation (str): method to use to approximate the weights (Monte-Carlo or One-Shot with the trained denoiser).
        - num_samples_mc (int): number of samples to use for the MC approximation of the weights.
        - max_iter_alpha (float): maximum number of iterations to find an inflation coefficient such that N_eff is in [N_min, N_max].
        - verbose (True): if True, a tqdm is bar is used.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, num_particles, 3).
    """
    # Get device and put the denoiser on device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    denoiser = denoiser.to(device=device).eval()

    # Convert input array to tensors and put them on the device
    particles = torch.from_numpy(x_0.copy()).to(device=device, dtype=torch.float32)
    y_torch = torch.from_numpy(y.copy()).to(device=device, dtype=torch.float32)
    std_z_torch = torch.from_numpy(std_z.copy()).to(device=device, dtype=torch.float32)
    std_x_torch = torch.from_numpy(std_x.copy()).to(device=device, dtype=torch.float32)
    mean_x_torch = torch.from_numpy(mean_x.copy()).to(device=device, dtype=torch.float32)
    sigma_y_torch = torch.from_numpy(sigma_y.copy()).to(device=device, dtype=torch.float32)

    # Instanciate an unconditional sampler
    sampler = DDIMSampler(denoiser=denoiser, eta=0.5, steps=128, device=device, silent=True, dtype=torch.float32)

    # Internal function to normalized states
    def normalized_state(x: Tensor, mean_x: Tensor = mean_x_torch, std_x: Tensor = std_x_torch) -> Tensor:
        """
        Normalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 3).
            - mean_x (Tensor): means of states with dimension (3,).
            - std_x (Tensor): standard deviation of states with dimension (3,).
        """
        return (x - mean_x[None,:]) / std_x[None,:]
    
    # Internal function to unnormalized states
    def unnormalized_state(x: Tensor, mean_x: Tensor = mean_x_torch, std_x: Tensor = std_x_torch) -> Tensor:
        """
        Unnormalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 3).
            - mean_x (Tensor): means of states with dimension (3,).
            - std_x (Tensor): standard deviation of states with dimension (3,).
        """
        return std_x[None,:] * x + mean_x[None,:]
    
    # Internal function to unnormalized residual
    def unnormalized_residual(z: Tensor, std_z: Tensor = std_z_torch) -> Tensor:
        """
        Unnormalized a batch of residuals.
        Input(s):
            - z (Tensor): normalized residuals with dimension (batch_size, 3).
            - std_z (Tensor): standard deviation of residuals with dimension (3,).
        """
        return std_z[None,:] * z
    
    # Internal function to compute next states expectations
    def get_next_state_expectations(
        x_k: Tensor,
        std_z: Tensor = std_z_torch,
        std_x: Tensor = std_x_torch,
        mean_x: Tensor = mean_x_torch,
        sampler: DDIMSampler = sampler,
        denoiser: Lorenz63Denoiser = denoiser,
    ) -> Tensor:
        """
        Compute next states expectations E[x^{k+1}|hat{x}^{k}], latter used to approximate the weights.
        Input(s):
            - x_k (Tensor): current normalized states of the system with dimension (num_particles, 3).
            - std_z (Tensor): standard deviation of residuals with dimension (3,).
            - std_x (Tensor): standard deviation of states with dimension (3,).
            - mean_x (Tensor): means of states with dimension (3,).
            - sampler (DDIMSampler): sampler used to generate noisy samples at t=1.
            - denoiser (Lorenz63Denoiser): a trained denoiser.
        """
        noisy_residuals = sampler.init(shape=x_k.shape).to(device=x_k.device, dtype=torch.float32)
        t_1 = torch.ones(noisy_residuals.shape[0], device=x_k.device, dtype=torch.float32)
        residuals = denoiser(z_kp1_t=noisy_residuals, t=t_1, x_k=x_k).mean
        return unnormalized_residual(z=residuals, std_z=std_z) + unnormalized_state(x=x_k, mean_x=mean_x, std_x=std_x)
    
    # Internal function to duplicate a state and propagate it without observation
    def duplicate_and_propagate(
        x_k: Tensor,
        num_samples: int,
        std_z: Tensor = std_z_torch,
        std_x: Tensor = std_x_torch,
        mean_x: Tensor = mean_x_torch,
        sampler: DDIMSampler = sampler,
    ) -> Tensor:
        """
        Duplicate a normalized current state of the system and propagate it without using any observation.
        Input(s):
            - x_k (Tensor): a current normalized states of the system with dimension (3,).
            - num_samples (int): number of copy of x_k.
            - std_z (Tensor): standard deviation of residuals with dimension (3,).
            - std_x (Tensor): standard deviation of states with dimension (3,).
            - mean_x (Tensor): means of states with dimension (3,).
            - sampler (DDIMSampler): a sampler to generate next residuals.
        """
        batch_x_k = x_k.unsqueeze(0).repeat(num_samples, 1)
        noisy_residuals = sampler.init(shape=batch_x_k.shape).to(device=batch_x_k.device, dtype=torch.float32)
        residuals = sampler.__call__(x=noisy_residuals, x_k=batch_x_k)
        return unnormalized_residual(z=residuals, std_z=std_z) + unnormalized_state(x=batch_x_k, mean_x=mean_x, std_x=std_x)
    
    # Internal function to compute the number of efficient particles
    def get_num_efficient(log_weights: Tensor) -> float:
        """
        Compute the number of efficient particles.
        Input(s):
            - log_weights (Tensor): log of normalized weights [log(w^{k+1}_{1}), ..., log(w^{k+1}_{N})] with dimension (num_particles,).
        """
        log_N_eff = -torch.logsumexp(2 * log_weights, dim=-1)
        N_eff = torch.exp(log_N_eff).item()
        return N_eff

    # Internal function to compute the log of the weights
    def compute_log_weights(
        y: Tensor,
        alpha: float,
        expectations: Optional[Tensor] = None,
        samples: Optional[Tensor] = None,
        sigma_y: Tensor = sigma_y_torch,
        H: Callable[[Tensor], Tensor] = H,
    ) -> Tensor:
        """
        Compute the log of normalized inflated weights [log(w^{k+1}_{1}), ..., log(w^{k+1}_{N})] with dimension (num_particles,).
        Input(s):
            - y (Tensor): observation of the next state with dimension (d,).
            - alpha (float): inflation coefficient.
            - expectations (Optional[Tensor]): next states expectations E[x^{k+1}|hat{x}^{k}] with dimension (num_particles, 3).
            - samples (Optional[Tensor]): samples from p(x^{k+1}|hat{x}^{k}__{(i)}) for each particle i with dimension (num_particles, num_samples, 3).
            - sigma_y (Tensor): covariance matrix of observations with dimension (d,).
            - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 3) to (batch_size, d).
        Returns:
            - log_weights (Tensor): log of normalized inflated weights with dimension (num_particles,). 
        """
        # Inflated the covariance matrix
        inflated_covariance = (1.0 / alpha) * sigma_y[None,:]

        # Approximation with the method proposed in the paper
        if expectations is not None:
            v = y[None, :] - H(expectations)
            a = -0.5 * torch.sum((v**2) / inflated_covariance, dim=-1)
            normalization_coefficient = torch.logsumexp(a, dim=-1)
            log_inflated_weights = a - normalization_coefficient
        
        # Monte-Carlo approximation
        elif samples is not None:
            a = torch.zeros(samples.shape[0]).to(device=device, dtype=torch.float32)
            for n in range(samples.shape[0]):
                particle_mc_samples = samples[n]
                v = y[None, :] - H(particle_mc_samples)
                log_sum_exp = torch.logsumexp(-0.5 * torch.sum((v**2) / inflated_covariance, dim=-1), dim=0)
                a[n] = log_sum_exp
            log_inflated_weights = a - torch.logsumexp(a, dim=0)

        return log_inflated_weights # type: ignore

    # Internal function to draw indices given the log of normalized inflated weights
    def get_indices(log_weights: torch.Tensor, method: str = "systematic") -> Tensor:
        """
        Draw indices using the log of normalized weights.
        Input(s):
            - log_weights (Tensor): log of normalized weights [log(w^{k+1}_{1}), ..., log(w^{k+1}_{N})] with dimension (num_particles,).
            - method (str): the method to use in order to draw the samples.
        """
        num_particles = log_weights.shape[0]
        weights = torch.exp(log_weights - torch.max(log_weights))
        weights = weights / torch.sum(weights, dim=-1)
        if method == "categorical":
            indices = torch.multinomial(weights, num_samples=num_particles, replacement=True)
        else:
            cumulative_sum = torch.cumsum(weights, dim=-1)
            u0 = torch.rand(1, device=log_weights.device, dtype=torch.float32) / num_particles
            positions = u0 + torch.arange(num_particles, device=log_weights.device, dtype=torch.float32) / num_particles
            indices = torch.searchsorted(cumulative_sum, positions, right=True)
        return indices

    # Internal function to draw samples from the optimal proposal
    def draw_from_optimal(
        x_k: Tensor,
        y: Tensor,
        denoiser: Lorenz63Denoiser = denoiser,
        sigma_y: Tensor = sigma_y_torch,
        H: Callable[[Tensor], Tensor] = H,
        std_z: Tensor = std_z_torch,
        std_x: Tensor = std_x_torch,
        mean_x: Tensor = mean_x_torch,
    ) -> Tensor:
        """
        Draw samples from the optimal proposal q(x^{k+1} | x^{k}_{i}, y^{k+1}).
        Input(s):
            - x_k (Tensor): current normalized states of the system with dimension (num_particles, 3).
            - y (Tensor): observation of the next state with dimension (d,).
            - denoiser (Lorenz63Denoiser): trained denoiser for the Lorenz63 stochastic system.
            - sigma_y (Tensor): covariance matrix of observations with dimension (d,).
            - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 3) to (batch_size, d).
            - std_z (Tensor): standard deviation of residuals with dimension (3,).
            - std_x (Tensor): standard deviation of states with dimension (3,).
            - mean_x (Tensor): means of states with dimension (3,).
        """
        # Define a conditional denoiser
        conditional_denoiser = ConditionalMMPSDenoiser(
            denoiser=denoiser,
            y=y,
            H=H,
            sigma_y=sigma_y,
            mean_x=mean_x,
            std_x=std_x,
            std_z=std_z,
        ).to(x_k.device)

        # Define a conditional sampler
        conditional_sampler = DDIMSampler(
            denoiser=conditional_denoiser,
            eta=0.5,
            steps=512,
            device=x_k.device,
            silent=True,
            dtype=torch.float32,
        )

        # Do sampling form the optimal proposal
        noisy_residuals = conditional_sampler.init(shape=x_k.shape).to(device=x_k.device, dtype=torch.float32)
        normalized_residuals = conditional_sampler.__call__(x=noisy_residuals, x_k=x_k)
        return normalized_residuals
    
    # Instanciate the output
    particles = normalized_state(x=particles)
    num_assim_steps, num_particles = y.shape[0], x_0.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, num_particles, 3))
    posteriors[0] = x_0
    
    # Loop on the numer of steps to do
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Get the current observation
        y_kp1 = y_torch[k - 1].to(dtype=torch.float32, device=device)

        # Compute expectation or samples
        if weights_computation == "one-shot":
            expectations = get_next_state_expectations(x_k=particles)
            samples = None
        else:
            expectations = None
            samples = torch.zeros(particles.shape[0], num_samples_mc, 3).to(dtype=torch.float32, device=device)
            for n in range(samples.shape[0]):
                samples[n] = duplicate_and_propagate(x_k=particles[n], num_samples=num_samples_mc)

        # Get the indices with inflated weights
        dim_state = particles.numel() / particles.shape[0]
        alpha_min, alpha_max, alpha = 1e-12, 1.0, 1.0 / dim_state
        num_iter, N_eff = 0, None
        while num_iter < max_iter_alpha:
            # Compute the log of normalized inflated weights
            log_weights = compute_log_weights(
                y=y_kp1,
                alpha=alpha,
                expectations=expectations,
                samples=samples,
            )

            # Compute the number of efficient particles
            N_eff = get_num_efficient(log_weights=log_weights)

            # Update alpha
            if N_eff > N_max:
                alpha_min = alpha
                alpha = 0.5 * (alpha_min + alpha_max)
            elif N_eff < N_min:
                alpha_max = alpha
                alpha = 0.5 * (alpha_min + alpha_max)
            else:
                break

            # Update the number if iterations
            num_iter += 1

        # Draw indices from the log of normalized inflated weights and select particles
        indices = get_indices(log_weights=log_weights)  # type: ignore
        selected_particles = torch.zeros_like(particles).to(dtype=torch.float32, device=device)
        for i, indice in enumerate(indices):
            idx = int(indice.item())
            selected_particles[i] = particles[idx]

        # Draw the next particles using the optimal proposal disitribution
        next_residuals = draw_from_optimal(
            x_k=selected_particles,
            y=y_kp1,
        )
        next_states = unnormalized_residual(z=next_residuals) + unnormalized_state(x=selected_particles)
        next_normalized_states = normalized_state(x=next_states)

        # Update the particles
        particles = next_normalized_states.clone()
        particles = particles.to(dtype=torch.float32, device=device)

        # Update the tensor of results
        posteriors[k] = next_states.detach().cpu().numpy().copy()

    return posteriors
