import logging
import torch
from tqdm import tqdm
from typing import Dict, Tuple


__all__ = ['SDE_sampler_manifolds_CLangevin', 'SDE_sampler_manifolds_OLLA', 'SDE_sampler_manifolds_CHMC_OBABO', 'SDE_sampler_manifolds_CHMC_OABOA']

def _extract_mass_gamma(kwargs: Dict, *, device: torch.device):
    mass = float(kwargs.get("mass", 1.0))
    gamma = float(kwargs.get("gamma", 100.0))
    if "mass" not in kwargs:
        logging.info("mass not specified – defaulting to 1.0")
    if "gamma" not in kwargs:
        logging.info("gamma not specified – defaulting to 100.0")
    return torch.as_tensor(mass, device=device), torch.as_tensor(gamma, device=device)

temperature = 1.0

@torch.no_grad()
def SDE_sampler_manifolds_CLangevin(sde,
                                    manifold,
                                    init,
                                    reverse,
                                    score_net=None,
                                    keep_quiet=False,
                                    **kwargs):
    """
    timesteps: shape [N+1, bsz]
    Shape of x_hist: [sde.N, bsz, dim]
    """
    device = init.device
    shape = init.shape
    N = sde.N

    if reverse:
        rsde = sde.reverse(score_net)
        drift_diffusion_fn = rsde.sde
        timesteps = torch.linspace(sde.T, 0., N+1, device=device).reshape(-1, 1).repeat(1, shape[0])
    else:
        drift_diffusion_fn = sde.sde
        timesteps = torch.linspace(0., sde.T, N+1, device=device).reshape(-1, 1).repeat(1, shape[0]) # shape [N+1, bsz]

    def update_fn_predictor(x, t, delta_t):
        # diffusion = sigma_t, shape : [bsz]
        # drift = sigma_t^2 * b(x, t), shape : [bsz, dim]
        # std = sigma_t * sqrt(|delta_t|), shape : [bsz, 1]
        delta_t = delta_t.reshape(-1, 1)
        z = torch.randn_like(x, device=device)
        if reverse:
            drift, diffusion = drift_diffusion_fn(x, score_t = t, diff_t = t + delta_t.flatten()) # drift = b(X_t) sigma_t^2, diffusion = sigma_t
        else:
            drift, diffusion = drift_diffusion_fn(x, t)
        std = diffusion[:, None] * torch.sqrt(torch.abs(delta_t))

        if len(z.shape) == 2:
            tangent_vec = manifold.project_onto_tangent_space(drift * delta_t / 2 + std * z, base_point=x)
        else:
            tangent_vec = manifold.project_onto_tangent_space(drift * delta_t[:, None] / 2+ std[:, None] * z, base_point=x)

        # converged_traj_one_step: shape [bsz] (bool) indicates whether the projection is successful
        x_prime, converged_traj_one_step = manifold.project_onto_manifold_with_base(tangent_vec, base_point=x)
        
        std_hist[i] = std.reshape(-1).clone()
        tangent_vec_hist[i] = tangent_vec.clone()
        return x_prime.detach(), converged_traj_one_step

    x_hist = torch.zeros(N+1, *init.shape).to(device)
    x = init
    x_hist[0] = x.clone()
    std_hist = torch.zeros(N, init.shape[0]).to(device)
    tangent_vec_hist = torch.zeros(N, *init.shape).to(device)

    converged_traj = torch.ones(init.shape[0], dtype=torch.bool).to(device)

    for i in tqdm(range(N), mininterval=2., disable=keep_quiet):
        vec_t = timesteps[i]
        dt = timesteps[i+1] - timesteps[i]
        x, converged_traj_one_step = update_fn_predictor(x, vec_t, dt)
        # converged_traj: shape [bsz] (bool) indicates whether the whole trajectory is successful so far    
        converged_traj = torch.logical_and(converged_traj, converged_traj_one_step)
        x_hist[i + 1] = x.clone()

    num_unconverged_traj = init.shape[0] - converged_traj.sum()
    if keep_quiet is False: logging.info(f'{num_unconverged_traj} of {init.shape[0]} trajectories are dropped.')
    other_dict = {"tangent_vec_hist": tangent_vec_hist, "converged_traj":converged_traj, "x_hist_all": x_hist}
    return x[converged_traj, :], x_hist[:, converged_traj, :], other_dict

@torch.no_grad()
def SDE_sampler_manifolds_OLLA(sde,
                                manifold,
                                init,
                                reverse,
                                score_net=None,
                                keep_quiet=False,
                                **kwargs):
    """
    timesteps: shape [N+1, bsz]
    Shape of x_hist: [sde.N, bsz, dim]
    """
    device = init.device
    shape = init.shape
    N = sde.N

    if kwargs.__contains__('alpha'):
        alpha = kwargs['alpha']
    else:
        alpha = 100.0
        logging.info(f'Warning: alpha is not specified. Set alpha = {alpha}.')

    if reverse:
        rsde = sde.reverse(score_net)
        drift_diffusion_fn = rsde.sde
        timesteps = torch.linspace(sde.T, 0., N+1, device=device).reshape(-1, 1).repeat(1, shape[0])
    else:
        drift_diffusion_fn = sde.sde
        timesteps = torch.linspace(0., sde.T, N+1, device=device).reshape(-1, 1).repeat(1, shape[0]) # shape [N+1, bsz]

    def update_fn_predictor(x, t, delta_t, last = False):
        # diffusion = sigma_t, shape : [bsz]
        # drift = sigma_t^2 * b(x, t), shape : [bsz, dim]
        # std = sigma_t * sqrt(|delta_t|), shape : [bsz, 1]
        delta_t = delta_t.reshape(-1, 1)
        z = torch.randn_like(x, device=device)
        # drift = b(X_t), diffusionb = sigma_t
        if reverse:
            drift, diffusion = drift_diffusion_fn(x, score_t = t, diff_t = t + delta_t.flatten()) # drift = b(X_t) sigma_t^2, diffusion = sigma_t
        else:
            drift, diffusion = drift_diffusion_fn(x, t)
        # std = sigma_t * sqrt(|delta_t|)
        diffusion = diffusion.reshape(-1, 1)  # [bsz, 1]
        std = diffusion * torch.sqrt(torch.abs(delta_t))

        # tangent_vec = P(X_t)(b(X_t) * delta_t + sigma_t * sqrt(|delta_t|) * z)
        if len(z.shape) == 2:
            tangent_vec = manifold.project_onto_tangent_space(drift * delta_t / 2+ std * z, base_point=x)
        else:
            tangent_vec = manifold.project_onto_tangent_space(drift * delta_t[:, None] / 2 + std[:, None] * z, base_point=x)

        # converged_traj_one_step: shape [bsz] (bool) indicates whether the projection is successful
        x_prime = manifold.adding_correction_decaying(tangent_vec, base_point = x, delta_t = delta_t, alpha = alpha, sigma_sq = diffusion**2)
        converged_traj_one_step = torch.ones(x_prime.shape[0], dtype=torch.bool, device=device)
        if last:
            tangent_vec_zero = torch.zeros_like(tangent_vec, device=device)  # last step, no correction
            x_prime, converged_traj_one_step = manifold.project_onto_manifold_with_base(tangent_vec_zero, base_point=x_prime)

        std_hist[i] = std.reshape(-1).clone()
        tangent_vec_hist[i] = tangent_vec.clone()
        return x_prime.detach(), converged_traj_one_step
    
    x_hist = torch.zeros(N+1, *init.shape).to(device)
    x = init
    x_hist[0] = x.clone()
    std_hist = torch.zeros(N, init.shape[0]).to(device)
    tangent_vec_hist = torch.zeros(N, *init.shape).to(device)

    # OLLA paths are always all converged 
    converged_traj = torch.ones(init.shape[0], dtype=torch.bool).to(device)

    for i in tqdm(range(N), mininterval=2., disable=keep_quiet):
        vec_t = timesteps[i]
        dt = timesteps[i+1] - timesteps[i]

        # --- MODIFICATION: PRINT VIOLATIONS AT EACH STEP ---
        if True and hasattr(manifold, 'report_violations'):
            # Get the violation string from the manifold and print it
            violation_str = manifold.report_violations(x)
            print(f"Step {i+1}/{N}: {violation_str}")
        # --- END OF MODIFICATION ---

        if i == N - 1:
            # last step, use the projection
            x, converged_traj_one_step = update_fn_predictor(x, vec_t, dt, last = True)
            converged_traj = torch.logical_and(converged_traj, converged_traj_one_step)
        else:
            x = update_fn_predictor(x, vec_t, dt, last = False)[0]
            # converged_traj: shape [bsz] (bool) indicates whether the whole trajectory is successful so far    
        x_hist[i + 1] = x.clone()



    num_unconverged_traj = init.shape[0] - converged_traj.sum()
    if keep_quiet is False: logging.info(f'{num_unconverged_traj} of {init.shape[0]} trajectories are dropped.')
    other_dict = {"tangent_vec_hist": tangent_vec_hist, "converged_traj":converged_traj, "x_hist_all": x_hist}
    return x[converged_traj, :], x_hist[:, converged_traj, :], other_dict

@torch.no_grad()
def SDE_sampler_manifolds_CHMC_OBABO(
        sde,                       # SDE_Brownian_manifolds instance
        manifold,                  # manifold object with project_* helpers
        init_x,                    # tensor [bsz, dim] – data minibatch
        reverse: bool,             # False = forward  (training / NLL),  True  = backward (sampling / generation)
        score_net=None,            # needed only when reverse=True
        init_v=None,               # optional initial velocity
        keep_quiet: bool = False,
        **kwargs):
    """
    Forward:   Eq. (38)–(40)  (ζ_k , ζ_{k+1})
    Backward:  Eq. (48)–(50) (requires score_net) – **implement next**

    Returns
    -------
    x        : final position  [bsz, dim]
    x_hist   : positions       [N+1, bsz, dim]
    v_hist   : velocities      [N+1, bsz, dim]
    other    : dict – std history, projection flags …
    """

    mass, gamma = _extract_mass_gamma(kwargs, device=init_x.device)

    device   = init_x.device
    shape    = init_x.shape
    N        = sde.N
    sqrt_mass = mass ** 0.5


    if reverse:
        rsde = sde.reverse(score_net, underdamped=True)
        drift_diffusion_fn = rsde.sde
        timesteps = torch.linspace(sde.T, 0., N+1, device=device).reshape(-1, 1).repeat(1, shape[0])
    else:
        drift_diffusion_fn = sde.sde
        timesteps = torch.linspace(0., sde.T, N+1, device=device).reshape(-1, 1).repeat(1, shape[0]) # shape [N+1, bsz]

    def update_fn_predictor(x, v, t, delta_t):

        delta_t = delta_t.reshape(-1, 1)

        if reverse:
            drift, diffusion = drift_diffusion_fn(x, v, score_t = t, diff_t = t + delta_t.flatten()) # drift = b(X_t) sigma_t^2, diffusion = sigma_t
            diffusion = diffusion.reshape(-1, 1)  # [bsz, 1]
            # drift /= diffusion ** 2  # normalize drift to unit variance
            # drift = drift - 2 * gamma / mass * diffusion ** 2 * v  # momentum compensation
            # drift = 0

        else:
            drift, diffusion = drift_diffusion_fn(x, t)
            diffusion = diffusion.reshape(-1, 1)  # [bsz, 1]

        a = torch.exp(-diffusion**2 * torch.abs(delta_t) * gamma / (4. * mass))
        
        sqrt_1ma2 = torch.sqrt(torch.abs(1. - a**2))

        z_1 = torch.randn_like(x, device=device) / (temperature ** 0.5)
        v_mid_pre = a * v + sqrt_mass * sqrt_1ma2 * z_1 + (delta_t / 4) * drift
        v_mid_pre  = manifold.project_onto_tangent_space(v_mid_pre, base_point=x)
        step_vec1 = (diffusion ** 2 * delta_t / (2 * mass)) * v_mid_pre

        # Project to manifold with Newton solver
        x_next, converged_traj_one_step = manifold.project_onto_manifold_with_base(step_vec1, base_point=x)

        if converged_traj_one_step.sum() != x_next.shape[0]:
            print(f"Warning: {x_next.shape[0] - converged_traj_one_step.sum()} of {x_next.shape[0]} trajectories are not projected successfully.")
    
        v_mid = (2 * mass) / (diffusion ** 2 * delta_t) * (x_next - x)
        v_mid = manifold.project_onto_tangent_space(v_mid, base_point=x_next)

        # BO update
        if reverse:
            drift_next, diffusion = drift_diffusion_fn(x_next, v_mid, score_t = t + delta_t.flatten() / 2, diff_t = t + delta_t.flatten()) # drift = b(X_t) sigma_t^2, diffusion = sigma_t
            diffusion = diffusion.reshape(-1, 1)  # [bsz, 1]
            # drift_next /= diffusion ** 2  # normalize drift to unit variance
            # drift_next = drift_next - 2 * gamma / mass * diffusion ** 2 * v_mid  # momentum compensation
            # drift_next = 0
            # drift_next = drift_next
        else:
            drift_next, diffusion = drift_diffusion_fn(x_next, t)
            diffusion = diffusion.reshape(-1, 1)  # [bsz, 1]
            
        z_2 = torch.randn_like(x_next, device=device) / (temperature ** 0.5)

        v_next_pre = a * v_mid + sqrt_mass * sqrt_1ma2 * z_2 + a * (delta_t / 4) * drift_next
        
        v_next = manifold.project_onto_tangent_space(v_next_pre, base_point=x_next)

        return x_next.detach(), v_next.detach(), converged_traj_one_step

    x_hist = torch.zeros(N+1, *init_x.shape, device=device)
    v_hist = torch.zeros(N+1, *init_x.shape, device=device)
    x = init_x
    if init_v is None:
        # sample tangent Gaussian with with variance = mass
        z = torch.randn_like(x, device=device)
        init_v = sqrt_mass * z
    v = manifold.project_onto_tangent_space(init_v, base_point=x)
    x_hist[0] = x.clone()
    v_hist[0] = v.clone()

    converged_traj = torch.ones(init_x.shape[0], dtype=torch.bool).to(device)
    
    for i in tqdm(range(N), mininterval=2., disable=keep_quiet):
        vec_t = timesteps[i]
        dt = timesteps[i+1] - timesteps[i]
        if i > int(N * 0.8) and reverse:
            print(f"norm of v at step {i}: {v.norm(dim=-1).mean().item():.4f}")
        x, v, converged_traj_one_step = update_fn_predictor(x, v, vec_t, dt)

        converged_traj = torch.logical_and(converged_traj, converged_traj_one_step)
        x_hist[i + 1] = x.clone()
        v_hist[i + 1] = v.clone()

    num_unconverged_traj = init_x.shape[0] - converged_traj.sum()
    if keep_quiet is False: logging.info(f'{num_unconverged_traj} of {init_x.shape[0]} trajectories are dropped.')
    
    other_dict = {
        "converged_traj": converged_traj,
        "x_hist_all": x_hist,
        "v_hist_all": v_hist,
        }
    
    # Concatenate position and velocity for the final return values
    x_v = torch.cat([x, v], dim=-1)
    x_v_hist = torch.cat([x_hist, v_hist], dim=-1)
    
    return x_v[converged_traj, :], x_v_hist[:, converged_traj, :], other_dict
        
@torch.no_grad()
def SDE_sampler_manifolds_CHMC_OABOA(
        sde,                      # SDE_Brownian_manifolds instance
        manifold,                 # manifold helper with project_* methods
        init_x,                   # [bsz, dim] initial positions on Σ
        reverse: bool,            # False → forward   True → backward
        score_net=None,           # required when reverse=True
        init_v=None,              # optional initial velocities
        keep_quiet: bool = False,
        **kwargs):
    """
    Constrained HMC sampler with the non‑symmetric OA‑BOA splitting (URDDPM‑OABOA).
    Matches the tensor API and return signature of SDE_sampler_manifolds_CHMC.
    """

    # REPARAM = True
    REPARAM = False

    # ── parameters ──────────────────────────────────────────────────────────────
    mass, gamma = _extract_mass_gamma(kwargs, device=init_x.device)
    sqrt_mass   = mass ** 0.5
    device      = init_x.device
    N           = sde.N

    # choose drift / diffusion operator -------------------------------------------------
    if reverse:
        rsde = sde.reverse(score_net, underdamped=True)
        drift_diffusion_fn = rsde.sde
        timesteps = torch.linspace(sde.T, 0., N + 1, device=device).reshape(-1, 1).repeat(1, init_x.shape[0])
    else:
        drift_diffusion_fn = sde.sde
        timesteps = torch.linspace(0., sde.T, N + 1, device=device).reshape(-1, 1).repeat(1, init_x.shape[0])

    # ── helper: one OABOA step ─────────────────────────────────────────────────
    def update_fn_oaboa(x, v, t, delta_t):
        """
        x, v:   [bsz, dim] current state on the manifold & tangent bundle
        t:      [bsz]      current time
        delta_t:[bsz]      signed step (next - current)
        """
        delta_t = delta_t.reshape(-1, 1)                   # [bsz,1]

        diffusion = sde.get_diffusion(t + delta_t.flatten()) if reverse else sde.get_diffusion(t)  
        diffusion = diffusion.reshape(-1, 1)               # [bsz,1]   
        tau = sde.get_tau_scheduler(t + delta_t.flatten()) if reverse else sde.get_tau_scheduler(t)
        tau = tau.reshape(-1, 1)

        # exact OU coefficient a_t = exp(-σ² γ Δt /(4m))
        a = torch.exp(- diffusion ** 2 * torch.abs(delta_t) * gamma / (4. * mass))
        sqrt_1ma2 = torch.sqrt(torch.clamp(1. - a ** 2, min=0))

        drift, diffusion = drift_diffusion_fn(x, v, score_t=t, diff_t=t + delta_t.flatten()) if reverse else drift_diffusion_fn(x, t)
        diffusion = diffusion.reshape(-1, 1)               # [bsz,1]

        # if Reparametrization (Turn on when reparametrization)
        if REPARAM:
            pre_factor = 2 * sqrt_mass * sqrt_1ma2 / (diffusion ** 2 * torch.abs(delta_t))
            drift = drift * pre_factor

        # ── OA block ────────────────────────────────────────────────────────────
        z1 = torch.randn_like(x)
        oa_vel = manifold.project_onto_tangent_space(a * v + sqrt_mass * sqrt_1ma2 * z1 + (delta_t / 2) * drift, base_point=x)
        step_vec_oa = (diffusion ** 2 * delta_t / (4. * mass)) * oa_vel
        x_half, conv_oa = manifold.project_onto_manifold_with_base(step_vec_oa, base_point=x)

        # # ── OA block ────────────────────────────────────────────────────────────
        # z1 = torch.randn_like(x)
        # oa_vel = manifold.project_onto_tangent_space(a * v + sqrt_mass * sqrt_1ma2 * z1, base_point=x)
        # step_vec_oa = (diffusion ** 2 * delta_t / (4. * mass)) * oa_vel
        # x_half, conv_oa = manifold.project_onto_manifold_with_base(step_vec_oa, base_point=x)

        # auxiliary velocity after OA
        v_third = (4. * mass / (diffusion ** 2 * delta_t)) * (x_half - x)
        v_third = manifold.project_onto_tangent_space(v_third, base_point=x_half)

        drift, diffusion = drift_diffusion_fn(x_half, v_third, score_t = t + delta_t.flatten() / 2, diff_t=t + delta_t.flatten()) if reverse else drift_diffusion_fn(x_half, t)
        diffusion = diffusion.reshape(-1, 1)               # [bsz,1]
        tau = sde.get_tau_scheduler(t + delta_t.flatten()) if reverse else sde.get_tau_scheduler(t)
        tau = tau.reshape(-1, 1)

        if REPARAM:
            pre_factor = 2 * sqrt_mass * sqrt_1ma2 / (diffusion ** 2 * torch.abs(delta_t))
            drift = drift * pre_factor

        # ── BOA block ─────────────────────────────────────────────────────────────────
        z2 = torch.randn_like(x)
        boa_vel = manifold.project_onto_tangent_space(a * v_third + sqrt_mass * sqrt_1ma2 * z2 + ( delta_t / 2) * drift, base_point=x_half)
        step_vec_boa = (diffusion ** 2 * delta_t / (4. * mass)) * boa_vel
        x_next, conv_boa = manifold.project_onto_manifold_with_base(step_vec_boa, base_point=x_half)

        # new velocity from displacement
        v_next = (4. * mass / (diffusion ** 2 * delta_t)) * (x_next - x_half)
        v_next = manifold.project_onto_tangent_space(v_next, base_point=x_next)

        return x_next.detach(), x_half.detach(), v_next.detach(), torch.logical_and(conv_oa, conv_boa)

    # ── initialise tensors ─────────────────────────────────────────────────────
    x_hist = torch.zeros(N + 1, *init_x.shape, device=device)
    x_mid_hist = torch.zeros(N + 1, *init_x.shape, device=device)
    v_hist = torch.zeros(N + 1, *init_x.shape, device=device)

    x      = init_x
    if init_v is None:
        init_v = sqrt_mass * torch.randn_like(x)
    v      = manifold.project_onto_tangent_space(init_v, base_point=x)
    x_hist[0], v_hist[0] = x.clone(), v.clone()

    converged = torch.ones(init_x.shape[0], dtype=torch.bool, device=device)

    # ── main loop ──────────────────────────────────────────────────────────────
    for i in tqdm(range(N), mininterval=2., disable=keep_quiet):
        dt = timesteps[i + 1] - timesteps[i]
        x, x_mid, v, conv_step = update_fn_oaboa(x, v, timesteps[i], dt)
        converged &= conv_step
        x_hist[i + 1], x_mid_hist[i + 1], v_hist[i + 1] = x.clone(), x_mid.clone(), v.clone()

    if not keep_quiet:
        logging.info(f"{(init_x.shape[0] - converged.sum()).item()} of {init_x.shape[0]} "
                     f"trajectories were dropped due to projection failure.")

    other = {
        "converged_traj": converged,
        "x_hist_all": x_hist,
        "x_mid_hist_all": x_mid_hist,
        "v_hist_all": v_hist
    }

    # concat position & velocity like CHMC
    x_v      = torch.cat([x, x_mid, v], dim=-1)
    x_v_hist = torch.cat([x_hist, x_mid_hist, v_hist], dim=-1)
    return x_v[converged], x_v_hist[:, converged], other

@torch.no_grad()
def SDE_sampler_manifolds_ULLA_OABOA(
        sde,                      # SDE_Brownian_manifolds instance
        manifold,                 # manifold helper with project_* methods
        init_x,                   # [bsz, dim] initial positions on Σ
        reverse: bool,            # False → forward   True → backward
        score_net=None,           # required when reverse=True
        init_v=None,              # optional initial velocities
        keep_quiet: bool = False,
        **kwargs):

    # ── parameters ──────────────────────────────────────────────────────────────
    mass, gamma = _extract_mass_gamma(kwargs, device=init_x.device)
    if kwargs.__contains__('alpha'):
        alpha = kwargs['alpha']
    else:
        alpha = 100.0
        logging.info(f'Warning: alpha is not specified. Set alpha = {alpha}.')
    sqrt_mass   = mass ** 0.5
    device      = init_x.device
    N           = sde.N

    # choose drift / diffusion operator -------------------------------------------------
    if reverse:
        rsde = sde.reverse(score_net, underdamped=True)
        drift_diffusion_fn = rsde.sde
        timesteps = torch.linspace(sde.T, 0., N + 1, device=device).reshape(-1, 1).repeat(1, init_x.shape[0])
    else:
        drift_diffusion_fn = sde.sde
        timesteps = torch.linspace(0., sde.T, N + 1, device=device).reshape(-1, 1).repeat(1, init_x.shape[0])

    # ── helper: one OABOA step ─────────────────────────────────────────────────
    def update_fn_oaboa(x, v, t, delta_t, last = False):
        """
        x, v:   [bsz, dim] current state on the manifold & tangent bundle
        t:      [bsz]      current time
        delta_t:[bsz]      signed step (next - current)
        """
        delta_t = delta_t.reshape(-1, 1)                   # [bsz,1]

        diffusion = sde.get_diffusion(t + delta_t.flatten()) if reverse else sde.get_diffusion(t)  
        diffusion = diffusion.reshape(-1, 1)               # [bsz,1]   
        tau = sde.get_tau_scheduler(t + delta_t.flatten()) if reverse else sde.get_tau_scheduler(t)
        tau = tau.reshape(-1, 1)

        # exact OU coefficient a_t = exp(-σ² γ Δt /(4m))
        a = torch.exp(- diffusion ** 2 * torch.abs(delta_t) * gamma / (4. * mass))
        sqrt_1ma2 = torch.sqrt(torch.clamp(1. - a ** 2, min=0))

        drift, diffusion = drift_diffusion_fn(x, v, score_t=t, diff_t=t + delta_t.flatten()) if reverse else drift_diffusion_fn(x, t)
        diffusion = diffusion.reshape(-1, 1)               # [bsz,1]

        # ── OA block ────────────────────────────────────────────────────────────
        z1 = torch.randn_like(x)
        oa_vel = manifold.project_onto_tangent_space(a * v + sqrt_mass * sqrt_1ma2 * z1 + (tau * delta_t) * drift, base_point=x)

        oa_vel = manifold.adding_correction_decaying_momentum(oa_vel, base_point=x, base_momentum=v, delta_t=delta_t, alpha=alpha, sigma_sq=diffusion ** 2, mass=mass)

        step_vec_oa = (diffusion ** 2 * delta_t / (4. * mass)) * oa_vel
        x_half = x + step_vec_oa
        conv_oa = torch.ones(x_half.shape[0], dtype=torch.bool, device=device)

        # # ── OA block ────────────────────────────────────────────────────────────
        # z1 = torch.randn_like(x)
        # oa_vel = manifold.project_onto_tangent_space(a * v + sqrt_mass * sqrt_1ma2 * z1, base_point=x)
        # step_vec_oa = (diffusion ** 2 * delta_t / (4. * mass)) * oa_vel
        # x_half, conv_oa = manifold.project_onto_manifold_with_base(step_vec_oa, base_point=x)


        # auxiliary velocity after OA
        v_third = (4. * mass / (diffusion ** 2 * delta_t)) * (x_half - x)
        v_third = manifold.project_onto_tangent_space(v_third, base_point=x_half)

        drift, diffusion = drift_diffusion_fn(x_half, v_third, score_t = t + delta_t.flatten() / 2, diff_t=t + delta_t.flatten()) if reverse else drift_diffusion_fn(x_half, t)
        diffusion = diffusion.reshape(-1, 1)               # [bsz,1]
        tau = sde.get_tau_scheduler(t + delta_t.flatten()) if reverse else sde.get_tau_scheduler(t)
        tau = tau.reshape(-1, 1)

        # ── BOA block ─────────────────────────────────────────────────────────────────
        z2 = torch.randn_like(x)
        boa_vel = manifold.project_onto_tangent_space(a * v_third + sqrt_mass * sqrt_1ma2 * z2 + (tau * delta_t) * drift, base_point=x_half)
        boa_vel = manifold.adding_correction_decaying_momentum(boa_vel, base_point=x_half, base_momentum=v_third, delta_t=delta_t, alpha=alpha, sigma_sq=diffusion ** 2, mass=mass)

        step_vec_boa = (diffusion ** 2 * delta_t / (4. * mass)) * boa_vel
        x_next = x_half + step_vec_boa
        conv_boa = torch.ones(x_half.shape[0], dtype=torch.bool, device=device)

        # new velocity from displacement
        v_next = (4. * mass / (diffusion ** 2 * delta_t)) * (x_next - x_half)
        v_next = manifold.project_onto_tangent_space(v_next, base_point=x_next)

        if last:
            tangent_vec_zero = torch.zeros_like(v_next, device=device)  # last step, no correction
            x_next, conv_boa = manifold.project_onto_manifold_with_base(tangent_vec_zero, base_point=x_next)

            # x_next = manifold.force_project_SO(x_next)
            # conv_boa = torch.ones(x_next.shape[0], dtype=torch.bool, device=device)

        return x_next.detach(), x_half.detach(), v_next.detach(), torch.logical_and(conv_oa, conv_boa)

    # ── initialise tensors ─────────────────────────────────────────────────────
    x_hist = torch.zeros(N + 1, *init_x.shape, device=device)
    x_mid_hist = torch.zeros(N + 1, *init_x.shape, device=device)
    v_hist = torch.zeros(N + 1, *init_x.shape, device=device)

    x      = init_x
    if init_v is None:
        init_v = sqrt_mass * torch.randn_like(x)
    v      = manifold.project_onto_tangent_space(init_v, base_point=x)
    x_hist[0], v_hist[0] = x.clone(), v.clone()

    converged = torch.ones(init_x.shape[0], dtype=torch.bool, device=device)

    # ── main loop ──────────────────────────────────────────────────────────────
    for i in tqdm(range(N), mininterval=2., disable=keep_quiet):

        dt = timesteps[i + 1] - timesteps[i]
        if i == N-1:
            x, x_mid, v, conv_step = update_fn_oaboa(x, v, timesteps[i], dt, last=True)
            converged &= conv_step
        else:
            x, x_mid, v, conv_step = update_fn_oaboa(x, v, timesteps[i], dt, last=False)
        
        # if i > int(N * 0.95) and reverse:
        #     print(f"norm of x at step {i}: {torch.linalg.norm(x, dim=-1).mean().item():.4f}")

        x_hist[i + 1], x_mid_hist[i + 1], v_hist[i + 1] = x.clone(), x_mid.clone(), v.clone()

    if not keep_quiet:
        logging.info(f"{(init_x.shape[0] - converged.sum()).item()} of {init_x.shape[0]} "
                     f"trajectories were dropped due to projection failure.")

    other = {
        "converged_traj": converged,
        "x_hist_all": x_hist,
        "x_mid_hist_all": x_mid_hist,
        "v_hist_all": v_hist
    }

    # concat position & velocity like CHMC
    x_v      = torch.cat([x, x_mid, v], dim=-1)
    x_v_hist = torch.cat([x_hist, x_mid_hist, v_hist], dim=-1)
    return x_v[converged], x_v_hist[:, converged], other


@torch.no_grad()
def SDE_sampler_manifolds_CHMC_EM(sde,
                                    manifold,
                                    init,
                                    reverse,
                                    init_v=None,
                                    score_net=None,
                                    keep_quiet=False,
                                    **kwargs):
    """
    timesteps: shape [N+1, bsz]
    init: shape [bsz, dim]
    dt: shape [bsz]
    vec_t: shape [bsz]
    """
    device = init.device
    shape = init.shape
    N = sde.N

    _, gamma = _extract_mass_gamma(kwargs, device=init.device)


    if reverse:
        rsde = sde.reverse(score_net, underdamped=True)
        timesteps = torch.linspace(sde.T, 0., N+1, device=device).reshape(-1, 1).repeat(1, shape[0])
    else:
        timesteps = torch.linspace(0., sde.T, N+1, device=device).reshape(-1, 1).repeat(1, shape[0]) # shape [N+1, bsz]

    def update_fn_predictor(x, x_back, t, delta_t, initial_step = False):
        z = torch.randn_like(x, device=device)
        if reverse:
            # delta_t < 0
            sigma = sde.get_diffusion(t + delta_t) # shape [bsz, 1]
            sigma_back = sde.get_diffusion(t) # shape [bsz, 1]
            drift_fn = lambda x, v, t: rsde.drift_score(x, v, t) - sde.drift_b(x)
        else:
            # delta_t > 0
            sigma = sde.get_diffusion(t) # shape [bsz, 1] 
            sigma_back = sde.get_diffusion(t - delta_t) # shape [bsz, 1]
            drift_fn = lambda x, v=None, t=None: sde.drift_b(x)


        a = torch.exp(- (sigma ** 2) * torch.abs(delta_t) * gamma) # shape [bsz, 1]
        sqrt_1ma2 = torch.sqrt(torch.abs(1. - a**2)) # shape [bsz, 1]

        # Reshaping
        delta_t = delta_t.reshape(-1, 1)
        sigma = sigma.reshape(-1, 1)
        sigma_back = sigma_back.reshape(-1, 1)
        a = a.reshape(-1, 1)
        sqrt_1ma2 = sqrt_1ma2.reshape(-1, 1)

        if init_v is not None and initial_step:
            v = init_v
        elif initial_step:
            v = manifold.project_onto_tangent_space(torch.randn_like(x, device=device), base_point = x)
        else:
            v = (x - x_back) / (sigma_back ** 2 * delta_t)  # shape [bsz, dim]
            v = manifold.project_onto_tangent_space(v, base_point = x)
            
        # mu_f_pre_w_noise = a * v + (1 - a) / gamma * drift_fn(x, v, t) + sqrt_1ma2 * z
        mu_f_pre_w_noise = a * v + (sigma **2) * torch.abs(delta_t) * drift_fn(x, v, t) + sqrt_1ma2 * z


        # mu_f_pre_w_noise = a * v + drift_fn(x, v, t) + sqrt_1ma2 * z
        mu_f_pre_w_noise = (sigma ** 2) * delta_t * manifold.project_onto_tangent_space(mu_f_pre_w_noise, base_point=x) 
        x_prime, converged_traj_one_step = manifold.project_onto_manifold_with_base(mu_f_pre_w_noise, base_point=x)

        return x_prime.detach(), converged_traj_one_step

    x_hist = torch.zeros(N+1, *init.shape).to(device)
    x = init
    x_hist[0] = x.clone()

    converged_traj = torch.ones(init.shape[0], dtype=torch.bool).to(device)

    for i in tqdm(range(N), mininterval=2., disable=keep_quiet):
        vec_t = timesteps[i]
        dt = timesteps[i+1] - timesteps[i]
        if i == 0:
            x, converged_traj_one_step = update_fn_predictor(x, x, vec_t, dt, initial_step = True)
        else:
            x, converged_traj_one_step = update_fn_predictor(x, x_hist[i-1], vec_t, dt, initial_step = False)

        # converged_traj: shape [bsz] (bool) indicates whether the whole trajectory is successful so far    
        converged_traj = torch.logical_and(converged_traj, converged_traj_one_step)
        x_hist[i + 1] = x.clone()

    num_unconverged_traj = init.shape[0] - converged_traj.sum()
    if keep_quiet is False: logging.info(f'{num_unconverged_traj} of {init.shape[0]} trajectories are dropped.')
    other_dict = {"converged_traj":converged_traj, "x_hist_all": x_hist}
    return x[converged_traj, :], x_hist[:, converged_traj, :], other_dict


@torch.no_grad()
def SDE_sampler_manifolds_ULLA_EM(sde,
                                    manifold,
                                    init,
                                    reverse,
                                    init_v=None,
                                    score_net=None,
                                    keep_quiet=False,
                                    **kwargs):
    """
    timesteps: shape [N+1, bsz]
    init: shape [bsz, dim]
    dt: shape [bsz]
    vec_t: shape [bsz]
    """
    device = init.device
    shape = init.shape
    N = sde.N

    if kwargs.__contains__('alpha'):
        alpha = kwargs['alpha']
    else:
        alpha = 100.0
        logging.info(f'Warning: alpha is not specified. Set alpha = {alpha}.')

    _, gamma = _extract_mass_gamma(kwargs, device=init.device)

    if reverse:
        rsde = sde.reverse(score_net, underdamped=True)
        timesteps = torch.linspace(sde.T, 0., N+1, device=device).reshape(-1, 1).repeat(1, shape[0])
    else:
        timesteps = torch.linspace(0., sde.T, N+1, device=device).reshape(-1, 1).repeat(1, shape[0]) # shape [N+1, bsz]

    def update_fn_predictor(x, x_back, t, delta_t, initial_step = False, last = False):
        z = torch.randn_like(x, device=device)
        if reverse:
            # delta_t < 0
            sigma = sde.get_diffusion(t + delta_t) # shape [bsz, 1]
            sigma_back = sde.get_diffusion(t) # shape [bsz, 1]
            drift_fn = lambda x, v, t: rsde.drift_score(x, v, t) - sde.drift_b(x)
        else:
            # delta_t > 0
            sigma = sde.get_diffusion(t) # shape [bsz, 1] 
            sigma_back = sde.get_diffusion(t - delta_t) # shape [bsz, 1]
            drift_fn = lambda x, v=None, t=None: sde.drift_b(x)


        a = torch.exp(- (sigma ** 2) * torch.abs(delta_t) * gamma) # shape [bsz, 1]
        sqrt_1ma2 = torch.sqrt(torch.abs(1. - a**2)) # shape [bsz, 1]

        # Reshaping
        delta_t = delta_t.reshape(-1, 1)
        sigma = sigma.reshape(-1, 1)
        sigma_back = sigma_back.reshape(-1, 1)
        a = a.reshape(-1, 1)
        sqrt_1ma2 = sqrt_1ma2.reshape(-1, 1)

        if init_v is not None and initial_step:
            v = init_v
        elif initial_step:
            v = manifold.project_onto_tangent_space(torch.randn_like(x, device=device), base_point = x)
        else:
            v = (x - x_back) / (sigma_back ** 2 * delta_t)  # shape [bsz, dim]
            v = manifold.project_onto_tangent_space(v, base_point = x)
            
        # mu_f_pre_w_noise = a * v + (1 - a) / gamma * drift_fn(x, v, t) + sqrt_1ma2 * z
        mu_f_pre_w_noise = a * v + (sigma **2) * torch.abs(delta_t)  * drift_fn(x, v, t) + sqrt_1ma2 * z

        mu_f_pre_w_noise = (sigma ** 2) * delta_t * manifold.project_onto_tangent_space(mu_f_pre_w_noise, base_point=x) 

        x_prime = manifold.adding_correction_decaying(mu_f_pre_w_noise, base_point=x, delta_t = delta_t, alpha= alpha, sigma_sq = sigma ** 2)

        converged_traj_one_step = torch.ones(x_prime.shape[0], dtype=torch.bool, device=device)
        if last:
            tangent_vec_zero = torch.zeros_like(mu_f_pre_w_noise, device=device)  # last step, no correction
            x_prime, converged_traj_one_step = manifold.project_onto_manifold_with_base(tangent_vec_zero, base_point=x_prime)

        return x_prime.detach(), converged_traj_one_step

    x_hist = torch.zeros(N+1, *init.shape).to(device)
    x = init
    x_hist[0] = x.clone()

    converged_traj = torch.ones(init.shape[0], dtype=torch.bool).to(device)

    for i in tqdm(range(N), mininterval=2., disable=keep_quiet):
        vec_t = timesteps[i]
        dt = timesteps[i+1] - timesteps[i]

        # --- MODIFICATION: PRINT VIOLATIONS AT EACH STEP ---
        if True and hasattr(manifold, 'report_violations'):
            # Get the violation string from the manifold and print it
            violation_str = manifold.report_violations(x)
            print(f"Step {i+1}/{N}: {violation_str}")
        # --- END OF MODIFICATION ---

        if i == 0:
            x, converged_traj_one_step = update_fn_predictor(x, x, vec_t, dt, initial_step = True)
        elif i == N-1:
            x, converged_traj_one_step = update_fn_predictor(x, x_hist[i-1], vec_t, dt, initial_step = False, last=True)
        else:
            x, converged_traj_one_step = update_fn_predictor(x, x_hist[i-1], vec_t, dt, initial_step = False)

        # converged_traj: shape [bsz] (bool) indicates whether the whole trajectory is successful so far    
        converged_traj = torch.logical_and(converged_traj, converged_traj_one_step)
        x_hist[i + 1] = x.clone()

    num_unconverged_traj = init.shape[0] - converged_traj.sum()
    if keep_quiet is False: logging.info(f'{num_unconverged_traj} of {init.shape[0]} trajectories are dropped.')
    other_dict = {"converged_traj":converged_traj, "x_hist_all": x_hist}
    return x[converged_traj, :], x_hist[:, converged_traj, :], other_dict