import torch
import numpy as np

from tqdm import tqdm
from torch import Tensor
from typing import Callable
from numpy import ndarray as Array
from azula.sample import DDIMSampler
from diffusion.denoiser import NavierStokesDenoiser, ConditionalMMPSDenoiser

@torch.no_grad()
def FA_APF(
    denoiser: NavierStokesDenoiser,
    hat_x: Tensor,
    N_min: int,
    N_max: int,
    y: Tensor,
    H: Callable[[Tensor], Tensor],
    sigma_y: Tensor,
    std_z: Tensor,
    std_x: Tensor,
    mean_x: Tensor,
    lower: Tensor,
    upper: Tensor,
    batch_size: int = 4,
    max_iter_alpha: int = 100,
    verbose: bool = True,
) -> Array:
    """
    Apply the Fully-Adapted Auxiliary Particle Filter (FA-APF) to approximate the filtering distribution p(x^{k} | y^{1:k}) at each step.
    Input(s):
        - denoiser (NavierStokesDenoiser): trained denoiser for the 2D Navier-Stokes system.
        - hat_x (Tensor): input normalized particles drawn from p(x^{0}) with dimension (num_particles, 2, 128, 128).
        - N_min (int): minimum number of effective particles.
        - N_max (int): maximum number of effective particles.
        - y (Tensor): observations with dimension (num_assim_steps, d).
        - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 1, 128, 128) to (batch_size, d).
        - sigma_y (Tensor): diagonal covariance matrix of the observations with dimension (d,).
        - std_z (Tensor): standard deviations of residuals with dimension (1, 128, 128).
        - std_x (Tensor): standard deviations of states with dimension (1, 128, 128).
        - mean_x (Tensor): means of states with dimension (1, 128, 128).
        - lower (Tensor): lower bound for MMPS with dimension (1, 128, 128).
        - upper (Tensor): upper bound for MMPS with dimension (1, 128, 128).
        - batch_size (int): batch size to use when propagating the particles.
        - max_iter_alpha (float): maximum number of iterations to find an inflation coefficient such that N_eff is in [N_min, N_max].
        - verbose (True): if True, a tqdm is bar is used.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, num_particles, 1, 128, 128).
    """
    # Get device and put the denoiser on device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    denoiser = denoiser.to(device=device).eval()

    # Put tensors to device
    particles = (hat_x.clone()).to(device=device, dtype=torch.float32)
    y = y.to(device=device, dtype=torch.float32)
    std_z = std_z.to(device=device, dtype=torch.float32)
    std_x = std_x.to(device=device, dtype=torch.float32)
    mean_x = mean_x.to(device=device, dtype=torch.float32)
    sigma_y = sigma_y.to(device=device, dtype=torch.float32)
    lower = lower.to(device=device, dtype=torch.float32)
    upper = upper.to(device=device, dtype=torch.float32)

    # Internal function to normalized states
    def normalized_state(x: Tensor, mean_x: Tensor = mean_x, std_x: Tensor = std_x) -> Tensor:
        """
        Normalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
        """
        return (x - mean_x[None, :]) / std_x[None, :]
    
    # Internal function to unnormalized states
    def unnormalized_state(x: Tensor, mean_x: Tensor = mean_x, std_x: Tensor = std_x) -> Tensor:
        """
        Unnormalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
        """
        return std_x[None, :] * x + mean_x[None, :]
    
    # Internal function to unnormalized residual
    def unnormalized_residual(z: Tensor, std_z: Tensor = std_z) -> Tensor:
        """
        Unnormalized a batch of residuals.
        Input(s):
            - z (Tensor): normalized residuals with dimension (batch_size, 1, 128, 128).
            - std_z (Tensor): standard deviation of residuals with dimension (1, 128, 128).
        """
        return std_z[None, :] * z
    
    # Internal function to compute next states expectations
    def get_next_state_expectations(
        x_k: Tensor,
        denoiser: NavierStokesDenoiser = denoiser,
    ) -> Tensor:
        """
        Compute next states expectations E[x^{k+1}|hat{x}^{k}], latter used to approximate the weights.
        Input(s):
            - x_k (Tensor): current normalized states of the system with dimension (num_particles, 2, 128, 128).
            - denoiser (NavierStokesDenoiser): a trained denoiser for the 2D Navier-Stokes equation.
        """
        shape = (x_k.shape[0], 1, 128, 128)
        sampler = DDIMSampler(denoiser=denoiser, eta=0.5, device=device)
        noisy_residuals = sampler.init(shape=shape).to(device=x_k.device, dtype=torch.float32)
        t_1 = torch.ones(noisy_residuals.shape[0], device=x_k.device, dtype=torch.float32)
        normalized_residuals = denoiser(z_kp1_t=noisy_residuals, t=t_1, x_k=x_k).mean
        return unnormalized_residual(z=normalized_residuals) + unnormalized_state(x=x_k[:, -1, :, :].unsqueeze(1))
    
    # Internal function to compute the number of efficient particles
    def get_num_efficient(log_weights: Tensor) -> float:
        """
        Compute the number of efficient particles.
        Input(s):
            - log_weights (Tensor): log of normalized weights [log(w^{k+1}_{1}), ..., log(w^{k+1}_{N})] with dimension (num_particles,).
        """
        log_N_eff = -torch.logsumexp(2 * log_weights, dim=-1)
        N_eff = torch.exp(log_N_eff).item()
        return N_eff

    # Internal function to compute the log of the weights
    def compute_log_weights(
        y: Tensor,
        alpha: float,
        expectations: Tensor,
        sigma_y: Tensor = sigma_y,
        H: Callable[[Tensor], Tensor] = H,
    ) -> Tensor:
        """
        Compute the log of normalized inflated weights [log(w^{k+1}_{1}), ..., log(w^{k+1}_{N})] with dimension (num_particles,).
        Input(s):
            - y (Tensor): observation of the next state with dimension (d,).
            - alpha (float): inflation coefficient.
            - expectations (Tensor): unnormalized next states expectations E[x^{k+1}|hat{x}^{k}] with dimension (num_particles, 1, 128, 128).
            - sigma_y (Tensor): covariance matrix of observations with dimension (d,).
            - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 1, 128, 128) to (batch_size, d).
        Returns:
            - log_weights (Tensor): log of normalized inflated weights with dimension (num_particles,). 
        """
        # Inflated the covariance matrix
        inflated_covariance = (1.0 / alpha) * sigma_y[None, :]

        # Approximation with the method proposed in the paper
        v = y[None, :] - H(expectations)
        a = -0.5 * torch.sum((v**2) / inflated_covariance, dim=-1)
        normalization_coefficient = torch.logsumexp(a, dim=-1)
        log_inflated_weights = a - normalization_coefficient
        
        return log_inflated_weights

    # Internal function to draw indices given the log of normalized inflated weights
    def get_indices(log_weights: torch.Tensor, method: str = "systematic") -> Tensor:
        """
        Draw indices using the log of normalized weights.
        Input(s):
            - log_weights (Tensor): log of normalized weights [log(w^{k+1}_{1}), ..., log(w^{k+1}_{N})] with dimension (num_particles,).
            - method (str): the method to use in order to draw the samples.
        """
        num_particles = log_weights.shape[0]
        weights = torch.exp(log_weights - torch.max(log_weights))
        weights = weights / torch.sum(weights, dim=-1)
        if method == "categorical":
            indices = torch.multinomial(weights, num_samples=num_particles, replacement=True)
        else:
            cumulative_sum = torch.cumsum(weights, dim=-1)
            u0 = torch.rand(1, device=log_weights.device, dtype=torch.float32) / num_particles
            positions = u0 + torch.arange(num_particles, device=log_weights.device, dtype=torch.float32) / num_particles
            indices = torch.searchsorted(cumulative_sum, positions, right=True)
        return indices

    # Internal function to draw samples from the optimal proposal
    def draw_from_optimal(
        x_k: Tensor,
        y: Tensor,
        denoiser: NavierStokesDenoiser = denoiser,
        sigma_y: Tensor = sigma_y,
        H: Callable[[Tensor], Tensor] = H,
        std_z: Tensor = std_z,
        std_x: Tensor = std_x,
        mean_x: Tensor = mean_x,
        lower: Tensor = lower,
        upper: Tensor = upper,
    ) -> Tensor:
        """
        Draw samples from the optimal proposal q(x^{k+1} | x^{k}_{i}, y^{k+1}) = p(x^{k+1} | x^{k}_{i}, y^{k+1}).
        Input(s):
            - x_k (Tensor): current normalized states of the system with dimension (batch_size, 2, 128, 128).
            - y (Tensor): observation of the next state with dimension (d,).
            - denoiser (NavierStokesDenoiser): trained denoiser for the 2D Navier-Stokes equation.
            - sigma_y (Tensor): covariance matrix of observations with dimension (d,).
            - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 1, 128, 128) to (batch_size, d).
            - std_z (Tensor): standard deviation of residuals with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - lower (Tensor): lower bounds of residuals with dimension (1, 128, 128).
            - upper (Tensor): upper bounds of residuals with dimension (1, 128, 128).
        """
        # Duplicate the observation
        batch_size = x_k.shape[0]
        y_batched = y.unsqueeze(0).repeat(batch_size, 1)

        # Define a conditional denoiser
        conditional_denoiser = ConditionalMMPSDenoiser(
            denoiser=denoiser,
            y=y_batched,
            H=H,
            sigma_y=sigma_y,
            mean_x=mean_x,
            std_x=std_x,
            std_z=std_z,
            lower=lower,
            upper=upper,
            num_iterations=1,
        ).to(x_k.device)

        # Define a conditional sampler
        conditional_sampler = DDIMSampler(
            denoiser=conditional_denoiser,
            eta=0.5,
            steps=128,
            device=x_k.device,
            silent=False,
            dtype=torch.float32,
        )

        # Do sampling form the optimal proposal
        shape = (batch_size, 1, 128, 128)
        noisy_residuals = conditional_sampler.init(shape=shape).to(device=x_k.device, dtype=torch.float32)
        normalized_residuals = conditional_sampler.__call__(x=noisy_residuals, x_k=x_k)
        return normalized_residuals
    
    # Internal function to propoagate particles
    def propagate_particles(
        x_k: Tensor,
        y: Tensor,
        batch_size: int = batch_size,
    ) -> Tensor:
        """
        Propagate the particles to the time step using the optimal proposal.
        Input(s):
            - x_k (Tensor): current normalized states of the system with dimension (num_particles, 2, 128, 128).
            - y (Tensor): observation of the next state with dimension (d,).
            - batch_size (int): batch size to use when propagating the particles.
        """
        num_particles = x_k.shape[0]
        div, mod = int(num_particles // batch_size), int(num_particles % batch_size)
        num_steps = div if (mod == 0) else div + 1
        next_particles = torch.zeros(num_particles, 1, 128, 128).to(device=x_k.device, dtype=torch.float32)
        for k in range(num_steps):
            batch_x_k = x_k[k * batch_size : (k + 1) * batch_size]
            next_residuals_batch = draw_from_optimal(x_k=batch_x_k, y=y)
            next_states_batch = unnormalized_residual(z=next_residuals_batch) + unnormalized_state(x=batch_x_k[:, -1, :, :].unsqueeze(1))
            next_particles[k * batch_size : (k + 1) * batch_size] = next_states_batch
        return next_particles
    
    # Instanciate the output
    num_assim_steps, num_particles = y.shape[0], particles.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, num_particles, 1, 128, 128))
    posteriors[0] = unnormalized_state(x=particles[:, -1, :, :].unsqueeze(1)).cpu().numpy()
    
    # Loop on the numer of steps to do
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Get the current observation
        y_kp1 = y[k - 1]

        # Compute expectation
        if k > 1:
            expectations = get_next_state_expectations(x_k=particles)

            # Get the indices with inflated weights
            dim_state = particles.numel() / particles.shape[0]
            alpha_min, alpha_max, alpha = 1e-12, 1.0, 1.0 / dim_state
            num_iter, N_eff = 0, None
            while num_iter < max_iter_alpha:
                # Compute the log of normalized inflated weights
                log_weights = compute_log_weights(
                    y=y_kp1,
                    alpha=alpha,
                    expectations=expectations,
                )

                # Compute the number of efficient particles
                N_eff = get_num_efficient(log_weights=log_weights)

                # Update alpha
                if N_eff > N_max:
                    alpha_min = alpha
                    alpha = 0.5 * (alpha_min + alpha_max)
                elif N_eff < N_min:
                    alpha_max = alpha
                    alpha = 0.5 * (alpha_min + alpha_max)
                else:
                    break

                # Update the number if iterations
                num_iter += 1
        else:
            num_particles = particles.shape[0]
            weights = torch.full((num_particles,), 1. / num_particles, dtype=torch.float32, device=particles.device, requires_grad=False)
            log_weights = torch.log(weights)

        # Draw indices from the log of normalized inflated weights and select particles
        indices = get_indices(log_weights=log_weights)  # type: ignore
        selected_particles = torch.zeros_like(particles).to(dtype=torch.float32, device=device)
        for i, indice in enumerate(indices):
            idx = int(indice.item())
            selected_particles[i] = particles[idx]

        # Draw the next particles using the optimal proposal disitribution
        next_states = propagate_particles(x_k=selected_particles, y=y_kp1)
        next_normalized_states = normalized_state(x=next_states)

        # Update the particles
        next_particles = torch.zeros_like(particles).to(dtype=torch.float32, device=device)
        next_particles[:, 0, :, :] = selected_particles[:, -1, :, :]
        next_particles[:, -1, :, :] = next_normalized_states[:, 0, :, :]
        particles = next_particles.clone()

        # Update the tensor of results
        posteriors[k] = next_states.cpu().numpy().copy()

    return posteriors


@torch.no_grad()
def FlowDAS(
    denoiser: NavierStokesDenoiser,
    hat_x: Tensor,
    y: Tensor,
    H: Callable[[Tensor], Tensor],
    sigma_y: Tensor,
    std_z: Tensor,
    std_x: Tensor,
    mean_x: Tensor,
    lower: Tensor,
    upper: Tensor,
    verbose: bool = True,
) -> Array:
    """
    Approximate the filtering distribution p(x^{k} | y^{1:k}) at each step with one sample drawn from p(x^{k+1} | y^{k+1}, x^{k}).
    Input(s):
        - denoiser (NavierStokesDenoiser): trained denoiser for the 2D Navier-Stokes system.
        - hat_x (Tensor): input normalized particles drawn from p(x^{0}) with dimension (1, 2, 128, 128).
        - y (Tensor): observations with dimension (num_assim_steps, d).
        - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 1, 128, 128) to (batch_size, d).
        - sigma_y (Tensor): diagonal covariance matrix of the observations with dimension (d,).
        - std_z (Tensor): standard deviations of residuals with dimension (1, 128, 128).
        - std_x (Tensor): standard deviations of states with dimension (1, 128, 128).
        - mean_x (Tensor): means of states with dimension (1, 128, 128).
        - lower (Tensor): lower bound for MMPS with dimension (1, 128, 128).
        - upper (Tensor): upper bound for MMPS with dimension (1, 128, 128).
        - verbose (True): if True, a tqdm is bar is used.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, 1, 1, 128, 128).
    """
    # Get device and put the denoiser on device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    denoiser = denoiser.to(device=device).eval()

    # Put tensors to device
    particle = (hat_x.clone()).to(device=device, dtype=torch.float32)
    y = y.to(device=device, dtype=torch.float32)
    std_z = std_z.to(device=device, dtype=torch.float32)
    std_x = std_x.to(device=device, dtype=torch.float32)
    mean_x = mean_x.to(device=device, dtype=torch.float32)
    sigma_y = sigma_y.to(device=device, dtype=torch.float32)
    lower = lower.to(device=device, dtype=torch.float32)
    upper = upper.to(device=device, dtype=torch.float32)

    # Internal function to normalized states
    def normalized_state(x: Tensor, mean_x: Tensor = mean_x, std_x: Tensor = std_x) -> Tensor:
        """
        Normalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
        """
        return (x - mean_x[None, :]) / std_x[None, :]
    
    # Internal function to unnormalized states
    def unnormalized_state(x: Tensor, mean_x: Tensor = mean_x, std_x: Tensor = std_x) -> Tensor:
        """
        Unnormalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
        """
        return std_x[None, :] * x + mean_x[None, :]
    
    # Internal function to unnormalized residual
    def unnormalized_residual(z: Tensor, std_z: Tensor = std_z) -> Tensor:
        """
        Unnormalized a batch of residuals.
        Input(s):
            - z (Tensor): normalized residuals with dimension (batch_size, 1, 128, 128).
            - std_z (Tensor): standard deviation of residuals with dimension (1, 128, 128).
        """
        return std_z[None, :] * z

    # Internal function to draw samples from the optimal proposal
    def draw_from_optimal(
        x_k: Tensor,
        y: Tensor,
        denoiser: NavierStokesDenoiser = denoiser,
        sigma_y: Tensor = sigma_y,
        H: Callable[[Tensor], Tensor] = H,
        std_z: Tensor = std_z,
        std_x: Tensor = std_x,
        mean_x: Tensor = mean_x,
        lower: Tensor = lower,
        upper: Tensor = upper,
    ) -> Tensor:
        """
        Draw samples from the optimal proposal q(x^{k+1} | x^{k}_{i}, y^{k+1}) = p(x^{k+1} | x^{k}_{i}, y^{k+1}).
        Input(s):
            - x_k (Tensor): current normalized states of the system with dimension (1, 2, 128, 128).
            - y (Tensor): observation of the next state with dimension (d,).
            - denoiser (NavierStokesDenoiser): trained denoiser for the 2D Navier-Stokes equation.
            - sigma_y (Tensor): covariance matrix of observations with dimension (d,).
            - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 1, 128, 128) to (batch_size, d).
            - std_z (Tensor): standard deviation of residuals with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - lower (Tensor): lower bounds of residuals with dimension (1, 128, 128).
            - upper (Tensor): upper bounds of residuals with dimension (1, 128, 128).
        """
        # Define a conditional denoiser
        conditional_denoiser = ConditionalMMPSDenoiser(
            denoiser=denoiser,
            y=y.unsqueeze(0),
            H=H,
            sigma_y=sigma_y,
            mean_x=mean_x,
            std_x=std_x,
            std_z=std_z,
            lower=lower,
            upper=upper,
            num_iterations=1,
        ).to(x_k.device)

        # Define a conditional sampler
        conditional_sampler = DDIMSampler(
            denoiser=conditional_denoiser,
            eta=0.5,
            steps=128,
            device=x_k.device,
            silent=True,
            dtype=torch.float32,
        )

        # Do sampling form the optimal proposal
        shape = (1, 1, 128, 128)
        noisy_residuals = conditional_sampler.init(shape=shape).to(device=x_k.device, dtype=torch.float32)
        normalized_residuals = conditional_sampler.__call__(x=noisy_residuals, x_k=x_k)
        return normalized_residuals
    
    # Instanciate the output
    num_assim_steps = y.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, 1, 1, 128, 128))
    posteriors[0] = unnormalized_state(x=particle[:, -1, :, :].unsqueeze(1)).cpu().numpy()
    
    # Loop on the numer of steps to do
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Get the current observation
        y_kp1 = y[k - 1]

        # Draw the next particles using the optimal proposal disitribution
        next_residual = draw_from_optimal(x_k=particle, y=y_kp1)
        next_state = unnormalized_residual(z=next_residual) + unnormalized_state(x=particle[:, -1, :, :].unsqueeze(1))
        next_normalized_state = normalized_state(x=next_state)

        # Update the particles
        next_particle = torch.zeros_like(particle).to(dtype=torch.float32, device=device)
        next_particle[:, 0, :, :] = particle[:, -1, :, :]
        next_particle[:, -1, :, :] = next_normalized_state
        particle = next_particle.clone()

        # Update the tensor of results
        posteriors[k] = next_state.cpu().numpy().copy()

    return posteriors


@torch.no_grad()
def BPF(
    denoiser: NavierStokesDenoiser,
    hat_x: Tensor,
    N_min: int,
    N_max: int,
    y: Tensor,
    H: Callable[[Tensor], Tensor],
    sigma_y: Tensor,
    std_z: Tensor,
    std_x: Tensor,
    mean_x: Tensor,
    max_iter_alpha: int = 100,
    verbose: bool = True,
) -> Array:
    """
    Apply the Boostrap Particle Filter (BPF) to approximate the filtering distribution p(x^{k} | y^{1:k}) at each step.
    Input(s):
        - denoiser (NavierStokesDenoiser): trained denoiser for the 2D Navier-Stokes system.
        - hat_x (Tensor): input normalized particles drawn from p(x^{0}) with dimension (num_particles, 2, 128, 128).
        - N_min (int): minimum number of effective particles.
        - N_max (int): maximum number of effective particles.
        - y (Tensor): observations with dimension (num_assim_steps, d).
        - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 1, 128, 128) to (batch_size, d).
        - sigma_y (Tensor): diagonal covariance matrix of the observations with dimension (d,).
        - std_z (Tensor): standard deviations of residuals with dimension (1, 128, 128).
        - std_x (Tensor): standard deviations of states with dimension (1, 128, 128).
        - mean_x (Tensor): means of states with dimension (1, 128, 128).
        - max_iter_alpha (float): maximum number of iterations to find an inflation coefficient such that N_eff is in [N_min, N_max].
        - verbose (True): if True, a tqdm is bar is used.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, num_particles, 1, 128, 128).
    """
    # Get device and put the denoiser on device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    denoiser = denoiser.to(device=device).eval()

    # Put tensors to device
    particles = (hat_x.clone()).to(device=device, dtype=torch.float32)
    y = y.to(device=device, dtype=torch.float32)
    std_z = std_z.to(device=device, dtype=torch.float32)
    std_x = std_x.to(device=device, dtype=torch.float32)
    mean_x = mean_x.to(device=device, dtype=torch.float32)
    sigma_y = sigma_y.to(device=device, dtype=torch.float32)

    # Internal function to normalized states
    def normalized_state(x: Tensor, mean_x: Tensor = mean_x, std_x: Tensor = std_x) -> Tensor:
        """
        Normalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
        """
        return (x - mean_x[None, :]) / std_x[None, :]
    
    # Internal function to unnormalized states
    def unnormalized_state(x: Tensor, mean_x: Tensor = mean_x, std_x: Tensor = std_x) -> Tensor:
        """
        Unnormalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
        """
        return std_x[None, :] * x + mean_x[None, :]
    
    # Internal function to unnormalized residual
    def unnormalized_residual(z: Tensor, std_z: Tensor = std_z) -> Tensor:
        """
        Unnormalized a batch of residuals.
        Input(s):
            - z (Tensor): normalized residuals with dimension (batch_size, 1, 128, 128).
            - std_z (Tensor): standard deviation of residuals with dimension (1, 128, 128).
        """
        return std_z[None, :] * z
        
    # Internal function to compute the number of efficient particles
    def get_num_efficient(log_weights: Tensor) -> float:
        """
        Compute the number of efficient particles.
        Input(s):
            - log_weights (Tensor): log of normalized weights [log(w^{k+1}_{1}), ..., log(w^{k+1}_{N})] with dimension (num_particles,).
        """
        log_N_eff = -torch.logsumexp(2 * log_weights, dim=-1)
        N_eff = torch.exp(log_N_eff).item()
        return N_eff

    # Internal function to compute the log of the weights
    def compute_log_weights(
        y: Tensor,
        alpha: float,
        next_states: Tensor,
        sigma_y: Tensor = sigma_y,
        H: Callable[[Tensor], Tensor] = H,
    ) -> Tensor:
        """
        Compute the log of normalized inflated weights [log(w^{k+1}_{1}), ..., log(w^{k+1}_{N})] with dimension (num_particles,).
        Input(s):
            - y (Tensor): observation of the next state with dimension (d,).
            - alpha (float): inflation coefficient.
            - next_states (Tensor): unnormalized next states with dimension (num_particles, 1, 128, 128).
            - sigma_y (Tensor): covariance matrix of observations with dimension (d,).
            - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 1, 128, 128) to (batch_size, d).
        Returns:
            - log_weights (Tensor): log of normalized inflated weights with dimension (num_particles,). 
        """
        # Inflated the covariance matrix
        inflated_covariance = (1.0 / alpha) * sigma_y[None, :]

        # Approximation with the method proposed in the paper
        v = y[None, :] - H(next_states)
        a = -0.5 * torch.sum((v**2) / inflated_covariance, dim=-1)
        normalization_coefficient = torch.logsumexp(a, dim=-1)
        log_inflated_weights = a - normalization_coefficient
        
        return log_inflated_weights

    # Internal function to draw indices given the log of normalized inflated weights
    def get_indices(log_weights: torch.Tensor, method: str = "systematic") -> Tensor:
        """
        Draw indices using the log of normalized weights.
        Input(s):
            - log_weights (Tensor): log of normalized weights [log(w^{k+1}_{1}), ..., log(w^{k+1}_{N})] with dimension (num_particles,).
            - method (str): the method to use in order to draw the samples.
        """
        num_particles = log_weights.shape[0]
        weights = torch.exp(log_weights - torch.max(log_weights))
        weights = weights / torch.sum(weights, dim=-1)
        if method == "categorical":
            indices = torch.multinomial(weights, num_samples=num_particles, replacement=True)
        else:
            cumulative_sum = torch.cumsum(weights, dim=-1)
            u0 = torch.rand(1, device=log_weights.device, dtype=torch.float32) / num_particles
            positions = u0 + torch.arange(num_particles, device=log_weights.device, dtype=torch.float32) / num_particles
            indices = torch.searchsorted(cumulative_sum, positions, right=True)
        return indices

    # Internal function to draw samples from the optimal proposal
    def propagate_particles(
        x_k: Tensor,
        denoiser: NavierStokesDenoiser = denoiser,
    ) -> Tensor:
        """
        Draw samples from the classical proposal q(x^{k+1} | x^{k}_{i}, y^{k+1}) = p(x^{k+1} | x^{k}_{i}).
        Input(s):
            - x_k (Tensor): current normalized states of the system with dimension (num_particles, 2, 128, 128).
            - denoiser (NavierStokesDenoiser): trained denoiser for the 2D Navier-Stokes equation.
        """
        # Define a conditional sampler
        sampler = DDIMSampler(
            denoiser=denoiser,
            eta=0.5,
            steps=32,
            device=x_k.device,
            silent=True,
            dtype=torch.float32,
        )

        # Do sampling form the optimal proposal
        shape = (x_k.shape[0], 1, 128, 128)
        noisy_residuals = sampler.init(shape=shape).to(device=x_k.device, dtype=torch.float32)
        normalized_residuals = sampler.__call__(x=noisy_residuals, x_k=x_k)
        return normalized_residuals
    
    # Instanciate the output
    num_assim_steps, num_particles = y.shape[0], particles.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, num_particles, 1, 128, 128))
    posteriors[0] = unnormalized_state(x=particles[:, -1, :, :].unsqueeze(1)).cpu().numpy()
    
    # Loop on the numer of steps to do
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Propagate the particles
        next_residuals = propagate_particles(x_k=particles)
        next_states = unnormalized_residual(z=next_residuals) + unnormalized_state(x=particles[:, -1, :, :].unsqueeze(1))
        next_normalized_states = normalized_state(x=next_states)

        # Get the current observation
        y_kp1 = y[k - 1]

        # Compute the log of normalized weights
        if k > 1:

            # Get the indices with inflated weights
            dim_state = particles.numel() / particles.shape[0]
            alpha_min, alpha_max, alpha = 1e-12, 1.0, 1.0 / dim_state
            num_iter, N_eff = 0, None
            while num_iter < max_iter_alpha:
                # Compute the log of normalized inflated weights
                log_weights = compute_log_weights(
                    y=y_kp1,
                    alpha=alpha,
                    next_states=next_states,
                )

                # Compute the number of efficient particles
                N_eff = get_num_efficient(log_weights=log_weights)

                # Update alpha
                if N_eff > N_max:
                    alpha_min = alpha
                    alpha = 0.5 * (alpha_min + alpha_max)
                elif N_eff < N_min:
                    alpha_max = alpha
                    alpha = 0.5 * (alpha_min + alpha_max)
                else:
                    break

                # Update the number if iterations
                num_iter += 1
        else:
            num_particles = particles.shape[0]
            weights = torch.full((num_particles,), 1. / num_particles, dtype=torch.float32, device=particles.device, requires_grad=False)
            log_weights = torch.log(weights)

        # Resampling
        indices = get_indices(log_weights=log_weights)  # type: ignore
        selected_particles = torch.zeros_like(particles).to(dtype=torch.float32, device=device)
        for i, indice in enumerate(indices):
            idx = int(indice.item())
            selected_particles[i, 0, :, :] = particles[idx, -1, :, :]
            selected_particles[i, -1, :, :] = next_normalized_states[idx, 0, :, :]

        # Update the particles
        particles = selected_particles.clone()

        # Update the tensor of results
        posteriors[k] = unnormalized_state(x=particles[:, -1, :, :].unsqueeze(1)).cpu().numpy().copy()

    return posteriors


@torch.no_grad()
def LETKF(
    denoiser: NavierStokesDenoiser,
    hat_x: Tensor,
    y: Tensor,
    obs_coords: Tensor,
    H: Callable[[Tensor], Tensor],
    sigma_y: Tensor,
    std_z: Tensor,
    std_x: Tensor,
    mean_x: Tensor,
    inflation_factor: float = 1.05,
    localization_radius: float = 1.0,
    verbose: bool = True,
):
    """
    Apply the LETKF algorithm to approximate the filtering distribution p(x^{k} | y^{1:k}) at each step.
    Input(s):
        - denoiser (NavierStokesDenoiser): trained denoiser for the 2D Navier-Stokes system.
        - hat_x (Tensor): input normalized particles drawn from p(x^{0}) with dimension (num_particles, 2, 128, 128).
        - y (Tensor): observations with dimension (num_assim_steps, d).
        - obs_coords (Tensor): coodinates of observations with dimension (d,2).
        - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 1, 128, 128) to (batch_size, d).
        - sigma_y (Tensor): diagonal covariance matrix of the observations with dimension (d,).
        - std_z (Tensor): standard deviations of residuals with dimension (1, 128, 128).
        - std_x (Tensor): standard deviations of states with dimension (1, 128, 128).
        - mean_x (Tensor): means of states with dimension (1, 128, 128).
        - verbose (True): if True, a tqdm is bar is used.
    Returns
        - posteriors (Array): estimation of the filtering distribution at each step with dimension (num_assim_steps + 1, num_particles, 1, 128, 128).
    """
    # Get device and put the denoiser on device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    denoiser = denoiser.to(device=device).eval()

    # Put tensors to device
    particles = (hat_x.clone()).to(device=device, dtype=torch.float32)
    y = y.to(device=device, dtype=torch.float32)
    std_z = std_z.to(device=device, dtype=torch.float32)
    std_x = std_x.to(device=device, dtype=torch.float32)
    mean_x = mean_x.to(device=device, dtype=torch.float32)
    sigma_y = sigma_y.to(device=device, dtype=torch.float32)

    # Internal function to normalized states
    def normalized_state(x: Tensor, mean_x: Tensor = mean_x, std_x: Tensor = std_x) -> Tensor:
        """
        Normalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
        """
        return (x - mean_x[None, :]) / std_x[None, :]
    
    # Internal function to unnormalized states
    def unnormalized_state(x: Tensor, mean_x: Tensor = mean_x, std_x: Tensor = std_x) -> Tensor:
        """
        Unnormalized a batch of states.
        Input(s):
            - x (Tensor): normalized state with dimension (batch_size, 1, 128, 128).
            - mean_x (Tensor): means of states with dimension (1, 128, 128).
            - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
        """
        return std_x[None, :] * x + mean_x[None, :]
    
    # Internal function to unnormalized residual
    def unnormalized_residual(z: Tensor, std_z: Tensor = std_z) -> Tensor:
        """
        Unnormalized a batch of residuals.
        Input(s):
            - z (Tensor): normalized residuals with dimension (batch_size, 1, 128, 128).
            - std_z (Tensor): standard deviation of residuals with dimension (1, 128, 128).
        """
        return std_z[None, :] * z
    
    # Internal function to draw samples from the optimal proposal
    def propagate_particles(
        x_k: Tensor,
        denoiser: NavierStokesDenoiser = denoiser,
    ) -> Tensor:
        """
        Draw samples from the classical proposal q(x^{k+1} | x^{k}_{i}, y^{k+1}) = p(x^{k+1} | x^{k}_{i}).
        Input(s):
            - x_k (Tensor): current normalized states of the system with dimension (num_particles, 2, 128, 128).
            - denoiser (NavierStokesDenoiser): trained denoiser for the 2D Navier-Stokes equation.
        """
        # Define a conditional sampler
        sampler = DDIMSampler(
            denoiser=denoiser,
            eta=0.5,
            steps=128,
            device=x_k.device,
            silent=True,
            dtype=torch.float32,
        )
    
        # Do sampling form the optimal proposal
        shape = (x_k.shape[0], 1, 128, 128)
        noisy_residuals = sampler.init(shape=shape).to(device=x_k.device, dtype=torch.float32)
        normalized_residuals = sampler.__call__(x=noisy_residuals, x_k=x_k)
        return unnormalized_residual(z=normalized_residuals) + unnormalized_state(x=x_k[:, -1, :, :].unsqueeze(1))
    
    # Internal function for Gaspari-Cohn
    def gaspari_cohn(dist: Tensor, radius: float, eps: float = 1e-12) -> Tensor:
        r = (dist / (radius + eps))
        w = torch.zeros_like(r)
        idx1 = (r >= 0) & (r < 1)
        rr = r[idx1]
        w[idx1] = 1.0 - (5.0/3.0)*rr**2 + (5.0/8.0)*rr**3 + 0.5*rr**4 - 0.25*rr**5
        idx2 = (r >= 1) & (r < 2)
        rr2 = r[idx2]
        w[idx2] = ( ( -1.0/12.0)*rr2**5 + 0.5*rr2**4
                    - (5.0/8.0)*rr2**3 - (5.0/3.0)*rr2**2 + 5.0*rr2 - 4.0 )
        return w.clamp(min=0.0)

    
    # Compute 2D grid coordinates
    H_dim = 128
    grid_1d = torch.linspace(0, 2 * np.pi, H_dim, device=device)
    grid_y, grid_x = torch.meshgrid(grid_1d, grid_1d, indexing='ij')
    grid_coords = torch.stack([grid_x, grid_y], dim=-1).reshape(-1, 2)
    
    # Instanciate the output
    num_assim_steps, num_particles = y.shape[0], particles.shape[0]
    posteriors = np.zeros((num_assim_steps + 1, num_particles, 1, 128, 128))
    posteriors[0] = unnormalized_state(x=particles[:, -1, :, :].unsqueeze(1)).cpu().numpy()
    
    # Loop on the numer of steps to do
    iterator = tqdm(range(1, num_assim_steps + 1)) if verbose else range(1, num_assim_steps + 1)
    for k in iterator:
        # Propagate the particles
        next_states = propagate_particles(x_k=particles)

        # Get the current observation
        y_kp1 = y[k - 1]

        # LETKF analysis
        X_f = next_states.view(num_particles, -1)
        mean_x_f = X_f.mean(dim=0)
        Z_f = X_f - mean_x_f.unsqueeze(0)
        Z_f = Z_f * inflation_factor

        Y_f = H(next_states) 
        mean_y_f = Y_f.mean(dim=0)
        Y_anom = Y_f - mean_y_f.unsqueeze(0)
        innovation = y_kp1 - mean_y_f

        d_x = torch.abs(grid_coords[:, None, 0] - obs_coords[None, :, 0])
        d_y = torch.abs(grid_coords[:, None, 1] - obs_coords[None, :, 1])
        d_x = torch.min(d_x, 2 * np.pi - d_x)
        d_y = torch.min(d_y, 2 * np.pi - d_y)
        dists = torch.sqrt(d_x**2 + d_y**2)
        rho_mat = gaspari_cohn(dists, localization_radius)

        x_a_mean = torch.zeros_like(mean_x_f)
        x_a_pert = torch.zeros_like(Z_f)
        R_inv_diag = 1.0 / (sigma_y + 1e-8)
        for s in range(grid_coords.shape[0]):
            w_loc = rho_mat[s]
            active_idx = w_loc > 1e-5
            if not active_idx.any():
                x_a_mean[s] = mean_x_f[s]
                x_a_pert[:, s] = Z_f[:, s]
                continue
            Y_loc = Y_anom[:, active_idx]
            innov_loc = innovation[active_idx]
            R_inv_loc_data = R_inv_diag[active_idx]
            w_loc_data = w_loc[active_idx]
            inv_R_loc = R_inv_loc_data * w_loc_data
            Ys_scaled = Y_loc * inv_R_loc.unsqueeze(0) 
            H_ens = torch.matmul(Ys_scaled, Y_loc.T)
            ident = torch.eye(num_particles, device=device)
            P_tilde_inv = (num_particles - 1) * ident + H_ens
            vals, vecs = torch.linalg.eigh(P_tilde_inv)
            vals_inv = 1.0 / (vals + 1e-8)
            P_tilde = vecs @ torch.diag(vals_inv) @ vecs.T
            vals_inv_sqrt = torch.sqrt(vals_inv)
            W_s = (vecs @ torch.diag(vals_inv_sqrt) @ vecs.T) * np.sqrt(num_particles - 1)
            term_inn = torch.matmul(Ys_scaled, innov_loc)
            w_bar = torch.matmul(P_tilde, term_inn)
            x_a_mean[s] = mean_x_f[s] + torch.dot(Z_f[:, s], w_bar)
            x_a_pert[:, s] = torch.matmul(Z_f[:, s].unsqueeze(0), W_s).squeeze(0)

        X_a = x_a_mean.unsqueeze(0) + x_a_pert
        x_a = X_a.view(num_particles, 1, 128, 128)

        # Update the particles
        next_particles = torch.zeros_like(particles).to(device=particles.device, dtype=torch.float32)
        next_particles[:, 0, :, :] = particles[:, -1, :, :]
        next_particles[:, -1, :, :] = normalized_state(x=x_a)
        particles = next_particles.clone()

        # Update the tensor of results
        posteriors[k] = x_a.cpu().numpy().copy()
    
    return posteriors
