import torch
import math
import logging

from torch.utils.data import DataLoader

temperature = 1.0  # Default temperature for sampling

def tang_proj(manifold, y, base):
    if y.ndim != 2:
        raise ValueError(f"Input y must have shape (batch, dim), got {y.shape}")
    if base.ndim != 2:
        raise ValueError(f"Input base must have shape (batch, dim), got {base.shape}")
    
    return manifold.project_onto_tangent_space(y, base_point=base)


def G_fwd_overdamped(manifold, x_prev, x_next, sigmas, dt, drift):
    """
    G(x_prev, x_next) = (x_next - x_prev + σ_k² * b(x_prev) * Δt) / (σ_k√Δt)
    """
    delta_x = x_next - x_prev + drift * dt  # [N·B,D]
    
    # Project onto the tangent space of the manifold
    delta_x_proj = tang_proj(manifold, delta_x, x_prev)
    
    # Normalize by sigma_k \sqrt{dt}
    G_list = delta_x_proj / (sigmas * math.sqrt(dt))  # [N·B,D]
    
    return G_list

def G_bwd_overdamped(manifold, x_prev, x_next, sigmas, dt, drift):
    """
    G(x_prev, x_next) = (x_next - x_prev - σ_k² * b(x_next) * Δt) / (σ_k√Δt)
    """
    delta_x = x_prev - x_next - drift * dt / 2  # [N·B,D]
    
    # Project onto the tangent space of the manifold
    delta_x_proj = tang_proj(manifold, delta_x, x_next)
    
    # Normalize by sigma_k \sqrt{dt}
    G_list = delta_x_proj / (sigmas * math.sqrt(dt))  # [N·B,D]
    
    return G_list

def G_fwd_underdamped_EM(manifold, x_prev, x_next, v_prev, sigmas, dt, drift, gamma, a, sqrt_1ma2):
    # mu_f = a * v_prev + drift
    # mu_f = a * v_prev +  (1 - a ) / gamma * drift
    mu_f = a * v_prev +  sigmas **2 * dt * drift

    mu_f = x_prev + (sigmas ** 2) * dt * mu_f

    G_list = tang_proj(manifold, x_next - mu_f, x_prev)
    G_list = G_list / (sqrt_1ma2 * (sigmas ** 2) * dt)  # [N·B,D]

    return G_list


def G_bwd_underdamped_EM(manifold, x_prev, x_next, v_next, sigmas, dt, drift, gamma, a, sqrt_1ma2, temp= None):
    if temp is None:
        temp = 1.0
    # mu_b = a * v_next + drift * temp
    # mu_b = a * v_next + (1 - a) / gamma * drift * temp
    mu_b = a * v_next + (sigmas **2) * dt * drift * temp

    # print(f"current temp {temp}")
    mu_b = x_next - (sigmas ** 2) * dt * mu_b

    G_list = tang_proj(manifold, x_prev - mu_b, x_next)
    G_list = G_list / (sqrt_1ma2 * (sigmas ** 2) * dt)  # [N·B,D]
    
    return G_list


def loss_overdamped_path(manifold, x_hist, score_net, func_b, sigmas, dt):
    """
    x_hist  : [N+1,B,D]  entire forward path (from sampler)
    sigmas  : [N]        σ_k   ( NOT σ_k√Δt )
    dt      : scalar Δt
    """
    N       = sigmas.shape[0]
    B       = x_hist.shape[1]
    device  = x_hist.device
    sigmas  = sigmas.repeat_interleave(B)
    sigmas  = sigmas.reshape(-1,1)  # [N·B]
    

    # time grid for the score network
    t_vec   = torch.linspace(0., N*dt, N+1, device=device)[1:]     # t_{k+1}
    t_full  = t_vec[:,None].expand(N,B).reshape(-1)                # [N·B]

    x_next  = x_hist[1:].reshape(-1, x_hist.size(-1))             # [N·B,D]
    x_prev  = x_hist[:-1].reshape_as(x_next)

    b_val   = func_b(x_next)                                      # [N·B,D]
    scores  = score_net(x_next, t_full)                           # [N·B,D]

    drift   = sigmas**2 * (scores - b_val)  # [N·B,D]

    G_bwd_list  = G_bwd_overdamped(manifold, x_prev, x_next, sigmas, dt,drift)
    
    G_bwd_sq    = (G_bwd_list**2).sum(-1)      

    training_loss    = 0.5 * G_bwd_sq.view(N, B).sum(0).mean()              
    return training_loss

def loss_underdamped_path_EM(manifold, x_hist, score_net, func_b, sigmas, dt, mass, gamma, sample_last_v = True, temp = None):
    """
    x_hist  : [N+1,B,D]  entire forward path (from sampler)
    sigmas  : [N]        σ_k   ( NOT σ_k√Δt )
    dt      : scalar Δt
    """
    if temp is not None:
        temp = temp
    else:
        temp = 1.0

    N       = sigmas.shape[0]
    B, D    = x_hist.shape[1], x_hist.shape[2]
    device  = x_hist.device

    # time grid for the score network
    t_vec   = torch.linspace(0., N*dt, N+1, device=device)[1:]     # t_{k+1}
    t_full  = t_vec[:,None].expand(N, B).reshape(-1)                # [N·B]

    x_prev = x_hist[:-1] #(N,B,D)
    x_next = x_hist[1:] #(N,B,D)


    # --- START OF MODIFICATION: Create loss mask ---
    # Check for inequality constraint violations only if the manifold has them (l>0)
    if hasattr(manifold, 'l') and manifold.l > 0:
        # Reshape for manifold.g function which expects [batch_size, dim]
        x_prev_flat = x_prev.reshape(-1, D)
        x_next_flat = x_next.reshape(-1, D)
        
        # A point is invalid if g(x) > 0.
        # Get violation values for all points in the trajectory.
        g_prev = manifold.g(x_prev_flat) # Shape [N*B, l]
        g_next = manifold.g(x_next_flat) # Shape [N*B, l]
        
        # A step is invalid if either the start or end point violates the constraint.
        # We check if *any* of the 'l' constraints are violated (value > 0).
        invalid_prev = torch.any(g_prev > 0, dim=1) # Shape [N*B]
        invalid_next = torch.any(g_next > 0, dim=1) # Shape [N*B]
        
        # The mask is True for steps where BOTH points are valid.
        valid_steps_mask = torch.logical_not(torch.logical_or(invalid_prev, invalid_next))
        
        # Unsqueeze to allow multiplication with G_bwd_sq
        loss_mask = valid_steps_mask.float().unsqueeze(-1) # Shape [N*B, 1]
    else:
        # If there are no inequality constraints, the mask does nothing.
        loss_mask = 1.0
    # --- END OF MODIFICATION ---

    v_bwd = (x_hist[2:] - x_hist[1:-1]) / ((sigmas[1:, None, None] ** 2) * dt)

    if sample_last_v:
        vN = torch.randn_like(x_hist[-1], device = device)

    v_bwd = torch.cat([v_bwd, vN.unsqueeze(0)], dim = 0)  # (N-1,B,D) + (1,B,D) -> (N,B,D)

    x_prev = x_prev.reshape(-1, D)  # [N·B,D]
    x_next = x_next.reshape(-1, D)  # [N·B,D]
    v_bwd = v_bwd.reshape(-1, D)  # [N·B,D]
    v_bwd = tang_proj(manifold, v_bwd, base=x_next)

    sigmas_step = sigmas.repeat_interleave(B).reshape(-1, 1)

    a     = torch.exp(- (sigmas_step ** 2) * gamma * dt)
    sqrt_1ma2 = torch.sqrt(torch.abs(1.0 - a**2))  # [N]

    b_val   = func_b(x_next)                                      # [N·B,D]
    scores  = score_net(torch.cat([x_next, v_bwd], dim = -1), t_full)                           # [N·B,D]
    drift   = scores - b_val  # [N·B,D]

    G_bwd_list  = G_bwd_underdamped_EM(manifold, x_prev, x_next, v_bwd, sigmas_step, dt, drift, gamma, a, sqrt_1ma2, temp = temp)

    G_bwd_sq    = (G_bwd_list**2).sum(-1)      

    # --- MODIFICATION: Apply the mask to the loss ---
    # masked_G_bwd_sq = G_bwd_sq * loss_mask.squeeze()
    masked_G_bwd_sq = G_bwd_sq

    training_loss    = 0.5 * masked_G_bwd_sq.view(N, B).sum(0).mean()              
    return training_loss


def G_abs_underdamped_OBABO(manifold, x_front, x_back, v, drift_next, sigma_sq, dt, mass, a_pick, sqrt_a_pick, mode = "OBA", backward = False):
    """Computes the G_OBA term for the underdamped backward process."""

    if backward == False:
        coeff_v_sign = -1
    else:
        coeff_v_sign = 1

    if mode == "OBA":
        coeff_drift_sign = -1
        x_pivot = x_back
    elif mode == "BO":
        coeff_drift_sign = 1
        x_pivot = x_front
    else:
        raise ValueError(f"Unknown mode: {mode}. Choose 'OBA' or 'BO'.")

    coeff_v = (a_pick * sigma_sq * dt) / (2.0 * mass) * coeff_v_sign
    coeff_drift = (sigma_sq * dt**2) / (8.0 * mass) * coeff_drift_sign

    delta_term = (x_front - x_back) + coeff_v * v + coeff_drift * drift_next

    delta_term = tang_proj(manifold, delta_term, base=x_pivot) 

    G = (2.0 * math.sqrt(mass) / (sigma_sq * dt * sqrt_a_pick)) * delta_term
    return G

def G_abs_underdamped_OABOA(manifold, x_front, x_back, v, drift, sigma_sq, dt, mass, a_pick, sqrt_a_pick, tau, mode = 'OA', backward = False):
    if backward == False:
        coeff_v_sign = -1
    else:
        coeff_v_sign = 1
    
    if mode == 'OA':
        coeff_drift = - (sigma_sq * dt**2) / (8.0 * mass) * tau
    else:
        coeff_drift = - (sigma_sq * dt**2) / (8.0 * mass) * tau

    coeff_v = (a_pick * sigma_sq * dt) / (4.0 * mass) * coeff_v_sign

    # No reparametrization (assuming  b = 0)
    delta_term = (x_front - x_back) + coeff_v * v + coeff_drift * drift 
    delta_term = tang_proj(manifold, delta_term, base=x_back) 
    G = -(4.0 * math.sqrt(mass) / (sigma_sq * dt * sqrt_a_pick)) * delta_term

    return G

def loss_underdamped_path_OBABO(manifold, data_hist, score_net, func_b, sigmas, dt: float, mass: float, gamma: float):
    # Split the concatenated data history into position and velocity
    x_hist, v_hist = torch.chunk(data_hist, 2, dim=-1)

    
    device = x_hist.device
    N, B, D = x_hist.shape[0] - 1, x_hist.shape[1], x_hist.shape[2]

    # 1. Pre-compute per-step scalar values
    sigmas      = sigmas.to(device)
    sigmas_sq    = sigmas ** 2
    sigmas_sq_dt = sigmas_sq * dt
    a     = torch.exp(- sigmas_sq_dt * gamma / (4. * mass))

    bar_a   = 1.0 / a

    sqrt_1ma2 = torch.sqrt(torch.abs(1.0 - a**2))  # [N]
    sqrt_bar_a2m1 = torch.sqrt(torch.abs(bar_a**2 - 1.0))  # [N]

    # Reshape for broadcasting
    sigmas_sq_broad   = sigmas_sq.repeat_interleave(B).reshape(-1, 1)  # [N·B, 1]
    a_broad         = a.repeat_interleave(B).reshape(-1, 1)
    bar_a_broad     = bar_a.repeat_interleave(B).reshape(-1, 1)
    sqrt_1ma2_broad = sqrt_1ma2.repeat_interleave(B).reshape(-1, 1)
    sqrt_bar_a2m1_broad = sqrt_bar_a2m1.repeat_interleave(B).reshape(-1, 1)

    # 2. Slice path history
    x_prev  = x_hist[:-1].reshape(-1, D)
    x_next  = x_hist[1:].reshape_as(x_prev)
    v_prev  = v_hist[:-1].reshape_as(x_prev)
    v_next  = v_hist[1:].reshape_as(x_prev)

    t_vec   = torch.linspace(0., N*dt, N+1, device=device)[1:]
    t_full  = t_vec[:, None].expand(N, B).reshape(-1)

    # 3. Evaluate drift and score functions
    b_prev  = func_b(x_prev)
    b_next  = func_b(x_next)

    s_next  = score_net(torch.cat([x_next, v_next], dim = -1), label = t_full)

    # Previous OBABO implementation
    v_mid = (2 * mass) / (sigmas_sq_broad * dt) * (x_next - x_prev)
    v_mid = tang_proj(manifold, v_mid, base = x_prev)  # Project onto the tangent space

    s_prev  = score_net(torch.cat([x_prev, v_mid], dim = -1), label = t_full - dt/2)

    # # Concatenate inputs for batch processing
    # x_input = torch.cat([
    #     torch.cat([x_prev, v_mid], dim=-1),
    #     torch.cat([x_next, v_next], dim=-1)
    # ], dim=0)

    # t_input = torch.cat([
    #     t_full - dt/2,
    #     t_full
    # ], dim=0)

    # # Single forward pass
    # s_output = score_net(x_input, t_input)

    # # Split the output back
    # s_prev = s_output[:x_prev.shape[0]]
    # s_next = s_output[x_prev.shape[0]:]


    # drift_prev = sigmas_sq_broad * (b_prev - s_prev - 2 * gamma / mass * v_mid)
    # drift_next = sigmas_sq_broad * (b_next - s_next - 2 * gamma / mass * v_next)

    drift_prev = sigmas_sq_broad * (b_prev - s_prev)
    drift_next = sigmas_sq_broad * (b_next - s_next)

    # drift_prev = (b_prev - s_prev)
    # drift_next = (b_next - s_next)

    # 4. Calculate G_OBA and G_BO terms using helper functions
    G_bar_OBA = G_abs_underdamped_OBABO(manifold, x_prev, x_next, v_next, drift_next, sigmas_sq_broad, dt, mass, a_broad, sqrt_1ma2_broad, mode="OBA", backward=True)
    G_bar_BO  = G_abs_underdamped_OBABO(manifold, x_prev, x_next, v_prev, drift_prev, sigmas_sq_broad, dt, mass, bar_a_broad, sqrt_bar_a2m1_broad, mode="BO", backward=True)


    # 5. Combine squared norms for the final loss
    G_OBA_sq = (G_bar_OBA ** 2).sum(-1)
    G_BO_sq  = (G_bar_BO  ** 2).sum(-1)


    # loss_kb  = temperature * 0.5 * (G_OBA_sq + G_BO_sq)
    loss_kb  = temperature * 0.5 * (G_BO_sq)

    loss     = loss_kb.view(N, B).sum(0).mean()

    return loss


#NEWWWWWWWWWWWWWWWWWWWWWWWWWWWW
def loss_underdamped_path_OABOA(manifold, data_hist, score_net, func_b, sigmas, taus, dt: float, mass: float, gamma: float):
    x_hist, x_mid_hist, v_hist = torch.chunk(data_hist, 3, dim=-1)
    device = x_hist.device
    N, B, D = x_hist.shape[0] - 1, x_hist.shape[1], x_hist.shape[2]

    sigmas       = sigmas.to(device)
    taus         = taus.to(device)
    sigmas_sq    = sigmas ** 2
    sigmas_sq_dt = sigmas_sq * dt
    a     = torch.exp(- sigmas_sq_dt * gamma / (4. * mass))
    bar_a   = 1.0 / a

    sqrt_1ma2 = torch.sqrt(torch.abs(1.0 - a**2))  # [N]

    # Reshape for broadcasting
    sigmas_sq_broad   = sigmas_sq.repeat_interleave(B).reshape(-1, 1)  # [N·B, 1]
    tau_broad         = taus.repeat_interleave(B).reshape(-1, 1)
    a_broad         = a.repeat_interleave(B).reshape(-1, 1)
    sqrt_1ma2_broad = sqrt_1ma2.repeat_interleave(B).reshape(-1, 1)

    # 2. Slice path history
    x_prev  = x_hist[:-1].reshape(-1, D)
    x_next  = x_hist[1:].reshape_as(x_prev)
    x_mid   = x_mid_hist[1:].reshape_as(x_prev)

    # v_prev  = v_hist[:-1].reshape_as(x_prev)
    v_next  = v_hist[1:].reshape_as(x_prev)

    v_mid = (4 * mass) / (sigmas_sq_broad * dt) * (x_next - x_mid)
    v_mid = tang_proj(manifold, v_mid, base = x_mid)  # Project onto the tangent space


    # 3. Evaluate drift and score functions
    b_prev  = func_b(x_mid)
    b_next  = func_b(x_next)

    t_vec   = torch.linspace(0., N*dt, N+1, device=device)[1:]
    t_full  = t_vec[:, None].expand(N, B).reshape(-1)

    s_prev  = score_net(torch.cat([x_mid, v_mid], dim = -1), label = t_full - dt / 2)
    # s_next  = score_net(torch.cat([x_next, v_next], dim = -1), label = t_full)

    drift_prev = sigmas_sq_broad * (b_prev - s_prev)
    # drift_next = sigmas_sq_broad * (b_next - s_next)

    # drift_prev = (b_prev - s_prev)
    # drift_next = (b_next - s_next)

    # 4. Calculate G_OA and G_BOA terms using helper functions
    # G_bar_OA = G_abs_underdamped_OABOA(manifold, x_mid, x_next, v_next, drift_next, sigmas_sq_broad, dt, mass, a_broad, sqrt_1ma2_broad, tau_broad, mode = 'OA', backward=True)
    G_bar_BOA = G_abs_underdamped_OABOA(manifold, x_prev, x_mid, v_mid, drift_prev, sigmas_sq_broad, dt, mass, a_broad, sqrt_1ma2_broad, tau_broad, mode = 'BOA', backward=True)

    # 5. Combine squared norms for the final loss
    # G_OA_sq = (G_bar_OA ** 2).sum(-1)
    G_BOA_sq = (G_bar_BOA ** 2).sum(-1)

    # loss_kb  = temperature * 0.5 * (G_OA_sq + G_BOA_sq)
    loss_kb  = temperature * 0.5 * (G_BOA_sq)
    # loss_kb  = temperature * 0.5 * (G_OA_sq)
    loss     = loss_kb.view(N, B).sum(0).mean()

    return loss

@torch.no_grad()
def nll_overdamped_path(
        data: torch.Tensor,
        # --- model / geometry pieces ---------------------------------
        network,              # score network   s_θ(x,t)
        sde,                  # SDE object with .T, .N, .func_b, .sde()
        manifold,             # manifold object with .project_onto_tangent_space, .log_volume()
        sampler_fn,           # call signature: xN, x_hist, info = sampler_fn(sde, manifold, **kwargs)
        # --- hyper-params & misc -------------------------------------
        nll_bs: int = 64,     # batch size for DataLoader
        nll_K: int = 8,       # # trajectories per datum
        device: torch.device = None,
        sde_kwargs: dict = None,   # kwargs forwarded to sampler_fn
        keep_quiet: bool = True,
        return_mean: bool = True):
    """
    Computes the Jensen upper bound and importance–sampled estimate of
        NLL(x₀) following the over-damped formulation.

    Returns
    -------
    • (scalar, scalar)                                 if `return_mean=True`
      ⇒ (importance mean, Jensen upper bound) averaged over dataset
    • (Tensor[B], Tensor[B])                           if `return_mean=False`
      ⇒ per-sample values
    """
    # ------------------------------------------------------------
    # 0) utilities
    # ------------------------------------------------------------
    def masked_mean(mat, mask):
        # mat, mask : [B, K]
        valid = mask.sum(1) > 0
        out   = torch.zeros_like(mat[:, 0])
        if out[valid].numel() > 0:
            out[valid] = (mat[valid] * mask[valid]).sum(1) / mask[valid].sum(1)
        return out if not return_mean else out[valid]

    if device is None:
        device = data.device
    else:
        device = torch.device(device)

    data   = data.to(device)
    network = network.to(device)
    sde_kwargs = {} if sde_kwargs is None else dict(sde_kwargs)

    # ------------------------------------------------------------
    # 1) time grid & deterministic coefficients
    # ------------------------------------------------------------
    t_grid      = torch.linspace(0., sde.T, sde.N + 1, device=device)  # (N+1,)
    dt_vec      = torch.diff(t_grid)                                  # (N,)
    _, sigma_t  = sde.sde(torch.zeros(1, device=device), t_grid)      # (N+1)
    sigma_k     = sigma_t[:-1]                                        # (N,)

    # log-volume prior term  (uniform prior on Σ)
    prior_x_N = torch.tensor(manifold.log_volume(), device=device, dtype=torch.float32)

    # ------------------------------------------------------------
    # 2) accumulators
    # ------------------------------------------------------------
    nll_list, upper_nll_list, train_loss_list = [], [], []

    # ------------------------------------------------------------
    # 3) loop over minibatches
    # ------------------------------------------------------------
    loader = DataLoader(torch.arange(data.size(0)), batch_size=nll_bs)

    for idx in loader:
        # -------- replicate each datum K times ------------------
        x0_rep = torch.repeat_interleave(data[idx], repeats=nll_K, dim=0) # [B·K, d]

        # sample full forward path
        _, _ , info = sampler_fn(
            sde, manifold, init=x0_rep,
            reverse=False, keep_quiet=keep_quiet, **sde_kwargs)
        # x_hist : [N+1, BK, D]; bsz = B·K
        x_hist = info["x_hist_all"]  # [N+1, B·K, D]
        mask     = info["converged_traj"].view(idx.size(0), nll_K).float()
        bsz = x_hist.size(1)

        # flatten so that time index k moves slowest
        x_prev = x_hist[:-1].reshape(-1, x_hist.size(-1))      # [(N*bsz), D]
        x_next = x_hist[1:].reshape_as(x_prev)

        sigmas   = sigma_k[:, None].repeat(1, bsz).reshape(-1, 1)  # (N*bsz,)
        dt     = dt_vec[0]                                        # scalar Δt

        # -------- forward G_k (proposal q) ---------------------
        drift_fwd   = sigmas ** 2 * sde.func_b(x_prev)
        G_fwd    = G_fwd_overdamped(manifold, x_prev, x_next, sigmas, dt, drift_fwd)
        G_fwd_sq = (G_fwd**2).sum(-1)                            # (N*bsz,)

        # -------- backward G_k (model p_θ) ---------------------
        t_rep = t_grid[1:, None].repeat(1, bsz).flatten()     # (N*bsz,)
        drift_bwd = sigmas ** 2 * (network(x_next, t_rep).detach() - sde.func_b(x_next))
        G_bwd    = G_bwd_overdamped(manifold, x_prev, x_next, sigmas, dt, drift_bwd)
        G_bwd_sq = (G_bwd**2).sum(-1)

        # -------- log-weights & masking ------------------------
        log_w_k = -0.5 * (G_bwd_sq - G_fwd_sq)                  # (N*bsz,)
        log_w   = log_w_k.view(sde.N, idx.size(0), nll_K).sum(0)  # [B,K]

        # Get number of valid trajectories per datum
        valid_K = mask.sum(dim=1)
        # Prevent division by zero if a datum has no valid paths
        valid_K_safe = torch.max(valid_K, torch.tensor(1.0, device=device))
        
        # Set log_w for invalid paths to -inf so they don't contribute to the sum
        log_w[~mask.bool()] = -torch.inf

        # Calculate log(sum(exp(log_w))) stably
        log_sum_w = torch.logsumexp(log_w, dim=1) # [B]
        
        # Calculate log(mean(w)) = log(sum(w)) - log(K)
        log_w_bar = log_sum_w - torch.log(valid_K_safe)
        
        # Calculate NLL
        nll_i = prior_x_N - log_w_bar

        # Set NLL to NaN for data points with no valid trajectories
        nll_i[valid_K == 0] = torch.nan
        nll_list.append(nll_i)

        # ----- Jensen upper bound ------------------------------
        G_bwd_sum = G_bwd_sq.view(sde.N, idx.size(0), nll_K).sum(0)
        G_fwd_sum = G_fwd_sq.view(sde.N, idx.size(0), nll_K).sum(0)

        upper_i = 0.5 * masked_mean(G_bwd_sum, mask) \
                  + prior_x_N \
                  - 0.5 * masked_mean(G_fwd_sum, mask)
        upper_nll_list.append(upper_i)

        # ----- training loss (= ½‖G_bwd‖²) ---------------------
        train_i = 0.5 * masked_mean(G_bwd_sum, mask)            # [B] (or fewer)
        train_loss_list.append(train_i)

    # ------------------------------------------------------------
    # 4) stack & return
    # ------------------------------------------------------------
    nll      = torch.cat(nll_list)
    nll_up   = torch.cat(upper_nll_list)
    train_loss = torch.cat(train_loss_list)
    
    # Filter out NaNs from aggregation
    nll = nll[~torch.isnan(nll)]

    if return_mean:
        return nll.mean(), nll_up.mean(), train_loss.mean(), train_loss.mean(), train_loss.mean()
    
    return nll, nll_up, train_loss, train_loss, train_loss


@torch.no_grad()
def nll_underdamped_path_EM(
        data: torch.Tensor,
        # --- model / geometry pieces ---------------------------------
        network,              # score network s_θ(x,t)
        sde,                  # SDE object with .T, .N, .func_b, .sde()
        manifold,             # manifold object with .project_onto_tangent_space, .log_volume()
        sampler_fn,           # call signature: xN, x_hist, info = sampler_fn(sde, manifold, **kwargs)
        # --- hyper-params & misc -------------------------------------
        nll_bs: int = 64,     # batch size for DataLoader
        nll_K: int = 8,       # # trajectories per datum
        device: torch.device = None,
        sde_kwargs: dict = None,   # kwargs forwarded to sampler_fn
        keep_quiet: bool = True,
        return_mean: bool = True):

    # ------------------------------------------------------------

    mass = sde_kwargs["mass"] if sde_kwargs is not None else 1.0
    gamma = sde_kwargs["gamma"] if sde_kwargs is not None else 100.0

    def masked_mean(mat, mask):
        # mat, mask : [B, K]
        valid = mask.sum(1) > 0
        out   = torch.zeros_like(mat[:, 0])
        if out[valid].numel() > 0:
            out[valid] = (mat[valid] * mask[valid]).sum(1) / mask[valid].sum(1)
        return out if not return_mean else out[valid]

    if device is None:
        device = data.device
    else:
        device = torch.device(device)

    data   = data.to(device)
    network = network.to(device)
    sde_kwargs = {} if sde_kwargs is None else dict(sde_kwargs)

    # ------------------------------------------------------------
    # 1) time grid & deterministic coefficients
    # ------------------------------------------------------------
    t_grid      = torch.linspace(0., sde.T, sde.N + 1, device=device)  # (N+1,)
    dt_vec      = torch.diff(t_grid)     
    dt          = dt_vec[0]                                        # scalar Δt
    sigma_t     = sde.get_diffusion(t_grid)     # (N+1)
    sigma_k     = sigma_t[:-1]                                        # (N,)

    # log-volume prior term  (uniform prior on Σ)

    nll_list, upper_nll_list, train_loss_list = [], [], []

    loader = DataLoader(torch.arange(data.size(0)), batch_size=nll_bs)

    for idx in loader:
        # -------- replicate each datum K times ------------------
        x0_rep = torch.repeat_interleave(data[idx], repeats=nll_K, dim=0) # [B·K, d]
        v_m_1 = tang_proj(manifold, torch.randn_like(x0_rep, device=device), x0_rep)

        # sample full forward path
        _, _ , info = sampler_fn(
            sde, manifold, init=x0_rep,
            reverse=False, init_v=v_m_1, keep_quiet=keep_quiet, **sde_kwargs)
        # x_hist : [N+1, BK, D]; bsz = B·K

        x_hist = info["x_hist_all"]  # [N+1, B·K, D]
        v_N = tang_proj(manifold, torch.randn_like(x_hist[-1], device=device), x_hist[-1])

        mask     = info["converged_traj"].view(idx.size(0), nll_K).float() # [B, K]
        bsz = x_hist.size(1) # B x K
        x_prev = x_hist[:-1]
        x_next = x_hist[1:]
        sigmas = sigma_k.view(-1, 1, 1).expand(-1, x_prev.size(-2), x_prev.size(-1))

        v =  (x_next - x_prev) / ((sigmas ** 2) * dt)

        v_fwd = torch.cat([v_m_1.unsqueeze(0) ,v[:-1]], dim = 0)
        v_bwd = torch.cat([v[1:], v_N.unsqueeze(0)], dim = 0)

        x_prev = x_hist[:-1].reshape(-1, x_hist.size(-1))      # [(N*bsz), D]
        x_next = x_hist[1:].reshape_as(x_prev)
        sigmas   = sigmas.reshape_as(x_prev)

        v_fwd  = tang_proj(manifold, v_fwd.reshape_as(x_prev), base=x_prev)
        v_bwd  = tang_proj(manifold, v_bwd.reshape_as(x_prev), base=x_next)


        a = torch.exp(- (sigmas ** 2) * gamma * dt)
        sqrt_1ma2 = torch.sqrt(1 - a**2)

        # -------- forward G_k (proposal q) ---------------------
        drift_fwd = sde.drift_b(x_prev)
        G_fwd    = G_fwd_underdamped_EM(manifold, x_prev, x_next, v_fwd, sigmas, dt, drift_fwd, gamma, a, sqrt_1ma2)
        G_fwd_sq = (G_fwd**2).sum(-1)                            # (N*bsz,)

        # -------- backward G_k (model p_θ) ---------------------
        t_rep = t_grid[1:, None].repeat(1, bsz).flatten()     # (N*bsz,)
        drift_bwd = network(torch.cat([x_next, v_bwd], dim = -1), t_rep).detach() - sde.func_b(x_next)
        G_bwd    = G_bwd_underdamped_EM(manifold, x_prev, x_next, v_bwd, sigmas, dt, drift_bwd, gamma, a, sqrt_1ma2)
        G_bwd_sq = (G_bwd**2).sum(-1)

        # -------- log-weights & masking ------------------------
        log_w_k = -0.5 * (G_bwd_sq - G_fwd_sq)                  # (N*bsz,)
        log_prior_x = - torch.tensor(manifold.log_volume(), device=device, dtype=torch.float32) # literally ln vol(\Sigma) => ln p(x_N) = -ln vol(\Sigma)
        log_prior_v = -0.5 * ((v_N**2).sum(-1) - (v_m_1 ** 2).sum(-1))

        log_prior = (log_prior_x + log_prior_v.reshape(idx.size(0), nll_K))

        log_w   = log_w_k.view(sde.N, idx.size(0), nll_K).sum(0)  # [B,K]
        log_w   = log_w + log_prior

        # Get number of valid trajectories per datum
        valid_K = mask.sum(dim=1)
        # Prevent division by zero if a datum has no valid paths
        valid_K_safe = torch.max(valid_K, torch.tensor(1.0, device=device))
        
        # Set log_w for invalid paths to -inf so they don't contribute to the sum
        log_w[~mask.bool()] = -torch.inf

        # Calculate log(sum(exp(log_w))) stably
        log_sum_w = torch.logsumexp(log_w, dim=1) # [B]
        
        # Calculate log(mean(w)) = log(sum(w)) - log(K)
        log_w_bar = log_sum_w - torch.log(valid_K_safe)
        
        # Calculate NLL
        nll_i = -log_w_bar

        # Set NLL to NaN for data points with no valid trajectories
        nll_i[valid_K == 0] = torch.nan
        nll_list.append(nll_i)

        # ----- Jensen upper bound ------------------------------
        G_bwd_sum = G_bwd_sq.view(sde.N, idx.size(0), nll_K).sum(0)
        G_fwd_sum = G_fwd_sq.view(sde.N, idx.size(0), nll_K).sum(0)

        upper_i = 0.5 * masked_mean(G_bwd_sum,mask) \
                  - log_prior_x - masked_mean(log_prior_v.reshape(idx.size(0), nll_K), mask) \
                  - 0.5 * masked_mean(G_fwd_sum, mask)
        upper_nll_list.append(upper_i)

        # ----- training loss (= ½‖G_bwd‖²) ---------------------
        train_i = 0.5 * masked_mean(G_bwd_sum, mask)            # [B] (or fewer)
        train_loss_list.append(train_i)

    # ------------------------------------------------------------
    # 4) stack & return
    # ------------------------------------------------------------
    nll      = torch.cat(nll_list)
    nll_up   = torch.cat(upper_nll_list)
    train_loss = torch.cat(train_loss_list)
    
    # Filter out NaNs from aggregation
    nll = nll[~torch.isnan(nll)]

    if return_mean:
        return nll.mean(), nll_up.mean(), train_loss.mean(), train_loss.mean(), train_loss.mean()
    
    return nll, nll_up, train_loss, train_loss, train_loss

@torch.no_grad()
def nll_underdamped_path_OBABO(
        data: torch.Tensor,
        network,
        sde,
        manifold,
        sampler_fn,
        nll_bs: int = 64,
        nll_K: int = 8,
        device: torch.device = None,
        sde_kwargs: dict = None,
        keep_quiet: bool = True,
        return_mean: bool = True):
    """
    Computes the Jensen upper bound and importance-sampled estimate of
        NLL(x₀, v₀) following the underdamped formulation.
    """
    
    def masked_mean(mat, mask):
        valid = mask.sum(1) > 0
        out   = torch.zeros_like(mat[:, 0])
        if out[valid].numel() > 0:
            out[valid] = (mat[valid] * mask[valid]).sum(1) / mask[valid].sum(1)
        return out if not return_mean else out[valid]

    if device is None:
        device = data.device
    else:
        device = torch.device(device)


    data   = data.to(device)
    network = network.to(device)

    mass = sde_kwargs["mass"] if sde_kwargs is not None else 1.0
    gamma = sde_kwargs["gamma"] if sde_kwargs is not None else 100.0

    dim = data.shape[-1] 

    t_grid      = torch.linspace(0., sde.T, sde.N + 1, device=device)
    dt_vec      = torch.diff(t_grid)
    sigma_t     = sde.sde(torch.zeros(1,1,device=device), t_grid)[1]
    sigmas      = sigma_t[:-1]

    # Pre-compute coefficients
    sigmas_sq = sigmas**2
    sigmas_sq_dt = sigmas_sq * dt_vec[0]
    a = torch.exp(- sigmas_sq_dt * gamma / (4. * mass))
    
    bar_a = 1.0 / a

    sqrt_1ma2 = torch.sqrt(1.0 - a**2)
    sqrt_bar_a2m1 = torch.sqrt(bar_a**2 - 1.0)

    # Log-volume prior term (uniform prior on Σ)
    log_prior_x = -torch.tensor(manifold.log_volume(), device=device, dtype=torch.float32)
    intrinsic_dim = manifold.inner_dim

    nll_list, upper_nll_list, train_loss_list = [], [], []
    train_loss_OBA_list = []
    train_loss_BO_list = []
    loader = DataLoader(torch.arange(data.size(0)), batch_size=nll_bs)

    for idx in loader:
        bsz_nll_k = idx.size(0) * nll_K
        # Reshape for broadcasting
        sigmas_sq_broad = sigmas_sq[:, None].repeat(1, bsz_nll_k).reshape(-1, 1)  # [N·B, 1]
        a_broad = a[:, None].repeat(1, bsz_nll_k).reshape(-1, 1)  
        bar_a_broad = bar_a[:, None].repeat(1, bsz_nll_k).reshape(-1, 1)
        sqrt_1ma2_broad = sqrt_1ma2[:, None].repeat(1, bsz_nll_k).reshape(-1, 1)  
        sqrt_bar_a2m1_broad = sqrt_bar_a2m1[:, None].repeat(1, bsz_nll_k).reshape(-1, 1)

        x0_rep = torch.repeat_interleave(data[idx], repeats=nll_K, dim=0)

        # Sample full forward path
        _, _, info = sampler_fn(
            sde, manifold, init_x=x0_rep, reverse=False, keep_quiet=keep_quiet, **sde_kwargs)

        data_hist = torch.cat([info["x_hist_all"], info["v_hist_all"]], dim=-1)
        mask = info["converged_traj"].view(idx.size(0), nll_K).float()

        x_hist, v_hist = torch.chunk(data_hist, 2, dim=-1)
        x_prev, x_next = x_hist[:-1].reshape(-1, dim), x_hist[1:].reshape(-1, dim)
        v_prev, v_next = v_hist[:-1].reshape(-1, dim), v_hist[1:].reshape(-1, dim)

        v_mid = (2 * mass) / (sigmas_sq_broad * sde.dt) * (x_next - x_prev)
        v_mid = tang_proj(manifold, v_mid, base=x_prev)  # Project onto the tangent space  

        # Evaluate drift and score
        b_prev = sde.func_b(x_prev)
        b_next = sde.func_b(x_next)
        
        t_full = t_grid[1:, None].repeat(1, bsz_nll_k).flatten()  # [N·B]
        
        # s_prev = network(torch.cat([x_prev, v_mid], dim = -1), t_full - sde.dt / 2).detach()
        # s_next = network(torch.cat([x_next, v_next], dim = -1), t_full).detach()

        s_prev = network(torch.cat([x_prev, v_mid], dim = -1), label = t_full - sde.dt / 2).detach()
        s_next = network(torch.cat([x_next, v_next], dim = -1), label = t_full).detach()

        # jac_idx = int(t_full.size(0) * 0.3)
        # jac_x_prev_cand = torch.cat([x_prev[:jac_idx], v_mid[:jac_idx]], dim = -1)
        # jac_x_next_cand = torch.cat([x_next[:jac_idx], v_next[:jac_idx]], dim = -1)
        # jac_t_cand = t_full[:jac_idx]

        # jac_prev = torch.autograd.functional.jacobian(network, (jac_x_prev_cand, jac_t_cand - sde.dt / 2))
        # jac_next = torch.autograd.functional.jacobian(network, (jac_x_next_cand, jac_t_cand))
        
        # Calculate G terms for backward process (p_theta)
        # drift_bwd_prev = sigmas_sq_broad * (b_prev - s_prev - 2 * gamma / mass * v_mid)
        # drift_bwd_next = sigmas_sq_broad * (b_next - s_next - 2 * gamma / mass * v_next)

        drift_bwd_prev = sigmas_sq_broad * (b_prev - s_prev )
        drift_bwd_next = sigmas_sq_broad * (b_next - s_next )        
        
        # drift_bwd_prev = (b_prev - s_prev )
        # drift_bwd_next =  (b_next - s_next )  

        G_bar_OBA = G_abs_underdamped_OBABO(manifold, x_prev, x_next, v_next, drift_bwd_next, sigmas_sq_broad, sde.dt, mass, a_broad, sqrt_1ma2_broad, mode="OBA", backward=True)
        G_bar_BO  = G_abs_underdamped_OBABO(manifold, x_prev, x_next, v_prev, drift_bwd_prev, sigmas_sq_broad, sde.dt, mass, bar_a_broad, sqrt_bar_a2m1_broad, mode="BO", backward=True)

        G_bar_OBA_sq = (G_bar_OBA**2).sum(-1)
        G_bar_BO_sq  = (G_bar_BO**2).sum(-1)
        G_bwd_sq = G_bar_OBA_sq + G_bar_BO_sq

        # Calculate G terms for forward process (q)
        drift_fwd_prev = sigmas_sq_broad * b_prev
        drift_fwd_next = sigmas_sq_broad * b_next
        G_OBA = G_abs_underdamped_OBABO(manifold, x_next, x_prev, v_prev, drift_fwd_prev, sigmas_sq_broad, sde.dt, mass, a_broad, sqrt_1ma2_broad, mode="OBA", backward=False)
        G_BO  = G_abs_underdamped_OBABO(manifold, x_next, x_prev, v_next, drift_fwd_next, sigmas_sq_broad, sde.dt, mass, bar_a_broad, sqrt_bar_a2m1_broad, mode="BO", backward=False)

        G_fwd_sq = (G_OBA**2).sum(-1) + (G_BO**2).sum(-1)

        G_sq = temperature * (G_bwd_sq - G_fwd_sq).reshape(sde.N, -1)

        # Log-importance weights
        # This is the log of the ratio of normalization constants: log(a_k^{2(d-m)})
        log_norm_ratio = torch.zeros_like(a)

        log_w_k = log_norm_ratio.reshape(-1, 1) - 0.5 * G_sq
        log_w = log_w_k.sum(0) # Sum over time steps N
        
        # Add prior p(x_N, v_N)
        v_N_tangent = tang_proj(manifold, v_next.reshape(sde.N, -1, dim)[-1], x_next.reshape(sde.N, -1, dim)[-1])
        v_0_tangent = tang_proj(manifold, v_prev.reshape(sde.N, -1, dim)[0], x_prev.reshape(sde.N, -1, dim)[0])

        # Log-prior for velocity, including the normalization constant
        # NOTE: This assumes the prior has variance 1/mass, consistent with K(v) = (mass/2)*||v||^2
        log_prior_v = 0.5 / mass *((v_0_tangent ** 2).sum(-1) - (v_N_tangent**2).sum(-1))

        # Total log-prior
        log_prior = log_prior_x + log_prior_v 

        log_w += log_prior
        
        # --- Replace with ---
        log_w = log_w.reshape(idx.size(0), nll_K) # Shape [B, K]

        # Get number of valid trajectories per datum
        # Use the mask to avoid logsumexp over invalid (e.g., nan) paths
        valid_K = mask.sum(dim=1)
        # Prevent division by zero if a datum has no valid paths
        valid_K = torch.max(valid_K, torch.tensor(1.0, device=device))

        # Set log_w for invalid paths to -inf so they don't contribute to the sum
        log_w[~mask.bool()] = -torch.inf

        # Calculate log(sum(exp(log_w))) stably
        log_sum_w = torch.logsumexp(log_w, dim=1) # Shape [B]

        # Calculate log(mean(w)) = log(sum(w)) - log(K)
        log_w_bar = log_sum_w - torch.log(valid_K)

        nll_i = -log_w_bar
        # Be careful with data that had zero valid paths
        nll_i[valid_K == 0] = torch.nan # Or some other placeholder
        nll_list.append(nll_i)
        
        # Jensen upper bound
        upper_i = -masked_mean(log_prior.view(idx.size(0), nll_K), mask) + \
                  -log_norm_ratio.sum(0).reshape(-1, 1) + \
                   0.5 * masked_mean(G_sq.sum(0).reshape(idx.size(0), nll_K), mask)
        upper_nll_list.append(upper_i.flatten())

        # Training loss
        G_bwd_sq_sum = G_bwd_sq.reshape(sde.N, -1).sum(0)
        train_i = 0.5 * masked_mean(G_bwd_sq_sum.reshape(idx.size(0), nll_K), mask)
        train_loss_list.append(train_i)

        G_bwd_OBA_sq_sum = G_bar_OBA_sq.reshape(sde.N, -1).sum(0)
        train_loss_OBA_i = 0.5 * masked_mean(G_bwd_OBA_sq_sum.reshape(idx.size(0), nll_K), mask)
        train_loss_OBA_list.append(train_loss_OBA_i)

        G_bwd_BO_sq_sum = G_bar_BO_sq.reshape(sde.N, -1).sum(0)
        train_loss_BO_i = 0.5 * masked_mean(G_bwd_BO_sq_sum.reshape(idx.size(0), nll_K), mask)
        train_loss_BO_list.append(train_loss_BO_i)

    nll = torch.cat(nll_list)
    nll_up = torch.cat(upper_nll_list)
    train_loss = torch.cat(train_loss_list)
    train_loss_OBA = torch.cat(train_loss_OBA_list)
    train_loss_BO = torch.cat(train_loss_BO_list)   

    if return_mean:
        return nll.mean(), nll_up.mean(), train_loss.mean(), train_loss_OBA.mean(), train_loss_BO.mean()

    return nll, nll_up, train_loss, train_loss_OBA, train_loss_BO
