import math
from typing import Optional, Union, Tuple

import torch
import torch.nn.functional as F
import os

import numpy as np
try:
    import ot  # POT: Python Optimal Transport
except Exception as e:
    ot = None

# =========================
# Module-level utilities
# =========================
def reshape_last_dim_to_2d(
    x: torch.Tensor,
    N: int,
    M: int,
    out: Optional[torch.Tensor] = None,
    make_contiguous: bool = False,
) -> torch.Tensor:
    """
    Reshape the last dim of `x` (size L) into two dims (N, M).
    Exact-match only: requires L == N*M.

    Args:
      x: (*prefix, L)
      N, M: target factors of L
      out: optional preallocated tensor with shape (*prefix, N, M)
      make_contiguous: if True, return a contiguous tensor (ignored when `out` is given)
    """
    torch._assert(N > 0 and M > 0, "N and M must be positive")
    L = x.size(-1)
    target = N * M
    torch._assert(L == target, f"Last dim {L} must equal N*M={target}")

    target_shape = (*x.shape[:-1], N, M)
    y = x.reshape(*target_shape)  # view when possible, copy otherwise

    if out is not None:
        torch._assert(out.shape == y.shape, "`out` has wrong shape")
        torch._assert(out.dtype == x.dtype, "`out` has wrong dtype")
        torch._assert(out.device == x.device, "`out` on wrong device")
        out.copy_(y)
        return out

    return y.contiguous() if make_contiguous else y

def atan_k(s: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
    """
    Universal atan_κ:
      K<0: atanh(s),  K=0: s,  K>0: atan(s).
    Shapes must be broadcastable.
    """
    s = s
    K = K.to(s.dtype)

    tol = torch.finfo(s.dtype).eps
    neg = (K < -tol).to(s.dtype)
    zer = (K.abs() <= tol).to(s.dtype)
    pos = (K >  tol).to(s.dtype)

    # Safe atanh: only evaluate where K<0
    s_for_atanh = torch.where(neg.bool(), s, torch.zeros_like(s))
    s_for_atanh = s_for_atanh.clamp(min=-(1 - tol), max=(1 - tol))
    v_neg = torch.atanh(s_for_atanh)

    # Safe atan: only evaluate where K>0
    s_for_atan = torch.where(pos.bool(), s, torch.zeros_like(s))
    v_pos = torch.atan(s_for_atan)

    return neg * v_neg + zer * s + pos * v_pos

def tan_k(s: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
    """
    Universal tan_κ:
      K<0: tanh(s),  K=0: s,  K>0: tan(s).
    Shapes must be broadcastable.
    """
    s = s
    K = K.to(s.dtype)

    tol = torch.finfo(s.dtype).eps
    neg = (K < -tol).to(s.dtype)
    zer = (K.abs() <= tol).to(s.dtype)
    pos = (K >  tol).to(s.dtype)

    v_neg = torch.tanh(s)
    # Evaluate tan only where needed to avoid infs at π/2 + kπ contaminating other branches
    s_for_tan = torch.where(pos.bool(), s, torch.zeros_like(s))
    v_pos = torch.tan(s_for_tan)

    return neg * v_neg + zer * s + pos * v_pos

def sin_k(r: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
    """
    Universal sine S_K:
      K<0: sinh(√|K| r)/√|K|,  K=0: r,  K>0: sin(√K r)/√K.
    Shapes must be broadcastable.
    """
    r = r
    K = K.to(r.dtype)

    tol = torch.finfo(r.dtype).eps
    negb = (K < -tol)
    zerb = (K.abs() <= tol)
    posb = (K >  tol)

    # √|K| with safe denominator for zero curvature
    sqrt_absK = torch.sqrt(K.abs())
    denom = torch.where(zerb, torch.ones_like(sqrt_absK), sqrt_absK)

    a = denom * r
    v_neg = torch.sinh(a) / denom
    v_zer = r
    v_pos = torch.sin(a) / denom

    neg = negb.to(r.dtype); zer = zerb.to(r.dtype); pos = posb.to(r.dtype)
    return neg * v_neg + zer * v_zer + pos * v_pos

def asin_k(s: torch.Tensor, K: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    Inverse of S_K:
      K<0: asinh(√|K| s)/√|K|,  K=0: s,  K>0: arcsin(√K s)/√K.
    Shapes must be broadcastable; clamps as needed for stability.
    """
    s = s
    K = K.to(s.dtype)

    tol = torch.finfo(s.dtype).eps
    negb = (K < -tol)
    zerb = (K.abs() <= tol)
    posb = (K >  tol)

    sqrt_absK = torch.sqrt(K.abs())
    denom = torch.where(zerb, torch.ones_like(sqrt_absK), sqrt_absK)

    z = denom * s
    # Clamp only affects the arcsin path; safe to compute everywhere
    epsv = torch.as_tensor(eps, dtype=s.dtype, device=s.device)
    z_clip = z.clamp(min=-(1 - epsv), max=(1 - epsv))

    v_neg = torch.asinh(z) / denom
    v_zer = s
    v_pos = torch.asin(z_clip) / denom

    neg = negb.to(s.dtype); zer = zerb.to(s.dtype); pos = posb.to(s.dtype)
    return neg * v_neg + zer * v_zer + pos * v_pos

def radius_from_K(K: torch.Tensor, device=None) -> torch.Tensor:
    """
    Per-component gyro-ball radius:
      K != 0: 1/sqrt(|K|)
      K == 0: +inf
    Input:
      K: (N,) or (N, 1)
    Return:
      r: (1, N, 1)
    """
    dev = device if device is not None else K.device
    Kf  = torch.as_tensor(K, device=dev).reshape(-1)                 # (N,)
    eps_dtype = torch.finfo(Kf.dtype if Kf.is_floating_point() else torch.float32).eps
    Kf = Kf.to(dtype=torch.get_default_dtype() if not Kf.is_floating_point() else Kf.dtype)

    absK  = Kf.abs()                                                 # (N,)
    sqrtA = torch.sqrt(absK)                                         # (N,)
    infv  = torch.tensor(float('inf'), device=dev, dtype=Kf.dtype)

    # r_i = 1/sqrt(|K_i|) if |K_i|>0 else +inf
    r = torch.where(absK > 0, 1.0 / sqrtA, infv)                     # (N,)
    return r.view(1, -1, 1)                                          # (1, N, 1)


def project_inside_ball(x: torch.Tensor, K: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    Softly project each component of x inside its κ-ball:
      target radius = (1 - eps) * radius_from_K(K)
    No-op for K == 0 (Euclidean) and K > 0 (projected sphere).
    Input:
      x: (..., N, M)
      K: (N,) or (N, 1)
    Return:
      x_proj: (..., N, M)
    """
    *prefix, N, M = x.shape                                           # (..., N, M)
    dev, dt = x.device, x.dtype

    r = radius_from_K(K, device=dev).to(dtype=dt)                     # (1, N, 1)
    r = (1.0 - float(eps)) * r                                        # (1, N, 1)

    # per-component norms
    n = x.norm(dim=-1, keepdim=True)                                  # (..., N, 1)

    tiny = torch.finfo(dt).eps
    ratio = r / n.clamp_min(tiny)                                     # (..., N, 1), broadcast r
    scale = torch.minimum(torch.ones_like(ratio), ratio)              # (..., N, 1)

    # exact no-op for K >= 0
    K_flat = torch.as_tensor(K, device=dev).reshape(-1).to(dtype=dt)  # (N,)
    zero_mask = (K_flat >= 0).view(1, -1, 1)                          # (1, N, 1)
    scale = torch.where(zero_mask, torch.ones_like(scale), scale)     # (..., N, 1)

    return x * scale                                                  # (..., N, M)


def mobius_add_core(
    x: torch.Tensor,   # (..., A, N, M)
    y: torch.Tensor,   # (..., B, N, M)
    K: torch.Tensor,   # (N,) or (N,1)
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    Pair-wise κ-Möbius addition per component.
    Broadcasts over prefix and A/B axes; returns (..., A, B, N, M).
    `clip` is ignored (no projection), by design for `torch.compile`.
    """
    # --- shape checks (kept inside graph) ---
    torch._assert(x.size(-1) == y.size(-1), "M mismatch between x and y")
    torch._assert(x.size(-2) == y.size(-2), "N mismatch between x and y")
    N = x.size(-2)
    torch._assert(K.numel() in (N, N * 1), "K must have N entries")

    # Flatten K to (N,) and align dtype/device
    # K_flat = K.reshape(-1).to(dtype=x.dtype, device=x.device)

    # Insert pairwise axes: x (..., A, 1, N, M), y (..., 1, B, N, M)
    x = x.unsqueeze(-3)
    y = y.unsqueeze(-4)

    # Build K with shape (..., 1, 1, N, 1) for broadcasting
    # result_ndim = max(x.ndim, y.ndim)  # after unsqueeze
    # K_view = [1] * (result_ndim - 4) + [1, 1, N, 1]
    Kb = _Kb_pair_like(x, N, K)  # K_flat.view(*K_view)

    # Pair-wise inner products/norms over last axis M
    xy  = (x * y).sum(dim=-1, keepdim=True)  # (..., A, B, N, 1)
    nx2 = (x * x).sum(dim=-1, keepdim=True)  # (..., A, B, N, 1)
    ny2 = (y * y).sum(dim=-1, keepdim=True)  # (..., A, B, N, 1)

    one = x.new_tensor(1.0)
    t = one - 2.0 * Kb * xy  # shared term

    # Numerator and denominator (use shared term `t` to save work)
    num = (t - Kb * ny2) * x + (one + Kb * nx2) * y
    den = t + (Kb * Kb) * nx2 * ny2
    den = den.clamp_min(x.new_tensor(eps))

    # den = den.clamp(min=x.new_tensor(eps))
    z = num / den
    return z

def _Kb_pair_like(x_pair: torch.Tensor, N: int, K: torch.Tensor) -> torch.Tensor:
    """
    Broadcast K to pairwise shapes:
      x_pair has shape (..., A, 1, N, M)  (or any pairwise container)
      returns shape (..., 1, 1, N, 1) that broadcasts to (..., A, B, N, 1)
    """
    Kf = K.reshape(-1).to(dtype=x_pair.dtype, device=x_pair.device)
    return Kf.view(*([1] * (x_pair.ndim - 4)), 1, 1, N, 1)

def _lambda_x_K_single(x: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
    """
    λ_x^K = 2 / (1 + K ||x||^2) per component.
    x: (..., A, N, M)  → returns (..., A, N)
    """
    Ka = K.to(dtype=x.dtype, device=x.device).view(*([1] * (x.ndim - 2)), -1)  # (..., 1, N)
    nx2 = (x * x).sum(dim=-1)                                                  # (..., A, N)
    return 2.0 / (1.0 + Ka * nx2)

def gather_along_dim(x: torch.Tensor, index: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Gather `x` along axis `dim` using a 1D `index` (length B).
    Output shape = x.shape with dimension `dim` replaced by B.
    x: (..., (dim) N, ...)
    index: ndim where ndim-1 first dim is for batch, each batch gather by the last, ie (..., B)
    """
    # ---- compile-friendly shape checks ----
    x_ndim = x.ndim
    index_ndim = index.ndim
    dim = dim % x_ndim  # support negative dims
    torch._assert(index_ndim - 1 <= dim and dim < x_ndim, "dim out of range")

    # Ensure device/dtype for gather
    index = index.to(device=x.device, dtype=torch.long)
    B = index.shape[-1]

    # Build output shape and broadcasted index
    out_shape = list(x.shape)
    out_shape[dim] = B

    view_shape = list(index.shape[:-1]) + [1] * (x_ndim - index_ndim + 1)
    view_shape[dim] = B
    index_expanded = index.view(view_shape).expand(out_shape)

    # Gather along the chosen dim
    return torch.gather(x, dim=dim, index=index_expanded)

def _exp_map(x: torch.Tensor, v: torch.Tensor, K: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    Pairwise exponential map (per component):
    exp_x^K(v) = x ⊕_K ( tan_K( (√|K| λ_x^K ||v||)/2 ) * v / (√|K| ||v||) )
    Inputs:
    x: (..., A, N, M)
    v: (..., B, N, M)
    Output:
    (..., A, B, N, M)
    """
    N = x.shape[-2]
    x_pair = x.unsqueeze(-3)                         # (..., A, 1, N, M)   <-- was -2
    v_pair = v.unsqueeze(-4)                         # (..., 1, B, N, M)   <-- was -3

    lam_x = _lambda_x_K_single(x, K)[..., None, :, None]    # (..., A, 1, N, 1)

    Kb   = _Kb_pair_like(x_pair, N, K)     # (..., 1, 1, N, 1)
    sqK  = torch.sqrt(Kb.abs())                      # (..., 1, 1, N, 1)
    nv   = (v_pair * v_pair).sum(dim=-1, keepdim=True).sqrt()  # (..., 1, B, N, 1)

    s      = 0.5 * sqK * lam_x * nv                 # (..., A, B, N, 1)
    t_s    = tan_k(s, Kb)
    denom  = (sqK * nv).clamp_min(x.new_tensor(eps)) # (..., 1, B, N, 1)
    factor = t_s / denom                             # (..., A, B, N, 1)

    factor0 = 0.5 * lam_x                            # (..., A, 1, N, 1)
    factor  = torch.where(torch.isclose(Kb, torch.zeros_like(Kb).to(Kb.device)), factor0, factor)

    u_pair = factor * v_pair                         # (..., A, B, N, M)
    added = mobius_add_core(x_pair, u_pair, K, eps).squeeze(-4)  # (..., A, B, N, M)
    return added


# =========================
# Mixed-Curvature Space
# =========================

class MCS:
    def __init__(self, N: int, M: int, K: torch.Tensor, device=None):
        """
        Container for a mixed-curvature space with:
          N: number of components, each of dimension M
          K: per-component curvature tensor shaped (N,) or (N, 1)
        """
        self.N = N
        self.M = M
        self.NM = N * M
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.K = torch.as_tensor(K, device=self.device)
        self.mus = None
        if self.K.ndim == 0:
            self.K = self.K.unsqueeze(0)
        
        self.zeros = torch.zeros(self.N, self.M).to(device)
        
        # print(K.shape, N)
        assert self.K.shape[-1] == N, "K must have N rows"

        if 'COMPILE' in os.environ and os.environ['COMPILE'] in ['1', 'TRUE']:
          self._ensure_2d = torch.compile(self._ensure_2d)
          self.mobius_add = torch.compile(self.mobius_add)
          self.mobius_add_pairwise = torch.compile(self.mobius_add_pairwise)
          self.distance = torch.compile(self.distance)
          self.distance_points_lines_ori = torch.compile(self.distance_points_lines_ori)
          self.distance_points_lines = torch.compile(self.distance_points_lines)
          self.exp_map = torch.compile(self.exp_map)
          self.exp_map_pairwise = torch.compile(self.exp_map_pairwise)
          self.log_map = torch.compile(self.log_map)
          self.gyration = torch.compile(self.gyration)
          self.parallel_transport = torch.compile(self.parallel_transport)
          self.log_wrapped_normal = torch.compile(self.log_wrapped_normal)


    @staticmethod
    def _ensure_2d(x: torch.Tensor, N: int, M: int) -> torch.Tensor:
        """Accept (..., A, N, M) or (..., A, N*M); return (..., A, N, M)."""
        if x.shape[-2:] == (N, M):
            return x
        # else expect flattened last dim
        torch._assert(x.shape[-1] == N * M, "Last dim must equal N*M")
        return reshape_last_dim_to_2d(x, N, M)

    # -------------------------
    # Public ops
    # -------------------------
    def mobius_add(
        self,
        x: torch.Tensor,  # (..., A, N, M)
        y: torch.Tensor,  # (..., B, N, M)
        eps: float = 1e-6,
    ) -> torch.Tensor:
        """
        Pair-wise Möbius addition per component with this space's K.
        Returns: (..., A, B, N, M).
        """
        return mobius_add_core(x, y, self.K, eps)

    def mobius_add_pairwise(
        self,
        x: torch.Tensor,  # (..., N, M)
        y: torch.Tensor,  # (..., N, M)
        eps: float = 1e-6,
    ) -> torch.Tensor:
        return mobius_add_core(x.unsqueeze(-3), y.unsqueeze(-3), self.K, eps=eps).squeeze(-3).squeeze(-3)

    def distance(
        self,
        x: torch.Tensor,   # (..., A, N, M) or (..., A, N*M)
        y: torch.Tensor,   # (..., B, N, M) or (..., B, N*M)
        eps: float = 1e-12,
        reduce: bool = True,
    ) -> torch.Tensor:
        """
        Pair-wise geodesic distance per component using κ-geometry.
        Returns: (..., A, B) if reduce else (..., A, B, N).
        """
        x = self._ensure_2d(x, self.N, self.M)   # (..., A, N, M)
        y = self._ensure_2d(y, self.N, self.M)   # (..., B, N, M)
        device, dtype = x.device, x.dtype

        # Pairwise Möbius displacement z = (-x) ⊕ y
        z  = self.mobius_add(-x, y, eps=eps)           # (..., A, B, N, M)
        nz = z.norm(dim=-1)                            # (..., A, B, N)

        # Broadcast K to match nz
        Kb = self.K.to(device=device, dtype=dtype).view(*([1] * (nz.ndim - 1)), self.N).expand_as(nz)

        # Curved components: d = 2 * atan_k( sqrt(|K|) * nz ) / sqrt(|K|)
        sqrt_absK = torch.sqrt(Kb.abs())
        s = nz * sqrt_absK
        aK = atan_k(s, Kb)
        denom = torch.where(Kb != 0, sqrt_absK.clamp_min(eps), torch.ones_like(sqrt_absK))
        d_curved = 2 * aK / denom

        d = torch.where(torch.isclose(Kb, torch.zeros_like(Kb).to(self.device), atol=eps), 2*nz, d_curved)    # (..., A, B, N)
        return torch.sqrt((d * d).sum(dim=-1)+eps) if reduce else d

    def distance_pairwise(self, x, y, eps=1e-6, reduce=True):
        if reduce:
            return self.distance(x.unsqueeze(-3), y.unsqueeze(-3), eps, reduce).squeeze(-1).squeeze(-1)
        else:
            return self.distance(x.unsqueeze(-3), y.unsqueeze(-3), eps, reduce).squeeze(-2).squeeze(-2)

    def distance_points_lines_ori(self, points, intercept, index, eps=1e-6):
        # points:    (..., C, A, N, M)
        # intercept: (..., C, B, M)
        # index:     (..., C, B) in [0, N)
        points = self._ensure_2d(points, self.N, self.M)
        index  = index.to(points.device, dtype=torch.long)
        N, M   = self.N, self.M

        # Select per-line component from points: (..., C, A, B, M)
        sel_points = gather_along_dim(points, index, dim=-2)

        # --- FIX 1: expand K to match index leading batch ---
        K_view = self.K.view(*([1] * (index.ndim - 1)), -1).expand(*index.shape[:-1], N)  # (..., C, N)
        sel_K  = gather_along_dim(K_view, index, dim=-1)                                   # (..., C, B)
        K_ab   = sel_K.unsqueeze(-2)                                                       # (..., C, 1, B)

        # Angle between selected point-vectors and line direction
        u = intercept.unsqueeze(-3)                            # (..., C, 1, B, M)
        dot    = (sel_points * u).sum(dim=-1)                  # (..., C, A, B)
        u_norm = u.norm(dim=-1).clamp_min(eps)               # (..., C, 1, B)
        p_norm = sel_points.norm(dim=-1).clamp_min(eps)      # (..., C, A, B)
        cos_ang = (dot / (u_norm * p_norm)).clamp(min=-1+eps, max=1-eps)     # (..., C, A, B)
        sin_ang = torch.sqrt(torch.clamp(1.0 - cos_ang * cos_ang, min=0.0))

        # mismatch in cos_ang when cos_ang ~ -1
        # print('***mcs***')
        # print(f'cos: {cos_ang[0, 3, 2].item():.8f}')
        # print(f'sin: {sin_ang[0, 3, 2].item():.8f}')

        # Per-component distances to origin (compute ONCE)
        zeros = points.new_zeros(*points.shape[:-3], 1, N, M)  # (..., C, 1, N, M)
        # --- FIX 2: squeeze the B axis (-2), not -3 ---
        d_per_comp = self.distance(points, zeros, reduce=False).squeeze(-2)  # (..., C, A, N)

        # In-component radial distance
        p_norm_mcs = gather_along_dim(d_per_comp, index, dim=-1)  # (..., C, A, B)

        # Sine rule in selected component
        S    = sin_k(p_norm_mcs, K_ab)               # (..., C, A, B)
        d_in = asin_k(S * sin_ang, K_ab)             # (..., C, A, B)
        d_in2 = d_in * d_in

        # Sum squares from OTHER components
        d_all_exp = d_per_comp.unsqueeze(-2).expand(*d_per_comp.shape[:-1], index.shape[-1], N)  # (..., C, A, B, N)
        mask = F.one_hot(index.long(), num_classes=N).to(dtype=torch.bool).unsqueeze(-3)         # (..., C, 1, B, N)
        d_out2 = (d_all_exp.masked_fill(mask, 0.0) ** 2).sum(dim=-1)                             # (..., C, A, B)

        # Return (..., C, B, A)
        return torch.sqrt(d_in2 + d_out2).transpose(-1, -2)

    def distance_points_lines(
        self,
        points: torch.Tensor,     # (..., A, N, M)  A = n_points
        roots: torch.Tensor,      # (..., C, N, M)  C = n_trees (per-line origins / translations)
        intercept: torch.Tensor,  # (..., C, B, M)  B = n_lines / tree (per-line direction in its component)
        index: torch.Tensor,      # (..., C, B) component indices in [0, N)
    ) -> torch.Tensor:
        """
        Distance from points to translated lines:
        - Translate by Möbius: points' frame centered at each root
        - intercept is assumed after translate
        - Reuse `distance_points_lines_ori`
        Returns: (..., C, B, A)
        """
        torch._assert(index.dtype==torch.long, "index dtype must be torch.long/torch.int64")
        pts_t = self.mobius_add(-roots, points)  # (..., C, A, N, M)
        # no need to translate intercept
        return self.distance_points_lines_ori(pts_t, intercept, index)

    # -------------------------
    # exp map and log map
    # -------------------------
    def exp_map(self, x: torch.Tensor, v: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
        """
        Pairwise exponential map (per component):
        exp_x^K(v) = x ⊕_K ( tan_K( (√|K| λ_x^K ||v||)/2 ) * v / (√|K| ||v||) )
        Inputs:
        x: (..., A, N, M)
        v: (..., B, N, M)
        Output:
        (..., A, B, N, M)
        """
        return _exp_map(x, v, self.K, eps)

    def exp_map_pairwise(self, x: torch.Tensor, v: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
        """
        Pairwise exponential map (per component):
        exp_x^K(v) = x ⊕_K ( tan_K( (√|K| λ_x^K ||v||)/2 ) * v / (√|K| ||v||) )
        Inputs:
        x: (..., N, M)
        v: (..., N, M)
        Output:
        (..., N, M)
        """
        return _exp_map(x.unsqueeze(-3), v.unsqueeze(-3), self.K, eps).squeeze(-3).squeeze(-3)

    def log_map(self, x: torch.Tensor, y: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
        """
        Pairwise logarithmic map (per component):
        log_x^K(y) = (2 / (√|K| λ_x^K)) * atan_K( √|K| || -x ⊕_K y || )
                    * (-x ⊕_K y) / || -x ⊕_K y ||
        Inputs:
        x: (..., A, N, M)
        y: (..., B, N, M)
        Output:
        (..., A, B, N, M)
        """
        x = self._ensure_2d(x, self.N, self.M)
        y = self._ensure_2d(y, self.N, self.M)

        z = self.mobius_add(-x, y, eps)  # (..., A, B, N, M)

        nz  = (z * z).sum(dim=-1, keepdim=True).sqrt()  # (..., A, B, N, 1)
        Kb  = _Kb_pair_like(z, self.N, self.K)                                       # (..., 1, 1, N, 1)
        sqK = torch.sqrt(Kb.abs())                                                   # (..., 1, 1, N, 1)

        a       = atan_k(sqK * nz, Kb)
        ratio   = a / (sqK * nz).clamp_min(x.new_tensor(eps))                        # (..., A, B, N, 1)
        ratio   = torch.where(torch.isclose(Kb, torch.zeros_like(Kb, device=Kb.device)), torch.ones_like(ratio, device=ratio.device), ratio)

        lam_x = _lambda_x_K_single(x, self.K)[..., None, :, None]                    # (..., A, 1, N, 1)
        coef  = (2.0 / lam_x)                                                        # (..., A, 1, N, 1)

        return coef * ratio * z                                                      # (..., A, B, N, M)

    def log_map_pairwise(self, x, y, eps=1e-6):
        '''
        x: (..., A, N, M)
        y: (..., A, N, M)
        out: (..., A, N, M)
        '''
        return self.log_map(x.unsqueeze(-3), y.unsqueeze(-3), eps).squeeze(-3).squeeze(-3)

    def gyration(self, a: torch.Tensor, b: torch.Tensor, v: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
        """
        gyr[a,b]v = ⊖(a⊕b) ⊕ ( a ⊕ ( b ⊕ v ) )
        a, b, v : (..., N, M)
        returns : (..., N, M)
        """
        # simple shape checks
        torch._assert(a.ndim >= 2 and b.ndim == a.ndim and v.ndim == a.ndim, "a,b,v must have same ndim")
        torch._assert(a.shape[-2] == b.shape[-2] == v.shape[-2] == self.K.numel(), "N mismatch")

        # compute a ⊕ b and its inverse: (..., N, M)
        a_plus_b = self.mobius_add_pairwise(a, b, eps)        # (..., N, M)
        neg_a_plus_b = self.mobius_add_pairwise(torch.zeros_like(a_plus_b, device=self.device), -a_plus_b, eps)  # (..., N, M)

        # compute inner: b ⊕ v and then a ⊕ (b ⊕ v): (..., N, M)
        b_plus_v = self.mobius_add_pairwise(b, v, eps)        # (..., N, M)
        a_plus_bv = self.mobius_add_pairwise(a, b_plus_v, eps)  # (..., N, M)

        # final gyration: ⊖(a⊕b) ⊕ ( a ⊕ ( b ⊕ v ) ) : (..., N, M)
        out = self.mobius_add_pairwise(neg_a_plus_b, a_plus_bv, eps)  # (..., N, M)
        return out


    def parallel_transport(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
        """
        PT^K_{x->y}(v) = (lambda_x / lambda_y) * gyr[y, -x](v)
        x, y, v : (..., N, M)
        returns  : (..., N, M)
        """
        torch._assert(x.ndim >= 2 and y.ndim == x.ndim and v.ndim == x.ndim, "x,y,v must have same ndim")
        torch._assert(x.shape[-2] == y.shape[-2] == v.shape[-2] == self.K.numel(), "N mismatch")

        if torch.isclose(x, y).all():
            return v
        else:
            # lambda factors: (..., N)
            lam_x = _lambda_x_K_single(x, self.K)  # (..., N)
            lam_y = _lambda_x_K_single(y, self.K)  # (..., N)
            # print('lam x, lam y: ', lam_x, lam_y)


            # gyration with (y, -x): (..., N, M)
            gyr_out = self.gyration(y, -x, v, eps)  # (..., N, M)
            # print('gyr: ', gyr_out)

            # scalar ratio and final: (..., N, 1) * (..., N, M) -> (..., N, M)
            # lam_y_safe = lam_y.clamp_min(1e-12)  # (..., N)
            ratio = (lam_x / lam_y).unsqueeze(-1)  # (..., N, 1)
            transported = ratio * gyr_out  # (..., N, M)
            return transported

    # stereographic back projection
    def stereographi_back_proj(self, x: torch.Tensor, eps=1e-4):
        x_norm = x.norm(dim=-1)
        first = (1-self.K*x_norm**2)/(1+self.K*x_norm**2) / self.K.abs()**(1/2)
        first[:, torch.isclose(self.K, torch.zeros_like(self.K, device=self.device), atol=eps)] = 0.
        first = first.unsqueeze(-1)
        second = 2/(1+self.K*x_norm**2).unsqueeze(-1) * x
        cat = torch.cat([first, second], dim=-1)
        return cat

    # tree sampling
    @torch.no_grad()
    def generate_trees_frames(self, ntrees: int, nlines: int, batch=()):
        """
        Generate root frames and intercept frames for many trees/lines.

        - root: shape (*batch, ntrees, N, M)
        - intercept: shape (*batch, ntrees, nlines, M)
                    produced by mapping random intercept vectors through exp_map
                    with per-intercept curvature selected from self.K.
        """
        # normalize batch -> tuple
        batch_dims = (int(batch),) if isinstance(batch, int) and batch else tuple(batch)  # batch tuple

        device = self.device
        dtype = torch.get_default_dtype()

        root_raw = torch.randn(*batch_dims, ntrees, self.N, self.M, device=device, dtype=dtype)  # (*batch, ntrees, N, M)
        zero_root = torch.zeros(*batch_dims, 1, self.N, self.M, device=device, dtype=dtype)      # (*batch, 1, N, M)
        root_mapped = self.exp_map(zero_root, root_raw)                                         # (*batch, 1, ntrees, N, M)
        root = root_mapped.reshape(*batch_dims, ntrees, self.N, self.M)                         # (*batch, ntrees, N, M)

        intercept_raw = torch.randn(*batch_dims, ntrees, nlines, self.M, device=device, dtype=dtype)  # (*batch, ntrees, nlines, M)
        intercept_index = torch.randint(high=self.N, size=(*batch_dims, ntrees, nlines, 1), device=device, dtype=torch.long)  # (*batch, ntrees, nlines, 1)

        # n_leading_ones = intercept_index.ndim - 1
        # base_K = self.K.view((1,) * n_leading_ones + (self.N,)).expand(*intercept_index.shape[:-1], self.N)  # (*batch, ntrees, nlines, N)
        # sel_K = gather_along_dim(base_K, intercept_index, dim=-1).squeeze(-1).reshape(-1)                    # (total,)

        # intercept_flat = intercept_raw.reshape(1, -1, self.M)                     # (1, total, M)
        # zero_for_intercept = torch.zeros(1, intercept_flat.shape[1], self.M, device=device, dtype=dtype)  # (1, total, M)
        # intercept_mapped = _exp_map(zero_for_intercept, intercept_flat, sel_K)    # (1, total, M) or equivalent
        # intercept_flat_out = intercept_mapped.reshape(-1, self.M)                 # (total, M)
        # intercept = intercept_flat_out.reshape(*batch_dims, ntrees, nlines, self.M)  # (*batch, ntrees, nlines, M)

        intercept = intercept_raw / intercept_raw.norm(dim=-1, keepdim=True)

        return root, intercept, intercept_index.squeeze(-1)

    def sample_wrap_normal(
        self,
        Mu: torch.Tensor,  # (N, M)
        sigma: torch.Tensor,  # (N),  per component sigma
        batch=()
    ) -> torch.Tensor:
        """
        Sample a wrapped normal (wrapped Gaussian) around Mu.

        This function:
          1. draws a Euclidean Gaussian at the origin in each component (same shape as Mu),
          2. parallel-transports that tangent vector from the origin to the location Mu,
          3. applies the exponential map at Mu to map the transported tangent to the manifold.

        Design goals / constraints:
          - Friendly to torch.compile: fully vectorized, no Python branching on tensors.
          - Deterministic shapes so JIT/compile can infer static shapes where possible.
          - Avoids implicit dtype/device mismatches: uses Mu's device/dtype for randomness.
          - Accepts `batch` either as an int (single leading batch dim) or a tuple of batch dims.

        Shapes:
          - Mu: (N, M)
          - If batch is empty (default), returns: (N, M)
          - If batch = (B1, B2, ...), returns: (*batch, N, M)
          - Internally we expand Mu to (..., N, M) where ... == batch

        Args:
          Mu: center (unwrapped point) with shape (N, M)
          sigma: standard deviation in the tangent space at the origin for each component
          batch: optional leading batch shape (int or tuple) describing how many iid samples to produce

        Returns:
          z: points on the mixed-curvature manifold with shape (*batch, N, M)
        """
        # Normalize batch -> tuple (same logic as generate_trees_frames)
        batch_dims = (int(batch),) if isinstance(batch, int) and batch else tuple(batch)

        # Ensure Mu has expected shape (N, M) then expand to batch shape (..., N, M)
        # We intentionally place Mu on self.device for consistent ops (matches other methods).
        Mu = Mu.view(self.N, self.M)
        out_shape = (*batch_dims, self.N, self.M)

        # Use class device and respect dtype from Mu if available
        device = self.device
        dtype = Mu.dtype if hasattr(Mu, "dtype") else torch.get_default_dtype()

        Mu_exp = Mu.view(*([1] * len(batch_dims)), self.N, self.M).expand(out_shape).to(device=device, dtype=dtype)

        # 1) Sample Gaussian at the origin with same dtype/device as Mu_exp
        #    Shape: (*batch, N, M)
        #    Use torch.randn with explicit dtype/device for compile-friendly behavior.
        eps_tensor = torch.randn(*batch_dims, self.N, self.M, device=device, dtype=dtype)
        # sigma = torch.as_tensor(1, device=device, dtype=dtype)
        sigma = sigma.view(*[1]*len(batch_dims), self.N, 1)
        sample_at_origin = eps_tensor * sigma  # allows grads & per-component σ if desired

        # 2) Parallel-transport the origin-sample to the tangent at Mu
        #    parallel_transport expects x,y,v all shaped (..., N, M) where x is source basepoint
        origin = torch.zeros_like(Mu_exp, device=device, dtype=dtype)
        transported = self.parallel_transport(origin, Mu_exp, sample_at_origin)  # (..., N, M)

        # 3) Exponential map at Mu with the transported tangent to obtain manifold point
        #    Use exp_map_pairwise which operates on pairwise per-component tensors (..., N, M)
        z = self.exp_map_pairwise(Mu_exp, transported)

        return z

    def log_wrapped_normal(
        self,
        z: torch.Tensor,                 # (..., N, M)    sample(s) on the manifold
        mu: torch.Tensor,                # (..., N, M)    component-wise mean(s)
        sigma: torch.Tensor,             # (N,) or (..., N)  per-component tangent std
        *,
        return_per_component: bool = False,
        eps: float = 1e-9,
    ) -> torch.Tensor:
        """
        Wrapped Normal log-likelihood on a mixed-curvature product space.

        Per component i (of dimension M):
        u_i = log^K_{mu_i}(z_i)                          (tangent at mu_i)
        v_i = PT^K_{mu_i -> 0}(u_i)                      (transport to origin)
        r_i = d^K(mu_i, z_i)                             (geodesic in comp i)

        log p_i(z_i | mu_i, sigma_i)
        = log N( v_i ; 0, sigma_i^2 I_M )
            - (M - 1) * log( S_{K_i}(r_i) / r_i )

        Total mixed-curvature log-likelihood = sum_i log p_i.

        Shapes:
        z, mu: (..., N, M)
        sigma: (N,) or (..., N) (broadcastable to the leading batch)
        Returns:
            (...,) by default (sum over N), or (..., N) if return_per_component=True.
        """
        N, M = self.N, self.M
        # Ensure component layout
        mu = self._ensure_2d(mu, N, M)     # (..., N, M)
        z  = self._ensure_2d(z,  N, M)     # (..., N, M)

        device, dtype = z.device, z.dtype

        # 1) u = log_mu(z)  (tangent at mu)
        u = self.log_map_pairwise(mu, z)            # (..., maybe 1, 1, N, M)
        if u.ndim >= 5:
            u = u[..., 0, 0, :, :]         # strip pair axes -> (..., N, M)

        # 2) v = PT_{mu -> 0}(u)  (origin = 0)
        origin = torch.zeros_like(mu, device=device, dtype=dtype)
        v = self.parallel_transport(mu, origin, u)   # (..., N, M)

        # 3) Gaussian term per component (spherical within each component)
        sigma = torch.as_tensor(sigma, device=device, dtype=dtype)
        # Broadcast sigma to (..., N)
        if sigma.ndim == 1:
            sigma = sigma.view(*([1] * (v.ndim - 2)), N)   # (..., N)
        # else assume already broadcastable to (..., N)

        sigma2 = (sigma * sigma).clamp_min(eps)            # (..., N)
        v2 = (v * v).sum(dim=-1)                           # (..., N)
        logN = -0.5 * (v2 / sigma2 + M * torch.log(2 * torch.pi * sigma2))  # (..., N)

        # 4) Jacobian correction via unified S_K
        r = self.distance_pairwise(mu, z, reduce=False)             # (..., maybe 1, 1, N)
        if r.ndim >= 4:
            r = r[..., 0, 0, :]                            # (..., N)

        ratio = (sin_k(r, self.K).to(dtype)+eps) / (r+eps)  # (..., N)
        logJ = -(M - 1) * torch.log(ratio)      # (..., N)
        # print('logj shape: ', logJ.shape)

        # 5) Sum over components
        # print('logN shape: ', logN.shape)
        ll_comp = logN + logJ                                   # (..., N)
        if return_per_component:
            return ll_comp
        return ll_comp.sum(dim=-1)


    def get_mus(self, n_mixture):
        if self.mus == None or self.mus.shape[0]!=n_mixture:
            self.mus = torch.randn(n_mixture, self.N, self.M).to(self.device)
        return self.mus

    def pairwise_cost(
        self,
        Xs: torch.Tensor,   # (As, N, M) or (.., As, N, M)
        Xt: torch.Tensor,   # (At, N, M) or (.., At, N, M)
        *,
        p: int | float = 2,
        reduce: bool = True,
    ) -> torch.Tensor:
        """
        Pairwise ground cost between two point clouds on the mixed-curvature space.
        Uses the manifold geodesic distance d_MCS, and returns d_MCS(x,y)^p.

        Returns:
          C: (..., As, At) matrix (torch, same device as inputs)
        """
        # d: (..., As, At)  when reduce=True (it aggregates across components with L2)
        d = self.distance(Xs, Xt, reduce=reduce)
        return d.pow(p)

    @torch.no_grad()
    def wasserstein(
        self,
        Xs: torch.Tensor,   # (As, N, M)
        Xt: torch.Tensor,   # (At, N, M)
        *,
        p: int | float = 2,
        a: np.ndarray | None = None,
        b: np.ndarray | None = None,
        sinkhorn_reg: float | None = None,
        numItermax: int = 1000,
        device: str | None = None,
    ) -> torch.Tensor:
        """
        1-Wasserstein or entropy-regularized (Sinkhorn) distance on the mixed-curvature space.

        Ground cost:  c(x,y) = d_MCS(x,y)^p,  where d_MCS combines per-component geodesics with L2.

        Args:
          Xs, Xt:  point clouds on the MCS, shape (A, N, M) and (B, N, M)
          p:       exponent on the ground metric; use p=1 (classic) or p=2, etc.
          a, b:    probability weights (numpy arrays of shape (A,), (B,)); if None, uniform.
          sinkhorn_reg:
                   - None  -> exact EMD via ot.emd2
                   - float -> Sinkhorn distance via ot.sinkhorn2 with regularization
          numItermax: passed to POT when using Sinkhorn
          device:   where to place the returned torch scalar (default: Xs.device)

        Returns:
          W: torch scalar (float32) with the Wasserstein loss value.
        """
        if ot is None:
            raise ImportError("POT (package 'ot') is required for MCS.wasserstein. Install with `pip install POT`.")

        device = device or (Xs.device if isinstance(Xs, torch.Tensor) else "cpu")

        # Build cost matrix on torch, then move to numpy for POT
        C = self.pairwise_cost(Xs, Xt, p=p, reduce=True)  # (A, B)
        C_np = C.detach().cpu().numpy()

        A = C_np.shape[0]
        B = C_np.shape[1]
        if a is None:
            a = np.ones((A,), dtype=np.float64) / max(A, 1)
        if b is None:
            b = np.ones((B,), dtype=np.float64) / max(B, 1)

        if sinkhorn_reg is None:
            val = ot.emd2(a, b, C_np)
        else:
            val = ot.sinkhorn2(a, b, C_np, sinkhorn_reg, numItermax=numItermax)

        return torch.tensor(val, dtype=torch.float32, device=device)
