"""
Copyright 2025 [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Tuple, Union
import numpy as np
import torch
from torch import Tensor, LongTensor, einsum, searchsorted
from torch.nn import Module, Sequential, ModuleList, Linear, ReLU, Tanh, Softplus, LayerNorm, BatchNorm1d, Dropout, Parameter
from torch.nn.utils.rnn import pad_packed_sequence, PackedSequence
from src.utils import vector_to_lower_triangular_matrix, broadcastable_cat, triangular_inverse, maharanobis_norm_squared, regularize_positive_definite, safe_positive_definite_inverse, safe_cholesky, safe_logdet



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



class MLP(Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int=256, activation: str='relu', activation2: str=None, normalization: str=None, dropout: float=0.0, dtype: torch.dtype=torch.float32):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        if activation == 'relu':
            self.activation1 = ReLU()
            self.activation2 = ReLU()
        elif activation == 'tanh':
            self.activation1 = Tanh()
            self.activation2 = Tanh()
        elif activation == 'softplus':
            self.activation1 = Softplus()
            self.activation2 = Softplus()
        else:
            raise NotImplementedError('Invalid activation function.')
        if activation2 == 'relu':
            self.activation2 = ReLU()
        elif activation2 == 'tanh':
            self.activation2 = Tanh()
        elif activation2 == 'softplus':
            self.activation2 = Softplus()

        layers = [
            Linear(input_dim, hidden_dim, dtype=dtype),
            self.activation1,
            Dropout(dropout),
            Linear(hidden_dim, hidden_dim, dtype=dtype),
            self.activation2,
            Dropout(dropout),
            Linear(hidden_dim, output_dim, dtype=dtype)
        ]

        if normalization is not None:
            if normalization == 'batch':
                layers.insert(1, BatchNorm1d(hidden_dim))
                layers.insert(5, BatchNorm1d(hidden_dim))
            elif normalization == 'layer':
                layers.insert(1, LayerNorm(hidden_dim))
                layers.insert(5, LayerNorm(hidden_dim))

        self.mlp = Sequential(*layers)
        
        # Initialize the weights
        for m in self.modules():
            if isinstance(m, Linear):
                if activation in ('relu', 'softplus'):
                    torch.nn.init.kaiming_normal_(m.weight)
                elif activation in ('tanh'):
                    torch.nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0.0)


    def forward(self, x):
        x = self.mlp(x)
        return x



class LowerTriangularMLP(Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int=256, activation: str='relu', activation2: str=None, dropout: float=0.0, dtype: torch.dtype=torch.float32, initial_weight_factor: float=1.0):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        if activation == 'relu':
            self.activation1 = ReLU()
            self.activation2 = ReLU()
        elif activation == 'tanh':
            self.activation1 = Tanh()
            self.activation2 = Tanh()
        elif activation == 'softplus':
            self.activation1 = Softplus()
            self.activation2 = Softplus()
        else:
            raise NotImplementedError('Invalid activation function.')
        if activation2 == 'relu':
            self.activation2 = ReLU()
        elif activation2 == 'tanh':
            self.activation2 = Tanh()
        elif activation2 == 'softplus':
            self.activation2 = Softplus()

        self.mlp = Sequential(
                    Linear(input_dim, hidden_dim, dtype=dtype),
                    # LayerNorm(hidden_dim),
                    self.activation1,
                    Dropout(dropout),
                    Linear(hidden_dim, hidden_dim, dtype=dtype),
                    # LayerNorm(hidden_dim),
                    self.activation2,
                    Dropout(dropout),
                    Linear(hidden_dim, output_dim*(output_dim + 1)//2, dtype=dtype)
                )
        
        # Initialize the weights
        for m in self.modules():
            if isinstance(m, Linear):
                if activation in ('relu', 'softplus'):
                    torch.nn.init.kaiming_normal_(m.weight)
                elif activation in ('tanh'):
                    torch.nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0.0)
                m.weight.data *= initial_weight_factor


    def forward(self, x):
        x = self.mlp(x)
        x = vector_to_lower_triangular_matrix(x, self.output_dim)
        return x
    


class ConditionalAffineCoupling(Module):
    def __init__(self, dim: int, mask: list, condition_dim: int, hidden_dim: int=256, activation: str='relu', dtype: torch.dtype=torch.float32):
        super().__init__()
        self.dim = dim
        self.mask = mask
        self.hidden_dim = hidden_dim
        self.condition_dim = condition_dim
        self.activation = activation
        self.log_out_max = 10

        self.unmask = [not m for m in mask]
        self.masked_dim = mask.count(True)
        self.unmasked_dim = mask.count(False)

        self.s_ = MLP(self.masked_dim + condition_dim, self.unmasked_dim, hidden_dim, activation, dtype=dtype)
        self.t_ = MLP(self.masked_dim + condition_dim, self.unmasked_dim, hidden_dim, activation, dtype=dtype)


    def s(self, x, c):
        return torch.tanh(self.s_(broadcastable_cat([x, c], dim=-1))/self.log_out_max)*self.log_out_max/self.dim
    

    def t(self, x, c):
        return self.t_(broadcastable_cat([x, c], dim=-1))


    def forward(self, x, c):
        x_masked, x_unmasked = x[...,self.mask], x[...,self.unmask]
        s = self.s(x_masked, c)
        t = self.t(x_masked, c)
        x_ = x.clone()
        x_[...,self.unmask] = x_unmasked*torch.exp(s) + t
        return x_
    

    def inverse(self, x, c):
        x_masked, x_unmasked = x[...,self.mask], x[...,self.unmask]
        s = self.s(x_masked, c)
        t = self.t(x_masked, c)
        x_ = x.clone()
        x_[...,self.unmask] = (x_unmasked - t)*torch.exp(-s)
        return x_
    

    def log_det_jacobian(self, x, c):
        x_masked = x[...,self.mask]
        s = self.s(x_masked, c)
        return torch.sum(s, dim=-1)
    

    def log_det_jacobian_inverse(self, x, c):
        return -self.log_det_jacobian(x, c)
    


class ConditionalRealNVP(Module):
    def __init__(self, dim: int, condition_dim: int, hidden_dim: int=256, activation: str='relu', num_coupling: int=4, dtype: torch.dtype=torch.float32):
        super().__init__()
        self.dim = dim
        self.condition_dim = condition_dim
        self.hidden_dim = hidden_dim
        self.num_coupling = num_coupling

        self.masks = [[bool((i//2 + j) % 2) for i in range(dim)] for j in range(num_coupling)]
        self.couplings = ModuleList([ConditionalAffineCoupling(dim, mask, condition_dim, hidden_dim, activation, dtype) for mask in self.masks])


    def forward(self, x, c):
        for coupling in self.couplings:
            x = coupling(x, c)
        return x
    

    def inverse(self, x, c):
        for coupling in reversed(self.couplings):
            x = coupling.inverse(x, c)
        return x
    

    def log_det_jacobian(self, x, c):
        log_det_jacobian = 0
        for coupling in self.couplings:
            log_det_jacobian += coupling.log_det_jacobian(x, c)
            x = coupling(x, c)
        return log_det_jacobian
    

    def log_det_jacobian_inverse(self, x, c):
        log_det_jacobian_inverse = 0
        for coupling in reversed(self.couplings):
            log_det_jacobian_inverse += coupling.log_det_jacobian_inverse(x, c)
            x = coupling.inverse(x, c)
        return log_det_jacobian_inverse
    


class PositiveParameter(Module):
    def __init__(self, dim: int, num_mode: int=1, dtype: torch.dtype=torch.float32):
        super().__init__()
        self.dim = dim
        self.num_mode = num_mode
        self.elements = Parameter(torch.randn(num_mode, dim, dtype=dtype)*1e-0)


    def forward(self) -> Tensor:
        return torch.nn.functional.softplus(self.elements)



class LowerTriangular(Module):
    def __init__(self, dim: int, num_mode: int=1, dtype: torch.dtype=torch.float32):
        super().__init__()
        self.dim = dim
        self.num_mode = num_mode
        self.elements = Parameter(torch.randn(num_mode, dim*(dim + 1)//2, dtype=dtype))


    def forward(self, return_inverse: bool=False) -> Tensor:
        L = self.vector_to_lower_triangular_matrix(self.elements, self.dim)
        if return_inverse:
            return L, triangular_inverse(L)
        else:
            return L
        

    def vector_to_lower_triangular_matrix(self, vector: Tensor, dim: int) -> Tensor:
        idx_diag = torch.arange(dim)
        idx_offdiag = torch.tril_indices(dim - 1, dim - 1)
        idx_offdiag = (idx_offdiag[0] + 1, idx_offdiag[1])
        lower_triangular_matrix = torch.zeros(vector.shape[:-1] + (dim, dim), dtype=vector.dtype, device=vector.device)
        lower_triangular_matrix[..., idx_diag, idx_diag] = torch.exp(vector[...,:dim]*0.5)
        lower_triangular_matrix[..., idx_offdiag[0], idx_offdiag[1]] = vector[...,dim:]

        return lower_triangular_matrix



class PriorCovCholesky(LowerTriangular):
    def __init__(self, dim: int, num_mode: int=1, dtype: torch.dtype=torch.float32):
        super().__init__(dim, num_mode, dtype)
        self.elements = Parameter(torch.randn((num_mode, dim*(dim + 1)//2), dtype=dtype))



class SysCovCholesky(LowerTriangular):
    def __init__(self, dim: int, num_mode: int=1, sys_logstd_initial: float=-4.0, dtype: torch.dtype=torch.float32):
        super().__init__(dim, num_mode, dtype)
        self.elements = Parameter(torch.cat([torch.randn((num_mode, dim), dtype=dtype) + sys_logstd_initial, torch.randn((num_mode, dim*(dim - 1)//2), dtype=dtype)], dim=-1))



class Prior(Module):
    def __init__(self, latent_dim: int, num_mode: int=1, dtype: torch.dtype=torch.float32):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_mode = num_mode
        self.means = Parameter(torch.randn(num_mode, latent_dim, dtype=dtype))
        self.cov_choleskys = PriorCovCholesky(latent_dim, num_mode, dtype=dtype)


    def mean(self, mode: LongTensor) -> Tensor:
        return self.means[mode][...,0,:]
    

    def cov_cholesky(self, mode: LongTensor) -> Tensor:
        return self.cov_choleskys()[mode][...,0,:,:]


    def cov(self, mode: LongTensor) -> Tensor:
        cov_cholesky = self.cov_cholesky(mode)
        cov = cov_cholesky @ cov_cholesky.mT
        return cov


    def forward(self, mode: LongTensor) -> Tensor:
        mean = self.mean(mode)
        cov_cholesky = self.cov_cholesky(mode)
        z = mean + torch.real(einsum('...ij,...j->...i', cov_cholesky, torch.randn(mean.shape, dtype=mode.dtype, device=mode.device)))
        return z
    


class Propagator(Module):
    def __init__(self, latent_dim: int, period_scale: float=1.0, num_mode: int=1, sys_logstd_initial: float=-4.0, dtype: torch.dtype=torch.float32):
        super().__init__()
        self.latent_dim = latent_dim
        self.period_scale = period_scale
        self.half_dim = latent_dim // 2
        self.num_mode = num_mode
        self.decayrates = PositiveParameter(self.half_dim, num_mode=num_mode, dtype=dtype)
        self.frequencies = Parameter(torch.randn((num_mode, self.half_dim), dtype=dtype)*2*np.pi/period_scale)
        self.sys_cov_choleskys = SysCovCholesky(latent_dim, num_mode=num_mode, sys_logstd_initial=sys_logstd_initial, dtype=dtype)


    def sys_cov(self, mode: LongTensor) -> Tensor:
        sys_cov_cholesky = self.sys_cov_choleskys()[mode][...,0,:,:]
        sys_cov = sys_cov_cholesky @ sys_cov_cholesky.mT
        return sys_cov


    def mean(self, z: Tensor, dt: Tensor, mode: LongTensor) -> Tensor:
        decayrates, frequencies = self.decayrates(), self.frequencies
        decayrates, frequencies = decayrates[mode][...,0,:], frequencies[mode][...,0,:]
        zr, zi = z[...,:self.half_dim], z[...,self.half_dim:]

        meanr = torch.exp(-decayrates*dt)*(torch.cos(frequencies*dt)*zr - torch.sin(frequencies*dt)*zi)
        meani = torch.exp(-decayrates*dt)*(torch.sin(frequencies*dt)*zr + torch.cos(frequencies*dt)*zi)
        mean = torch.cat([meanr, meani], dim=-1)

        return mean
    

    def evolve_matrix_left(self, A: Tensor, dt: Tensor, mode: LongTensor) -> Tensor:
        return self.mean(A.mT, dt.repeat_interleave(A.shape[-1], -1).unsqueeze(-1), mode.repeat_interleave(A.shape[-1], -1).unsqueeze(-1)).mT
    

    def evolution_matrix(self, dt: Tensor, mode: LongTensor) -> Tensor:
        return self.evolve_matrix_left(torch.eye(self.latent_dim, dtype=dt.dtype, device=dt.device), dt, mode)
    

    def evolve_matrix(self, A: Tensor, dt: Tensor, mode: LongTensor) -> Tensor:
        decayrates, frequencies = self.decayrates(), self.frequencies
        decayrates, frequencies = decayrates[mode][...,0,:], frequencies[mode][...,0,:]
        Arr, Aii = A[...,:self.half_dim,:self.half_dim], A[...,self.half_dim:,self.half_dim:]
        Ari, Air = A[...,:self.half_dim,self.half_dim:], A[...,self.half_dim:,:self.half_dim]
        cos = torch.exp(-decayrates*dt)*torch.cos(frequencies*dt)
        sin = torch.exp(-decayrates*dt)*torch.sin(frequencies*dt)
        def bilinear(u, M, v):
            return einsum('...i,...ij,...j->...ij', u, M, v)

        A_evo_rr = bilinear(cos, Arr, cos) + bilinear(cos, -Ari, sin) + bilinear(sin, -Air, cos) + bilinear(sin, Aii, sin)
        A_evo_ii = bilinear(cos, Aii, cos) + bilinear(cos, Air, sin) + bilinear(sin, Ari, cos) + bilinear(sin, Arr, sin)
        A_evo_ri = bilinear(cos, Ari, cos) + bilinear(cos, Arr, sin) + bilinear(sin, -Aii, cos) + bilinear(sin, -Air, sin)
        A_evo_ir = bilinear(cos, Air, cos) + bilinear(cos, -Aii, sin) + bilinear(sin, Arr, cos) + bilinear(sin, -Ari, sin)

        A_evo = torch.cat([torch.cat([A_evo_rr, A_evo_ri], dim=-1), torch.cat([A_evo_ir, A_evo_ii], dim=-1)], dim=-2)
        
        return A_evo
    

    def evolve_cov(self, cov: Tensor, dt: Tensor, mode: LongTensor) -> Tensor:
        return self.evolve_matrix(cov, dt, mode) + self.cov(dt, mode)
    

    def cov(self, dt: Tensor, mode: LongTensor) -> Tensor:
        dt = dt.unsqueeze(-1)
        decayrates, frequencies = self.decayrates(), self.frequencies
        decayrates, frequencies = decayrates[mode][...,0,:], frequencies[mode][...,0,:]
        sys_cov = self.sys_cov(mode)

        sys_cov_rr = sys_cov[...,:self.half_dim,:self.half_dim]
        sys_cov_ri = sys_cov[...,:self.half_dim,self.half_dim:]
        sys_cov_ir = sys_cov[...,self.half_dim:,:self.half_dim]
        sys_cov_ii = sys_cov[...,self.half_dim:,self.half_dim:]
        sys_cov_diag_sum = sys_cov_rr + sys_cov_ii
        sys_cov_diag_diff = sys_cov_rr - sys_cov_ii
        sys_cov_offdiag_sum = sys_cov_ri + sys_cov_ir
        sys_cov_offdiag_diff = sys_cov_ir - sys_cov_ri

        decayrates_sum = decayrates[...,None] + decayrates[...,None].mT
        frequencies_sum = frequencies[...,None] + frequencies[...,None].mT
        frequencies_diff = frequencies[...,None] - frequencies[...,None].mT
        
        cov_diag_diff = (
                (torch.exp(-decayrates_sum*dt)*torch.cos(frequencies_sum*dt) - 1)*(-decayrates_sum*sys_cov_diag_diff + frequencies_sum*sys_cov_offdiag_sum)
                + torch.exp(-decayrates_sum*dt)*torch.sin(frequencies_sum*dt)*(decayrates_sum*sys_cov_offdiag_sum + frequencies_sum*sys_cov_diag_diff)
            )/(decayrates_sum**2 + frequencies_sum**2)
        cov_offdiag_sum = (
                (torch.exp(-decayrates_sum*dt)*torch.cos(frequencies_sum*dt) - 1)*(-decayrates_sum*sys_cov_offdiag_sum - frequencies_sum*sys_cov_diag_diff)
                + torch.exp(-decayrates_sum*dt)*torch.sin(frequencies_sum*dt)*(-decayrates_sum*sys_cov_diag_diff + frequencies_sum*sys_cov_offdiag_sum)
            )/(decayrates_sum**2 + frequencies_sum**2)
        cov_diag_sum = (
                (torch.exp(-decayrates_sum*dt)*torch.cos(frequencies_diff*dt) - 1)*(-decayrates_sum*sys_cov_diag_sum + frequencies_diff*sys_cov_offdiag_diff)
                + torch.exp(-decayrates_sum*dt)*torch.sin(frequencies_diff*dt)*(decayrates_sum*sys_cov_offdiag_diff + frequencies_diff*sys_cov_diag_sum)
            )/(decayrates_sum**2 + frequencies_diff**2)
        cov_offdiag_diff = (
                (torch.exp(-decayrates_sum*dt)*torch.cos(frequencies_diff*dt) - 1)*(-decayrates_sum*sys_cov_offdiag_diff - frequencies_diff*sys_cov_diag_sum)
                + torch.exp(-decayrates_sum*dt)*torch.sin(frequencies_diff*dt)*(-decayrates_sum*sys_cov_diag_sum + frequencies_diff*sys_cov_offdiag_diff)
            )/(decayrates_sum**2 + frequencies_diff**2)

        cov_rr = cov_diag_sum + cov_diag_diff
        cov_ri = cov_offdiag_sum - cov_offdiag_diff
        cov_ir = cov_offdiag_sum + cov_offdiag_diff
        cov_ii = cov_diag_sum - cov_diag_diff

        cov = torch.cat([torch.cat([cov_rr, cov_ri], dim=-1), torch.cat([cov_ir, cov_ii], dim=-1)], dim=-2)

        return cov
    

    def cov_cholesky(self, dt: Tensor, mode: LongTensor) -> Tensor:
        cov = self.cov(dt, mode)
        regulator_factor = 0.0
        while True:
            try:
                cov_cholesky = torch.linalg.cholesky(cov + torch.eye(self.latent_dim, dtype=cov.dtype, device=cov.device)*regulator_factor)
                break
            except torch._C._LinAlgError:
                print(f'Cholesky decomposition in Propagator failed at regulator factor = {regulator_factor}.')
                regulator_factor = (regulator_factor + 1e-11)*10
        
        return cov_cholesky
    

    def forward(self, z: Tensor, dt: Tensor, mode: LongTensor) -> Tensor:
        mean, cov_cholesky = self.mean(z, dt, mode), self.cov_cholesky(dt, mode)
        z = mean + einsum('...ij,...j->...i', cov_cholesky, torch.randn(mean.shape, dtype=z.dtype, device=z.device))

        return z



class Observer(Module):
    def __init__(self, dim: int, latent_dim: int, normalizing_flow: dict, num_mode: int=1, obs_logvar_initial: float=-8.0, dtype: torch.dtype=torch.float32):
        super().__init__()
        self.dim = dim
        self.latent_dim = latent_dim
        self.num_mode = num_mode
        self.W = Parameter(torch.randn(num_mode, dim, latent_dim, dtype=dtype))
        self.logvar_factor = 1.0
        self._logvar = Parameter(torch.randn(num_mode, dim, dtype=dtype) + obs_logvar_initial)
        self.normalizing_flow = normalizing_flow

        if normalizing_flow is not None:
            if normalizing_flow['name'] == 'ConditionalRealNVP':
                self.flow = ConditionalRealNVP(dim, 1, normalizing_flow['hidden_dim'], normalizing_flow['activation'], normalizing_flow['num_coupling'], dtype=dtype)
            else:
                raise NotImplementedError('Invalid flow name.')
            

    def logvar(self, mode: LongTensor) -> Tensor:
        raw = self._logvar[mode][...,0,:]*self.logvar_factor
        min_val = -8.0
        return torch.nn.functional.softplus(raw - min_val) + min_val
    

    def latent_precision(self) -> Tensor:
        lat_pre = einsum('mij,mj,mjk->mik', self.W.mT, torch.exp(-self.logvar(torch.arange(self.num_mode)[:,None])), self.W)
        return lat_pre
    

    def projection(self, x: Tensor, t: Tensor, mode: LongTensor) -> Tensor:
        return einsum('...ij,...j->...i', self.W[mode][...,0,:,:].mT, torch.exp(-self.logvar(mode))*self.flow(x, t))


    def forward(self, z: Tensor, t: Tensor, mode: LongTensor, no_flow: bool=False) -> Tensor:
        std = torch.exp(0.5*self.logvar(mode))
        mean = einsum('...ij,...j->...i', self.W[mode][...,0,:,:], z)
        x = mean + std*torch.randn(mean.shape, dtype=z.dtype, device=z.device)
        if not no_flow:
            x = self.flow.inverse(x, t)

        return x



class Filter(Module):
    def __init__(self, prior: Prior, propagator: Propagator, observer: Observer):
        super().__init__()
        self.prior = prior
        self.propagator = propagator
        self.observer = observer


    def step(self, mean0: Tensor, cov0: Tensor, x1: Tensor, t0: Tensor, t1: Tensor, mode: LongTensor, lat_pre: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """
        Parameters
        ----------
        mean0: Tensor, shape (..., latent_dim)
        cov0: Tensor, shape (..., latent_dim, latent_dim)
        x1: Tensor, shape (..., dim)
        t0: Tensor, shape (..., 1)
        t1: Tensor, shape (..., 1)
        mode: LongTensor, shape (..., 1)
        lat_pre: Tensor, shape (..., latent_dim, latent_dim)
        
        Returns
        -------
        mean1: Tensor, shape (..., latent_dim)
        cov1: Tensor, shape (..., latent_dim, latent_dim)
        gain: Tensor, shape (..., latent_dim, dim)
        """
        dt = t1 - t0
        mean_evo = self.propagator.mean(mean0, dt, mode)
        cov_evo = self.propagator.evolve_cov(cov0, dt, mode)
        cov_sum = torch.inverse(torch.inverse(cov_evo) + lat_pre)
        gain = cov_evo - cov_evo @ lat_pre @ cov_sum
        mean1 = mean_evo + einsum('...ij,...j->...i', gain, self.observer.projection(x1, t1, mode) - einsum('...ij,...j->...i', lat_pre, mean_evo))
        cov_evo_pre = cov_evo @ lat_pre
        tmp = - cov_evo_pre @ cov_evo + einsum('...ik,...jl,...kl->...ij', cov_evo_pre, cov_evo_pre, cov_sum)
        cov1 = cov_evo + (tmp + tmp.mT)/2
        cov1 = regularize_positive_definite(cov1, 'cov1')

        return mean1, cov1, gain
    

    def cumulants(self, xs: Tensor, ts: Tensor, mode: LongTensor, return_gain: bool=False) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        """
        Parameters
        ----------
        xs: Tensor, shape (..., max_n_time, dim)
        ts: Tensor, shape (..., max_n_time, 1)
        mode: LongTensor, shape (..., 1)

        Returns
        -------
        means: Tensor, shape (..., max_n_time + 1, latent_dim)
        covs: Tensor, shape (..., max_n_time + 1, latent_dim, latent_dim)
        gains: Tensor, shape (..., max_n_time, latent_dim, latent_dim)
        """
        lat_pre = self.observer.latent_precision()[mode][...,0,:,:]  # (..., latent_dim, latent_dim)
        means = self.prior.mean(mode).unsqueeze(-2)   # (..., 1, latent_dim)
        covs = self.prior.cov(mode).unsqueeze(-3)    # (..., 1, latent_dim, latent_dim)

        for i in range(0, xs.shape[-2]):
            if i == 0:
                mean, cov, gain = self.step(means[...,-1,:], covs[...,-1,:,:], xs[...,i,:], torch.zeros_like(ts[...,i,:]), ts[...,i,:], mode, lat_pre)
                if return_gain:
                    gains = gain
            else:
                mean, cov, gain = self.step(means[...,-1,:], covs[...,-1,:,:], xs[...,i,:], ts[...,i-1,:], ts[...,i,:], mode, lat_pre)
            means = broadcastable_cat([means, mean.unsqueeze(-2)], dim=-2)
            covs = broadcastable_cat([covs, cov.unsqueeze(-3)], dim=-3)
            if return_gain:
                gains = broadcastable_cat([gains, gain.unsqueeze(-3)], dim=-3)

        if return_gain:
            return means, covs, gains
        else:
            return means, covs
        

    def cumulants_continuous(self, t_sample: Tensor, ts: Tensor, mode: LongTensor, means_f: Tensor, covs_f: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Parameters
        ----------
        t_sample: Tensor, shape (..., n_t_sample, 1)
        ts: Tensor, shape (..., max_n_time, 1)
        mode: LongTensor, shape (..., 1)
        means_f: Tensor, shape (..., max_n_time + 1, latent_dim)
        covs_f: Tensor, shape (..., max_n_time + 1, latent_dim, latent_dim)

        Returns
        -------
        means: Tensor, shape (..., n_t_sample, latent_dim)
        covs: Tensor, shape (..., n_t_sample, latent_dim, latent_dim)
        """
        idx_sample = searchsorted(ts[...,0], t_sample[...,0], right=False)   # (batch_size, n_t_sample)
        means_f_left = means_f.gather(-2, idx_sample[...,None].expand(*idx_sample.shape, means_f.shape[-1]))   # (batch_size, n_t_sample, latent_dim)
        covs_f_left = covs_f.gather(-3, idx_sample[...,None,None].expand(*idx_sample.shape, covs_f.shape[-2], covs_f.shape[-1]))   # (batch_size, n_t_sample, latent_dim, latent_dim)
        ts_ex = torch.cat([torch.zeros(ts.shape[:-2], device=ts.device)[...,None,None], ts], dim=-2)   # (batch_size, max_n_time + 1, 1)
        dt = t_sample - ts_ex.gather(-2, idx_sample[...,None].expand(*idx_sample.shape, 1))   # (batch_size, n_t_sample, 1)

        means = self.propagator.mean(means_f_left, dt, mode.unsqueeze(-1))   # (batch_size, n_t_sample, latent_dim)
        covs = self.propagator.evolve_cov(covs_f_left, dt, mode.unsqueeze(-1))   # (batch_size, n_t_sample, latent_dim, latent_dim)

        return means, covs



class Regressor(Module):
    def __init__(self, propagator: Propagator, filter: Filter):
        super().__init__()
        self.propagator = propagator
        self.filter = filter


    def cumulant(self, z1: Tensor, t0: Tensor, t1: Tensor, ts: Tensor, means_f: Tensor, covs_f: Tensor, return_gain: bool=False) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        """
        Parameters
        ----------
        z1: Tensor, shape (..., latent_dim)
        t0: Tensor, shape (..., 1)
        t1: Tensor, shape (..., 1)
        ts: Tensor, shape (..., max_n_time, 1)
        means_f: Tensor, shape (..., max_n_time + 1, latent_dim)
        covs_f: Tensor, shape (..., max_n_time + 1, latent_dim, latent_dim)

        Returns
        -------
        mean0: Tensor, shape (..., latent_dim)
        cov0: Tensor, shape (..., latent_dim, latent_dim)
        """
        ts_ex = torch.cat([torch.zeros(ts.shape[:-2], device=ts.device)[...,None,None], ts], dim=-2)
        idx_t0 = searchsorted(ts_ex[...,0], t0, right=True) - 1
        mean_f_left = torch.gather(means_f, -2, idx_t0[...,None].expand(*idx_t0.shape, means_f.shape[-1]))[...,0,:]
        cov_f_left = torch.gather(covs_f, -3, idx_t0[...,None,None].expand(*idx_t0.shape, covs_f.shape[-2], covs_f.shape[-1]))[...,0,:,:]
        dt0 = t0 - torch.gather(ts_ex, -2, idx_t0[...,None].expand(*idx_t0.shape, 1))[...,0,:]
        dt1 = t1 - torch.gather(ts_ex, -2, idx_t0[...,None].expand(*idx_t0.shape, 1))[...,0,:]
        mean_f0, cov_f0 = self.propagator.mean(mean_f_left, dt0), self.propagator.evolve_cov(cov_f_left, dt0)
        mean_f1, cov_f1 = self.propagator.mean(mean_f_left, dt1), self.propagator.evolve_cov(cov_f_left, dt1)

        cov_f1_inverse = safe_positive_definite_inverse(cov_f1, 'cov_f1')

        gain = self.propagator.evolve_matrix_left(cov_f0, t1 - t0).mT @ cov_f1_inverse
        mean0 = mean_f0 + einsum('...ij,...j->...i', gain, z1 - mean_f1)
        cov0 = cov_f0 - gain @ self.propagator.evolve_matrix_left(cov_f0, t1 - t0)
        cov0 = (cov0 + cov0.mT)/2

        if return_gain:
            return mean0, cov0, gain
        else:
            return mean0, cov0
        

    def cov_cholesky(self, t0: Tensor, t1: Tensor, ts: Tensor, mode: LongTensor, covs_f: Tensor, return_gain: bool=False) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        """
        Parameters
        ----------
        t0: Tensor, shape (..., n_t_sample, 1)
        t1: Tensor, shape (..., n_t_sample, 1)
        ts: Tensor, shape (..., max_n_time, 1)
        mode: LongTensor, shape (..., 1)
        covs_f: Tensor, shape (..., max_n_time + 1, latent_dim, latent_dim)
        return_gain: bool, default False

        Returns
        -------
        cov_cholesky0: Tensor, shape (..., n_t_sample, latent_dim, latent_dim)
        gain: Tensor, shape (..., n_t_sample, latent_dim, latent_dim)
        """
        idx_t0 = searchsorted(ts[...,0], t0[...,0], right=True) # (..., n_t_sample)
        cov_f_left = covs_f.gather(-3, idx_t0[...,None,None].expand(*idx_t0.shape, covs_f.shape[-2], covs_f.shape[-1])) # (..., n_t_sample, latent_dim, latent_dim)
        ts_ex = torch.cat([torch.zeros(ts.shape[:-2], device=ts.device)[...,None,None], ts], dim=-2)    # (..., max_n_time + 1, 1)
        dt0 = t0 - ts_ex.gather(-2, idx_t0[...,None].expand(*idx_t0.shape, 1))  # (..., n_t_sample, 1)
        dt1 = t1 - ts_ex.gather(-2, idx_t0[...,None].expand(*idx_t0.shape, 1))  # (..., n_t_sample, 1)
        cov_f0 = self.propagator.evolve_cov(cov_f_left, dt0, mode.unsqueeze(-2))    # (..., n_t_sample, latent_dim, latent_dim)
        cov_f1 = self.propagator.evolve_cov(cov_f_left, dt1, mode.unsqueeze(-2))    # (..., n_t_sample, latent_dim, latent_dim)

        cov_f1_inverse = safe_positive_definite_inverse(cov_f1, 'cov_f1')

        gain = self.propagator.evolve_matrix_left(cov_f0, t1 - t0, mode.unsqueeze(-2)).mT @ cov_f1_inverse
        cov0 = cov_f0 - gain @ self.propagator.evolve_matrix_left(cov_f0, t1 - t0, mode.unsqueeze(-2))
        cov0 = (cov0 + cov0.mT)/2
        cov_cholesky0 = safe_cholesky(cov0, 'cov0')

        if return_gain:
            return cov_cholesky0, gain
        else:
            return cov_cholesky0
        

    def forward(self, z1: Tensor, t0: Tensor, t1: Tensor, ts: Tensor, means_f: Tensor, covs_f: Tensor):
        mean0, cov0 = self.cumulant(z1, t0, t1, ts, means_f, covs_f)
        cov_cholesky0 = safe_cholesky(cov0, 'cov0')
        z0 = mean0 + einsum('...ij,...j->...i', cov_cholesky0, torch.randn_like(mean0))

        return z0
    


class Smoother(Module):
    def __init__(self, propagator: Propagator, filter: Filter, regressor: Regressor):
        super().__init__()
        self.propagator = propagator
        self.filter = filter
        self.regressor = regressor


    def step(self, mean1: Tensor, cov1: Tensor, t0: Tensor, t1: Tensor, ts: Tensor, means_f: Tensor, covs_f: Tensor):
        mean0, cov0, gain = self.regressor.cumulant(mean1, t0, t1, ts, means_f, covs_f, return_gain=True)
        cov0 = cov0 + gain @ cov1 @ gain.mT
        cov0 = (cov0 + cov0.mT)/2

        return mean0, cov0
    

    def cumulants(self, ts: Tensor, n_time: LongTensor, means_f: Tensor, covs_f: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Parameters
        ----------
        ts: Tensor, shape (..., max_n_time, 1)
        n_time: Tensor, shape (...,)
        means_f: Tensor, shape (..., max_n_time + 1, latent_dim)
        covs_f: Tensor, shape (..., max_n_time + 1, latent_dim, latent_dim)

        Returns
        -------
        means: Tensor, shape (..., max_n_time, latent_dim)
        covs: Tensor, shape (..., max_n_time, latent_dim, latent_dim)
        """
        means = torch.zeros_like(means_f[...,:-1,:])
        covs = torch.eye(means.shape[-1], dtype=covs_f.dtype, device=covs_f.device).tile(means.shape[:-1] + (1, 1))
        idx = torch.meshgrid(*[torch.arange(s) for s in means.shape[:-2]], indexing='ij')
        idx1 = list(idx) + [n_time - 1]
        idx2 = list(idx) + [n_time]
        means[idx1], covs[idx1] = means_f[idx2], covs_f[idx2]

        for i in range(means.shape[-2] - 1, 0, -1):
            mask = (n_time > i)
            means[mask,i-1], covs[mask,i-1] = self.step(means[mask,i,:], covs[mask,i,:,:], ts[mask,i-1,:], ts[mask,i,:], ts[mask], means_f[mask], covs_f[mask])

        return means, covs
    

    def cumulants_continuous(self, t_sample: Tensor, ts: Tensor, n_time: LongTensor, means_f: Tensor, covs_f: Tensor, means_s: Tensor, covs_s: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Parameters
        ----------
        t_sample: Tensor, shape (..., n_t_sample)
        ts: Tensor, shape (..., max_n_time, 1)
        n_time: Tensor, shape (..., )
        means_f: Tensor, shape (..., max_n_time + 1, latent_dim)
        covs_f: Tensor, shape (..., max_n_time + 1, latent_dim, latent_dim)
        means_s: Tensor, shape (..., max_n_time, latent_dim)
        covs_s: Tensor, shape (..., max_n_time, latent_dim, latent_dim)

        Returns
        -------
        means: Tensor, shape (..., n_t_sample, latent_dim)
        covs: Tensor, shape (..., n_t_sample, latent_dim, latent_dim)
        """
        idx_sample = searchsorted(ts[...,0], t_sample, right=False)   # (batch_size, n_t_sample)
        idx_sample_clipped = torch.clip(idx_sample, max=n_time[...,None] - 1)
        means_s_right = torch.gather(means_s, -2, idx_sample_clipped[...,None].expand(*idx_sample.shape, means_s.shape[-1]))   # (batch_size, n_t_sample, latent_dim)
        covs_s_right = torch.gather(covs_s, -3, idx_sample_clipped[...,None,None].expand(*idx_sample.shape, covs_s.shape[-2], covs_s.shape[-1]))  # (batch_size, n_t_sample, latent_dim, latent_dim)
        t_right = torch.gather(ts, -2, idx_sample_clipped[...,None].expand(*idx_sample.shape, ts.shape[-1]))   # (batch_size, n_t_sample, 1)
        
        means, covs, gains = self.regressor.cumulant(means_s_right, t_sample[...,None], t_right,
                                                     ts.unsqueeze(-3).repeat([1]*(t_sample.dim()-1) + [t_sample.shape[-1], 1, 1]),
                                                     means_f.unsqueeze(-3).repeat([1]*(t_sample.dim()-1) + [t_sample.shape[-1], 1, 1]),
                                                     covs_f.unsqueeze(-4).repeat([1]*(t_sample.dim()-1) + [t_sample.shape[-1], 1, 1, 1]), return_gain=True)  # (batch_size, n_t_sample, latent_dim), (batch_size, n_t_sample, latent_dim, latent_dim), (batch_size, n_t_sample, latent_dim, latent_dim)
        covs = covs + einsum('...ij,...jk,...lk->...il', gains, covs_s_right, gains)

        # For t_sample > ts[n_time-1], means and covs are calculated by propagator.
        means_future_left = means_s_right[idx_sample == n_time[...,None]]
        covs_future_left = covs_s_right[idx_sample == n_time[...,None]]
        dt_future = t_sample[idx_sample == n_time[...,None]][...,None] - ts.gather(-2, n_time[...,None,None] - 1).repeat([1]*(t_sample.dim()-1) + [t_sample.shape[-1], 1])[idx_sample == n_time[...,None]]
        means[idx_sample == n_time[...,None]] = self.propagator.mean(means_future_left, dt_future)
        covs[idx_sample == n_time[...,None]] = self.propagator.evolve_cov(covs_future_left, dt_future)

        return means, covs



class OUFlow(Module):
    """
    Class for the OUFlow model.

    Parameters
    ----------
    dim: int
        The dimension of the target variable.
    latent_dim: int
        The dimension of the latent variable.
    num_mode: int
        The number of modes.
    normalizing_flow: dict
        The configuration of the normalizing flow.
        It should be a dictionary with the following keys:
        - name: str, the name of the normalizing flow. Currently, only 'ConditionalRealNVP' is supported.
        - hidden_dim: int, the dimension of the hidden layers.
        - activation: str, the activation function. 'relu', 'tanh', or 'softplus'.
        - num_coupling: int, the number of coupling layers.
    
    period_scale: float, default=1.0
        The scale of the initial period of the latent variable.
    x_scale: float or Tensor with shape (dim,), default=1.0
        The scale of the target variable for normalization.
    x_base: float or Tensor with shape (dim,), default=0.0
        The offset of the target variable for normalization.
    t_scale: float or Tensor, default=1.0
        The scale of the time for normalization.
    double_precision: bool, default=True
        Whether to use double precision.
    """
    def __init__(self, dim: int, latent_dim: int, num_mode: int,
                 normalizing_flow, period_scale: float=1.0,
                 sys_logstd_initial: float=-4.0, obs_logvar_initial: float=-8.0,
                 x_scale=1.0, x_base=0.0, t_scale=1.0, double_precision: bool=True
                 ):
        super().__init__()
        self.dim = dim
        self.latent_dim = latent_dim
        self.half_dim = latent_dim // 2
        self.num_mode = num_mode
        self.dtype = torch.float64 if double_precision else torch.float32

        x_scale = torch.as_tensor(x_scale, dtype=self.dtype)
        x_base = torch.as_tensor(x_base, dtype=self.dtype)
        t_scale = torch.as_tensor(t_scale, dtype=self.dtype)
        if x_scale.ndim == 0:
            x_scale = x_scale * torch.ones(dim, dtype=self.dtype)
        if x_base.ndim == 0:
            x_base = x_base * torch.ones(dim, dtype=self.dtype)
        if t_scale.ndim == 0:
            t_scale = t_scale * torch.ones(1, dtype=self.dtype)
        self.register_buffer('x_scale', x_scale)
        self.register_buffer('x_base', x_base)
        self.register_buffer('t_scale', t_scale)

        self.prior = Prior(latent_dim, num_mode=num_mode, dtype=self.dtype)
        self.propagator = Propagator(latent_dim, period_scale/t_scale, num_mode=num_mode, sys_logstd_initial=sys_logstd_initial, dtype=self.dtype)
        self.observer = Observer(dim, latent_dim, normalizing_flow, num_mode=num_mode, obs_logvar_initial=obs_logvar_initial, dtype=self.dtype)

        self.filter = Filter(self.prior, self.propagator, self.observer)
        self.regressor = Regressor(self.propagator, self.filter)
        self.smoother = Smoother(self.propagator, self.filter, self.regressor)

        self._mixture_weights = Parameter(torch.zeros(num_mode, dtype=self.dtype))


    def mixture_weights(self) -> Tensor:
        """
        Returns the mixture weights.

        Returns
        -------
        Tensor, shape (num_mode,)
        """
        return torch.softmax(self._mixture_weights, dim=0)
    

    def normalize(self, x=None, t=None) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        """
        Normalize the target variable and time.

        Parameters
        ----------
        x: Tensor, shape (..., dim)
            The target variable.
        t: Tensor, shape (..., 1)
            The time.

        Returns
        -------
        x: Tensor, shape (..., dim)
            The normalized target variable.
        t: Tensor, shape (..., 1)
            The normalized time.
        """
        if x is not None:
            x = (x - self.x_base) / self.x_scale
        if t is not None:
            t = t / self.t_scale

        if t is None:
            return x
        elif x is None:
            return t
        else:
            return x, t
    

    def denormalize(self, x=None, t=None) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        """
        Denormalize the target variable and time.
        
        Parameters
        ----------
        x: Tensor, shape (..., dim)
            The normalized target variable.
        t: Tensor, shape (..., 1)
            The normalized time.

        Returns
        -------
        x: Tensor, shape (..., dim)
            The denormalized target variable.
        t: Tensor, shape (..., 1)
            The denormalized time.
        """
        if x is not None:
            x = self.x_base + self.x_scale * x
        if t is not None:
            t = t / self.t_scale

        if t is None:
            return x
        elif x is None:
            return t
        else:
            return x, t
        

    def nll(self, xs: PackedSequence, ts: PackedSequence, t0: Tensor=None) -> Tuple[Tensor, Tensor]:
        """
        Compute the negative log-likelihood.

        Parameters
        ----------
        xs: PackedSequence, shape (batch_size, max_n_time, dim)
            The time series of the target variable.
        ts: PackedSequence, shape (batch_size, max_n_time, 1)
            The time points of the time series.
        t0: Tensor, shape (batch_size, 1), default=None
            The initial time of the time series.
        
        Returns
        -------
        nll: Tensor, shape (...)
            Negative log-likelihood.
        nll_mode: Tensor, shape (..., num_mode)
            Negative log-likelihood for each mode.
        """
        xs_padded = pad_packed_sequence(xs, batch_first=True, padding_value=0)[0]
        ts_padded = pad_packed_sequence(ts, batch_first=True, padding_value=ts.data.max() + 1e-4)[0]
        mask = (ts_padded != ts.data.max() + 1e-4)[...,0]
        if t0 is not None:
            ts_padded = ts_padded - t0[...,None]
        xs_padded, ts_padded = self.normalize(x=xs_padded, t=ts_padded)
        n_time = mask.sum(dim=-1)
        mode = torch.arange(self.num_mode, dtype=torch.long, device=device).tile(xs_padded.shape[:-2] + (1,))[...,None]   # (batch_size, num_mode, 1)

        lat_pre = self.observer.latent_precision()  # (num_mode, latent_dim, latent_dim)
        xs_proj = self.observer.projection(xs_padded.unsqueeze(-3), ts_padded.unsqueeze(-3), mode[...,None])  # (batch_size, num_mode, n_time, latent_dim)
        means_f, covs_f = self.filter.cumulants(xs_padded.unsqueeze(-3), ts_padded.unsqueeze(-3), mode) # (batch_size, num_mode, n_time + 1, latent_dim), (batch_size, num_mode, n_time + 1, latent_dim, latent_dim)
        means_f, covs_f = means_f[...,:-1,:], covs_f[...,:-1,:,:]   # (batch_size, num_mode, n_time, latent_dim), (batch_size, num_mode, n_time, latent_dim, latent_dim)
        dts = (ts_padded - torch.cat([torch.zeros(ts_padded.shape[:-2] + (1, 1), device=xs_padded.device), ts_padded[...,:-1,:]], dim=-2)).unsqueeze(-3)    # (batch_size, 1, n_time, 1)
        means_evo, covs_evo = self.propagator.mean(means_f, dts, mode.unsqueeze(-2)), self.propagator.evolve_cov(covs_f, dts, mode.unsqueeze(-2))    # (batch_size, num_mode, n_time, latent_dim), (batch_size, num_mode, n_time, latent_dim, latent_dim)

        logdet_covs_evo = safe_logdet(covs_evo, 'covs_evo')  # (batch_size, num_mode, n_time)
        covs_evo_inv = safe_positive_definite_inverse(covs_evo, 'covs_evo')  # (batch_size, num_mode, n_time, latent_dim, latent_dim)
        logdet_covs_evo_lat_pre = safe_logdet(covs_evo_inv + lat_pre.unsqueeze(-3), 'covs_evo_lat_pre')  # (batch_size, num_mode, n_time)

        nll_mode = 0.5*(
                    self.dim*np.log(2*np.pi)
                    + self.observer.logvar(mode).sum(dim=-1, keepdim=True) + logdet_covs_evo + logdet_covs_evo_lat_pre
                    + einsum('...ti,...mi->...mt', self.observer.flow(xs_padded, ts_padded)**2, torch.exp(-self.observer.logvar(mode)))
                    + einsum('...mti,mij,...mtj->...mt', means_evo, lat_pre, means_evo)
                    - 2*einsum('...ti,...ti->...t', xs_proj, means_evo)
                    - maharanobis_norm_squared(covs_evo_inv + lat_pre.unsqueeze(-3), xs_proj - einsum('mij,...mtj->...mti', lat_pre, means_evo))
                ) - self.observer.flow.log_det_jacobian(xs_padded, ts_padded).unsqueeze(-2)   # (batch_size, num_mode, n_time)
        nll_mode = torch.where(mask.unsqueeze(-2), nll_mode, torch.zeros_like(nll_mode)).sum(dim=-1) # (batch_size, num_mode)
        
        mixture_weights = self.mixture_weights()    # (num_mode, )
        nll_mode_min = nll_mode.min(dim=-1)[0]   # (batch_size, )
        nll = nll_mode_min - torch.log((torch.exp(-nll_mode + nll_mode_min[:,None])*mixture_weights).sum(dim=-1))    # (batch_size, )

        nll = nll / n_time  # (batch_size, )
        return nll, nll_mode
    

    def losses(self, xs: PackedSequence, ts: PackedSequence, t0: Tensor=None) -> Tensor:
        """
        Compute the losses.

        Parameters
        ----------
        xs: PackedSequence, shape (batch_size, max_n_time, dim)
            The time series of the target variable.
        ts: PackedSequence, shape (batch_size, max_n_time, 1)
            The time points of the time series.
        t0: Tensor, shape (batch_size, 1), default None
            The initial time of the time series.
        
        Returns
        -------
        nll: Tensor, shape (batch_size,)
            The negative log-likelihood.
        nll_mode_mean: Tensor, shape (batch_size,)
            The mean mode-specific negative log-likelihoods.
        responsibility_imbalance: Tensor, shape (1,)
            The responsibility imbalance.
        """
        nll, nll_mode = self.nll(xs, ts, t0)    # (batch_size,), (batch_size, num_mode)
        posterior_weights = torch.softmax(- nll_mode + self._mixture_weights, dim=-1)    # (batch_size, num_mode)
        nll_mode_mean = nll_mode.mean(dim=-1)    # (batch_size,)

        weights_mean = posterior_weights.mean(dim=tuple(range(posterior_weights.dim()-1))) # (batch_size,)
        responsibility_imbalance = (weights_mean * torch.log(weights_mean + 1e-10)).sum()

        return nll, nll_mode_mean, responsibility_imbalance
    

    def posterior_mixture_weights(self, xs_cond: PackedSequence, ts_cond: PackedSequence, t0: Tensor=None) -> Tensor:
        """
        Compute the posterior mixture weights.

        Parameters
        ----------
        xs_cond: PackedSequence, shape (batch_size, max_n_time, dim)
            The time series of the target variable.
        ts_cond: PackedSequence, shape (batch_size, max_n_time, 1)
            The time points of the time series.
        t0: Tensor, shape (batch_size, 1), default None
            The initial time of the time series.

        Returns
        -------
        mixture_weights: Tensor, shape (batch_size, num_mode)
            The posterior mixture weights.
        """
        _, nll_mode = self.nll(xs_cond, ts_cond, t0=t0)   # (batch_size, num_mode)
        mixture_weights = torch.softmax(- nll_mode + self._mixture_weights, dim=-1)    # (batch_size, num_mode)

        return mixture_weights
    

    def sample_mode(self, n_sample: int=1, batch_size: int=1, xs_cond: PackedSequence=None, ts_cond: PackedSequence=None, t0: Tensor=None) -> Tensor:
        """
        Sample the mode.

        Parameters
        ----------
        n_sample: int, default 1
            The number of samples.
        batch_size: int, default 1
            The batch size.
        xs_cond: PackedSequence, shape (batch_size, max_n_time, dim), default None
            The time series of the observed target variable. If None, the modes are sampled from the prior.
        ts_cond: PackedSequence, shape (batch_size, max_n_time, 1), default None
            The time points of the observed time series. If None, the modes are sampled from the prior.
        t0: Tensor, shape (batch_size, 1), default None
            The initial time of the time series.

        Returns
        -------
        mode: Tensor, shape (batch_size, n_sample, 1)
            The sampled modes.
        """
        if xs_cond is None and ts_cond is None: # Prior sampling
            mixture_weights = self.mixture_weights()    # (num_mode, )
            mixture_weights = mixture_weights[None].tile([batch_size, 1])   # (batch_size, num_mode)
        else:   # Posterior sampling
            mixture_weights = self.posterior_mixture_weights(xs_cond, ts_cond, t0)  # (batch_size, num_mode)
        mode = torch.multinomial(mixture_weights, n_sample, replacement=True)[...,None]   # (batch_size, n_sample, 1)

        return mode


    def forward(self, ts_pred: Tensor, n_sample: int=1, xs_cond: PackedSequence=None, ts_cond: PackedSequence=None, mode: int=None) -> Tensor:
        """
        Sample the time series of the target variable.

        Parameters
        ----------
        ts_pred: Tensor, shape (batch_size, n_time_pred, 1)
            The time points of the predicted time series.
        n_sample: int, default 1
            The number of samples.
        xs_cond: PackedSequence, shape (batch_size, max_n_time_cond, dim), default None
            The time series of the observed target variable. If None, the time series is sampled from the prior.
        ts_cond: PackedSequence, shape (batch_size, max_n_time_cond, 1), default None
            The time points of the observed time series. If None, the time series is sampled from the prior.
        mode: int, default None
            The mode to sample from for investigation. xs_cond and ts_cond must not be provided if mode is specified.

        Returns
        -------
        xs_pred: Tensor, shape (batch_size, n_sample, max_n_time_pred, dim)
            The time series of the target variable.
        """
        if xs_cond is None and ts_cond is None: # Prior sampling
            t0 = ts_pred.min(dim=-2).values  # (batch_size, 1)
            ts_pred = self.normalize(t=ts_pred - t0[...,None]).unsqueeze(-3)   # (batch_size, 1, max_n_time_pred, 1)
            if mode is None:
                mode = self.sample_mode(n_sample=n_sample, batch_size=ts_pred.shape[0]).unsqueeze(-2)   # (batch_size, n_sample, 1, 1)
            else:
                mode = torch.full(ts_pred.shape[:-3] + (n_sample, 1, 1), mode, dtype=torch.long, device=ts_pred.device)   # (batch_size, n_sample, 1, 1)

            noise_trans_mat = self.prior.cov_cholesky(mode)   # (batch_size, n_sample, 1, latent_dim, latent_dim)
            ts_pred_ex = torch.cat([torch.zeros(ts_pred.shape[:-2] + (1, 1), device=ts_pred.device), ts_pred], dim=-2)   # (batch_size, 1, max_n_time_pred + 1, 1)
            dts = ts_pred_ex[...,1:,:] - ts_pred_ex[...,:-1,:]   # (batch_size, 1, max_n_time_pred, 1)
            noise_trans_mat = torch.cat([
                noise_trans_mat,
                self.propagator.evolve_matrix_left(self.propagator.cov_cholesky(dts, mode), -ts_pred, mode)
                ], dim=-3)  # (batch_size, n_sample, max_n_time_pred + 1, latent_dim, latent_dim)
            noise = torch.randn(noise_trans_mat.shape[:-1], dtype=noise_trans_mat.dtype, device=noise_trans_mat.device)   # (batch_size, n_sample, max_n_time_pred + 1, latent_dim)
            noise = einsum('...ij,...j->...i', noise_trans_mat, noise)  # (batch_size, n_sample, max_n_time_pred + 1, latent_dim)
            noise[:,:,0] = noise[:,:,0] + self.prior.mean(mode[...,0,:])   # (batch_size, n_sample, latent_dim)
            zs_pred = noise.cumsum_(dim=-2)[...,1:,:]   # (batch_size, n_sample, max_n_time_pred, latent_dim)
            zs_pred = self.propagator.mean(zs_pred, ts_pred, mode)   # (batch_size, n_sample, max_n_time_pred, latent_dim)
        else:   # Posterior sampling
            t_max = max(ts_pred.max(), ts_cond.data.max())
            xs_cond_padded = pad_packed_sequence(xs_cond, batch_first=True, padding_value=0)[0]  # (batch_size, max_n_time_cond, dim)
            ts_cond_padded, n_time_cond = pad_packed_sequence(ts_cond, batch_first=True, padding_value=t_max + 1e-4)  # (batch_size, max_n_time_cond, 1), (batch_size, )
            t0 = broadcastable_cat([ts_cond_padded, ts_pred], dim=-2).min(dim=-2).values  # (batch_size, 1)
            ts_cond_padded = ts_cond_padded - t0[...,None]  # (batch_size, max_n_time_cond, 1)
            n_time_cond = n_time_cond.to(xs_cond_padded.device)  # (batch_size, )
            xs_cond_padded, ts_cond_padded = self.normalize(x=xs_cond_padded, t=ts_cond_padded)  # (batch_size, max_n_time_cond, dim), (batch_size, max_n_time_cond, 1)
            ts_cond_padded = ts_cond_padded.contiguous()
            ts_pred = self.normalize(t=ts_pred - t0[...,None]) + 1e-6  # (batch_size, max_n_time_pred, 1)
            batch_size = max(xs_cond_padded.shape[0], ts_pred.shape[0])
            mode = self.sample_mode(n_sample=n_sample, batch_size=batch_size, xs_cond=xs_cond, ts_cond=ts_cond, t0=t0).unsqueeze(-2)    # (batch_size, n_sample, 1, 1)

            # Concatenate ts_cond_padded and ts_pred and sort
            ts_cat = broadcastable_cat([ts_cond_padded, ts_pred], dim=-2)   # (batch_size, max_n_time_cond + max_n_time_pred, 1)
            ts_cat, indices = torch.sort(ts_cat, dim=-2, stable=True)   # (batch_size, max_n_time_cond + max_n_time_pred, 1), (batch_size, max_n_time_cond + max_n_time_pred, 1)
            indices_inv = torch.argsort(indices, dim=-2)    # (batch_size, max_n_time_cond + max_n_time_pred, 1)
            idx_boundary = indices_inv.gather(-2, n_time_cond[...,None,None] - 1)   # (batch_size, 1, 1)
            ts_boundary = ts_cond_padded.gather(-2, n_time_cond[...,None,None] - 1)   # (batch_size, 1, 1)

            # Separate ts_cat by idx_boundary and pad with t_max + 1e-4
            max_len_backward = idx_boundary.max().item()
            n_time_forward = n_time_cond[:,None,None] + ts_pred.shape[1] - idx_boundary  # (batch_size, 1, 1)
            max_len_forward = n_time_forward.max().item()
            ts_backward = torch.zeros(ts_cat.shape[:-2] + (max_len_backward, 1), dtype=ts_cat.dtype, device=ts_cat.device)   # (batch_size, max_len_backward, 1)
            ts_forward = torch.full(ts_cat.shape[:-2] + (max_len_forward, 1), t_max + 1e-4, dtype=ts_cat.dtype, device=ts_cat.device)   # (batch_size, max_len_forward, 1)
            mask_backward = torch.arange(max_len_backward, device=ts_cat.device)[None,:,None] >= max_len_backward - idx_boundary   # (batch_size, max_len_backward, 1)
            mask_forward = torch.arange(max_len_forward, device=ts_cat.device)[None,:,None] < n_time_forward  # (batch_size, max_len_forward, 1)
            mask_backward_cat = torch.arange(ts_cat.shape[1], device=ts_cat.device)[None,:,None] < idx_boundary   # (batch_size, max_len_backward + max_len_forward, 1)
            # mask_forward_cat = (torch.arange(ts_cat.shape[1], device=ts_cat.device)[None,:,None] >= idx_boundary) & (torch.arange(ts_cat.shape[1], device=ts_cat.device)[None,:,None] < idx_boundary + ts_pred.shape[1] + 1)   # (batch_size, max_len_backward + max_len_forward, 1)
            mask_forward_cat = (torch.arange(ts_cat.shape[1], device=ts_cat.device)[None,:,None] >= idx_boundary) & (ts_cat < t_max + 1e-4)   # (batch_size, max_len_backward + max_len_forward, 1)
            ts_backward[mask_backward] =  ts_cat[mask_backward_cat]   # (batch_size, max_len_backward, 1)
            ts_forward[mask_forward] = ts_cat[mask_forward_cat]   # (batch_size, max_len_forward, 1)

            means_f_all_modes, covs_f_all_modes = self.filter.cumulants(xs_cond_padded.unsqueeze(-3), ts_cond_padded.unsqueeze(-3), torch.arange(self.num_mode, dtype=torch.long, device=device)[None,:,None])   # (batch_size, num_mode, max_n_time_cond + 1, latent_dim), (batch_size, num_mode, max_n_time_cond + 1, latent_dim, latent_dim)
            means_f = means_f_all_modes.gather(1, mode.expand(-1, -1, means_f_all_modes.shape[-2], self.latent_dim))   # (batch_size, n_sample, max_n_time_cond + 1, latent_dim)
            covs_f = covs_f_all_modes.gather(1, mode[...,None].expand(-1, -1, means_f_all_modes.shape[-2], self.latent_dim, self.latent_dim))   # (batch_size, n_sample, max_n_time_cond + 1, latent_dim, latent_dim)

            # Forward generation
            noise_trans_mat = safe_cholesky(covs_f[...,-1:,:,:])   # (batch_size, n_sample, 1, latent_dim, latent_dim)
            ts_forward = (ts_forward - ts_boundary).unsqueeze(-3)   # (batch_size, 1, max_len_forward, 1)
            dts = ts_forward[...,1:,:] - ts_forward[...,:-1,:]  # (batch_size, 1, max_len_forward - 1, 1)
            noise_trans_mat = torch.cat([
                noise_trans_mat,
                self.propagator.evolve_matrix_left(self.propagator.cov_cholesky(dts, mode), -ts_forward[...,1:,:], mode)
                ], dim=-3)  # (batch_size, n_sample, max_len_forward, latent_dim, latent_dim)
            noise_forward = torch.randn_like(noise_trans_mat[...,0])   # (batch_size, n_sample, max_len_forward, latent_dim)
            zs_forward = einsum('...ij,...j->...i', noise_trans_mat, noise_forward)  # (batch_size, n_sample, max_len_forward, latent_dim)
            zs_forward[:,:,0] = zs_forward[:,:,0] + means_f[...,-1,:]   # (batch_size, n_sample, latent_dim)
            zs_forward = zs_forward.cumsum_(dim=-2)   # (batch_size, n_sample, max_len_forward, latent_dim)
            zs_forward = self.propagator.mean(zs_forward, ts_forward, mode)   # (batch_size, n_sample, max_len_forward, latent_dim)

            # Backward generation
            ts_cond_ex = torch.cat([torch.zeros(ts_cond_padded.shape[:-2] + (1, 1), device=ts_cond_padded.device), ts_cond_padded], dim=-2).unsqueeze(-3)   # (batch_size, 1, max_n_time_cond, 1)
            covs_f_evo = self.propagator.evolve_cov(covs_f[...,:-1,:,:], ts_cond_ex[...,1:,:] - ts_cond_ex[...,:-1,:], mode)   # (batch_size, n_sample, max_n_time_cond, latent_dim, latent_dim)
            gains = self.propagator.evolve_matrix_left(covs_f[...,:-1,:,:], ts_cond_ex[...,1:,:] - ts_cond_ex[...,:-1,:], mode).mT @ safe_positive_definite_inverse(covs_f_evo)   # (batch_size, n_sample, max_n_time_cond, latent_dim, latent_dim)
            for i in range(gains.shape[-3] - 1, 0, -1):
                gains[...,i-1,:,:] = gains[...,i-1,:,:] @ gains[...,i,:,:]
            idx_left = searchsorted(ts_cond_padded[...,0], ts_backward[...,0], right=True).unsqueeze(-2)    # (batch_size, 1, max_len_backward)
            means_f_left = means_f.gather(-2, idx_left[...,None].expand(-1, n_sample, -1, self.latent_dim))   # (batch_size, n_sample, max_len_backward, latent_dim)
            covs_f_left = covs_f.gather(-3, idx_left[...,None,None].expand(-1, n_sample, -1, self.latent_dim, self.latent_dim))  # (batch_size, n_sample, max_len_backward, latent_dim, latent_dim)
            ts_left = ts_cond_ex.gather(-2, idx_left[...,None].expand(-1, 1, -1, 1))   # (batch_size, 1, max_len_backward, 1)
            dt0 = ts_backward.unsqueeze(-3) - ts_left   # (batch_size, 1, max_len_backward, 1)
            dt1 = torch.cat([ts_backward[...,1:,:], ts_boundary], dim=-2).unsqueeze(-3) - ts_left   # (batch_size, 1, max_len_backward, 1)
            means0 = self.propagator.mean(means_f_left, dt0, mode)  # (batch_size, n_sample, max_len_backward, latent_dim)
            means1 = self.propagator.mean(means_f_left, dt1, mode)  # (batch_size, n_sample, max_len_backward, latent_dim)
            covs0 = self.propagator.evolve_cov(covs_f_left, dt0, mode)  # (batch_size, n_sample, max_len_backward, latent_dim, latent_dim)
            gains_left = gains.gather(-3, idx_left[...,None,None].expand(-1, n_sample, -1, self.latent_dim, self.latent_dim))    # (batch_size, n_sample, max_len_backward, latent_dim, latent_dim)
            gains_evo = self.propagator.evolve_matrix_left(covs0, -dt0, mode).mT @ safe_positive_definite_inverse(covs_f_left) @ gains_left   # (batch_size, n_sample, max_len_backward, latent_dim, latent_dim)
            gains_evo_inv = torch.inverse(gains_evo)   # (batch_size, n_sample, max_len_backward, latent_dim, latent_dim)
            gains_evo_inv = torch.cat([gains_evo_inv, torch.eye(self.latent_dim, device=ts_backward.device)[None,None].expand(batch_size, n_sample, 1, self.latent_dim, self.latent_dim)], dim=-3)   # (batch_size, n_sample, max_len_backward + 1, latent_dim, latent_dim)
            cov_cholesky_r = self.regressor.cov_cholesky(ts_backward.unsqueeze(-3),
                                                         torch.cat([ts_backward[:,1:], ts_boundary], dim=-2).unsqueeze(-3),
                                                         ts_cond_padded.unsqueeze(-3),
                                                         mode[...,0],
                                                         covs_f)  # (batch_size, n_sample, max_len_backward, latent_dim, latent_dim)

            noise_backward = torch.randn_like(means0)   # (batch_size, n_sample, max_len_backward, latent_dim)
            zs_backward = einsum('...ij,...j->...i', gains_evo_inv[...,:-1,:,:], means0 + einsum('...ij,...j->...i', cov_cholesky_r, noise_backward)) - einsum('...ij,...j->...i', gains_evo_inv[...,1:,:,:], means1)   # (batch_size, n_sample, max_len_backward, latent_dim)
            zs_backward = torch.cat([zs_backward, zs_forward[...,:1,:]], dim=-2)   # (batch_size, n_sample, max_len_backward + 1, latent_dim)
            zs_backward = zs_backward.flip(-2).cumsum_(dim=-2).flip(-2)[...,:-1,:]   # (batch_size, n_sample, max_len_backward, latent_dim)
            zs_backward = einsum('...ij,...j->...i', gains_evo, zs_backward)   # (batch_size, n_sample, max_len_backward, latent_dim)

            # Concatenate zs_backward and zs_forward and extract zs_pred
            zs_pred = torch.cat([zs_backward, zs_forward], dim=-2)  # (batch_size, n_sample, max_len_backward + max_len_forward, latent_dim)
            indices_pred = indices_inv[...,-ts_pred.shape[-2]:,:] + max_len_backward - idx_boundary   # (batch_size, max_n_time_pred, 1)
            zs_pred = zs_pred.gather(-2, indices_pred.unsqueeze(-3).expand(-1, n_sample, -1, self.latent_dim))   # (batch_size, n_sample, max_n_time_pred, latent_dim)
            ts_pred = ts_pred.unsqueeze(-3)

        xs_pred = self.observer(zs_pred, ts_pred, mode)    # (batch_size, n_sample, max_n_time_pred, dim)
        xs_pred = self.denormalize(x=xs_pred)   # (batch_size, n_sample, max_n_time_pred, dim)

        return xs_pred
    