import math
from typing import Dict, Optional, Tuple
import torch
from utils.MCS import MCS
import torch.nn.functional as F
from torch import nn, optim


# ---------------------------
# Evenly-spread anchors (means) on MCS
# ---------------------------

@torch.no_grad()
def build_evenly_spread_mixture(
    mcs: MCS,
    n_components: int,
    *,
    sigma_per_comp: Optional[torch.Tensor] = None,  # (N,)  tangent std per component
    weights: Optional[torch.Tensor] = None,         # (K,)
    radii_pos: Optional[float] = None,              # geodesic radius for K>0 comps
    radii_euclid: Optional[float] = None,           # Euclidean radius for K=0 comps
    radii_neg: Optional[float] = None,              # geodesic radius for K<0 comps
    seed: Optional[int] = 0,
) -> Dict[str, torch.Tensor]:
    """
    Returns a dict with
      mus:     (K, N, M)  mixture means on the manifold
      sigmas:  (N,)       per-component tangent std (broadcastable)
      weights: (K,)       mixture weights (sum=1)
    Strategy:
      For each component i, generate K quasi-uniform directions on S^{M-1} via Sobol->Gaussian->normalize,
      choose a fixed geodesic radius per curvature type, form tangent vectors v_i[k] = r_i * dir_i[k],
      then map to manifold means via exp_0^K(v_i[k]). Combine components with different permutations
      (Latin-hypercube style) to avoid grid artifacts across components.
    """
    device = mcs.device
    dtype  = torch.get_default_dtype()
    N, M   = mcs.N, mcs.M
    Kmix   = n_components

    # defaults for radii (geodesic where applicable)
    # K>0: sphere diameter = π/√K; pick a safe interior radius ≈ 0.6 * (π/√K)
    # K=0: plain Euclid; pick radius 1.0
    # K<0: unbounded r; pick moderate radius (e.g., 2.0) → ρ = tanh(r/2) ≈ 0.76 in Poincaré
    Kvec = mcs.K.reshape(-1).to(device=device, dtype=dtype)
    sqabsK = torch.sqrt(Kvec.abs() + 0.0)  # (N,)

    r_pos_default   = 0.6 * math.pi / torch.clamp(sqabsK, min=1e-12)  # only used where K>0
    r_euclid_default= torch.as_tensor(1.0, device=device, dtype=dtype)
    r_neg_default   = torch.as_tensor(2.0, device=device, dtype=dtype)

    # user overrides
    r_pos   = torch.as_tensor(radii_pos if radii_pos is not None else 0.0, device=device, dtype=dtype)
    r_eucl  = torch.as_tensor(radii_euclid if radii_euclid is not None else r_euclid_default, device=device, dtype=dtype)
    r_neg   = torch.as_tensor(radii_neg if radii_neg is not None else r_neg_default, device=device, dtype=dtype)

    # assemble per-component radii
    tol = torch.finfo(dtype).eps
    is_pos = (Kvec >  tol)
    is_zer = (Kvec.abs() <= tol)
    is_neg = (Kvec < -tol)

    r_each = torch.where(is_pos, r_pos_default if radii_pos is None else r_pos,
                  torch.where(is_zer, r_eucl,
                      r_neg))  # (N,)

    # quasi-uniform directions per component (Sobol -> inverse normal -> normalize)
    dirs_per_comp = []
    g = torch.Generator(device='cpu')
    if seed is not None:
        g.manual_seed(seed)
    for i in range(N):
        sob = torch.quasirandom.SobolEngine(dimension=M, scramble=True, seed=(None if seed is None else seed + 131 * (i+1)))
        u = sob.draw(Kmix).to(dtype)  # (K, M) in [0,1)
        # map to quasi-Gaussian via inverse CDF of N(0,1): x = Φ^{-1}(u)
        # Φ^{-1}(u) = √2 * erfinv(2u-1)
        x = math.sqrt(2.0) * torch.erfinv(2.0*u - 1.0)  # (K, M)
        x = x / (x.norm(dim=-1, keepdim=True).clamp_min(1e-12))     # normalize -> S^{M-1}
        dirs_per_comp.append(x.to(device))

    # permutations per component to decorrelate choices across components
    perms = []
    for i in range(N):
        perm = torch.randperm(Kmix, generator=g).to(device)
        perms.append(perm)

    # build tangent means per mixture index k: v[k, N, M]
    v_tan = torch.zeros(Kmix, N, M, device=device, dtype=dtype)
    for i in range(N):
        v_tan[:, i, :] = r_each[i] * dirs_per_comp[i][perms[i], :]

    # map by exp at the origin to manifold means
    zeros = torch.zeros_like(v_tan, device=device, dtype=dtype)
    mus = mcs.exp_map_pairwise(zeros, v_tan)  # (K, N, M)

    # sigmas and weights
    if sigma_per_comp is None:
        sigma_per_comp = torch.full((N,), 0.3, device=device, dtype=dtype)  # sensible default
    else:
        sigma_per_comp = sigma_per_comp.to(device=device, dtype=dtype).reshape(N)

    if weights is None:
        weights = torch.full((Kmix,), 1.0 / Kmix, device=device, dtype=dtype)
    else:
        weights = (weights.to(device=device, dtype=dtype).reshape(-1))
        weights = weights / weights.sum().clamp_min(1e-12)

    return {"mus": mus, "sigmas": sigma_per_comp, "weights": weights}


# ---------------------------
# Mixture sampling
# ---------------------------

@torch.no_grad()
def mixture_sample(
    mcs: MCS,
    params: Dict[str, torch.Tensor],
    n_samples: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Draw n_samples from the mixture:
      p(z) = sum_k π_k  WrappedNormal_MCS(μ_k, σ)
    Returns:
      samples:  (B, N, M)
      comp_ix:  (B,)  chosen component indices
    """
    mus    = params["mus"]        # (K, N, M)
    sigmas = params["sigmas"]     # (N,)
    weights= params["weights"]    # (K,)

    device = mcs.device
    dtype  = mus.dtype
    K, N, M = mus.shape
    B = int(n_samples)

    comp_ix = torch.multinomial(weights, B, replacement=True)  # (B,)
    samples = torch.empty(B, N, M, device=device, dtype=dtype)

    # group-by component for fewer exp_map calls
    uniq = torch.unique(comp_ix)
    for k in uniq.tolist():
        mask = (comp_ix == k)
        cnt  = int(mask.sum().item())
        z_k  = mcs.sample_wrap_normal(mus[k], sigmas, batch=cnt)  # (cnt, N, M)
        samples[mask] = z_k

    return samples, comp_ix


# ---------------------------
# Mixture log-likelihood
# ---------------------------

def mixture_log_prob(
    mcs: MCS,
    params: Dict[str, torch.Tensor],
    z: torch.Tensor,                       # (B, N, M) or (..., N, M)
    return_componentwise: bool = False,   # if True: returns (B, K)
) -> torch.Tensor:
    """
    Computes log p(z) under the mixture:
      log p(z) = logsumexp_k [ log π_k + log q(z | μ_k, σ) ]
    Uses mcs.log_wrapped_normal (already sums across components).
    """
    mus    = params["mus"]       # (K, N, M)
    sigmas = params["sigmas"]    # (N,)
    weights= params["weights"]   # (K,)

    device = z.device
    dtype  = z.dtype
    K, N, M = mus.shape

    z = z.view(-1, N, M)                     # (B, N, M)
    B = z.shape[0]

    # broadcast to (B, K, N, M)
    z_b  = z.unsqueeze(1).expand(B, K, N, M).contiguous()
    mu_b = mus.unsqueeze(0).expand(B, K, N, M).contiguous()

    # per-(B,K) log-likelihood on MCS (sum across components is internal)
    ll_bk = mcs.log_wrapped_normal(z_b, mu_b, sigmas, return_per_component=False)  # (B, K)

    logw = torch.log(weights + torch.finfo(dtype).eps)  # (K,)
    logw = logw.view(1, K).expand(B, K)
    logits = logw + ll_bk                                # (B, K)

    if return_componentwise:
        return logits  # caller can logsumexp over K

    # log-sum-exp over mixture components
    m = logits.max(dim=1, keepdim=True).values
    logp = m.squeeze(1) + torch.log(torch.exp(logits - m).sum(dim=1) + torch.finfo(dtype).eps)
    return logp  # (B,)


# =========================
# Learn evenly-spread mus on MCS by maximizing pairwise angular separation
# =========================
def _cosine_spread_loss(P_unit: torch.Tensor) -> torch.Tensor:
    """
    HPN-style objective: minimize the maximum pairwise cosine similarity
    among K normalized rows of P_unit (K x D).
    """
    # cosine sim matrix (diag == 1)
    S = P_unit @ P_unit.t()  # (K, K)
    # make diagonal the minimum possible so it won't be selected by max
    S = S - 2.0 * torch.diag(torch.diag(S)) + 1.0
    # per-row max (i != j)
    row_max = S.max(dim=1).values
    return row_max.mean()

@torch.no_grad()
def _reshape_KxNM_to_KxNxM(mcs, X):
    return X.view(-1, mcs.N, mcs.M)

def _normalize_flat(P):
    # l2-normalize each row
    return F.normalize(P, p=2, dim=-1)

def _equiangular_circle(K: int, device, dtype):
    """Fallback for total dimension == 2."""
    ang = torch.arange(K, device=device, dtype=dtype) * (2*math.pi / K)
    return torch.stack([torch.cos(ang), torch.sin(ang)], dim=1)  # (K,2)

def train_prototype_directions(
    mcs,
    n_components: int,
    *,
    steps: int = 2000,
    lr: float = 0.1,
    momentum: float = 0.9,
    seed: int = 0,
    verbose: bool = False,
):
    """
    Learn K (=n_components) unit directions on S^{N*M-1} by minimizing the
    maximum pairwise cosine similarity (HPN-style).

    Returns:
      directions: (K, N, M) unit directions split per component (each row is unit in the FLAT space).
      flat_unit:  (K, N*M) same directions, flattened and unit-normalized.
    """
    torch.manual_seed(seed)
    device = mcs.device
    dtype  = torch.get_default_dtype()
    K, D = int(n_components*mcs.N), int(mcs.M)

    if D == 2:
        flat = _equiangular_circle(K, device, dtype)
        flat = _normalize_flat(flat)
        return _reshape_KxNM_to_KxNxM(mcs, flat), flat

    # initialize and optimize
    flat = nn.Parameter(torch.randn(K, D, device=device, dtype=dtype))
    opt  = optim.SGD([flat], lr=lr, momentum=momentum)

    for t in range(steps):
        P_unit = _normalize_flat(flat)
        loss = _cosine_spread_loss(P_unit)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        # re-normalize in-place (keep optimizer state)
        with torch.no_grad():
            flat.copy_( _normalize_flat(flat) )
        if verbose and ((t+1) % max(1, steps//10) == 0):
            # report current max cosine (off-diagonal)
            with torch.no_grad():
                S = P_unit @ P_unit.t()
                S = S - 2.0 * torch.diag(torch.diag(S)) + 1.0
                mx = S.max().item()
            print(f"[{t+1:5d}/{steps}] loss={loss.item():.6f} max_cos={mx:.6f}")

    flat_unit = _normalize_flat(flat.detach())
    directions = _reshape_KxNM_to_KxNxM(mcs, flat_unit)
    return directions, flat_unit

@torch.no_grad()
def directions_to_mus(
    mcs,
    directions: torch.Tensor,             # (K, N, M), not necessarily per-comp unit
    radius_per_comp: torch.Tensor | float # (N,) or scalar; geodesic radius in each component
):
    """
    Map per-component directions to manifold means at a specified *geodesic* radius.
    For each mixture k and component i:
        v_{k,i} = r_i * dir_{k,i} / ||dir_{k,i}||   (tangent at origin)
        mu_{k,i} = exp_0^K( v_{k,i} )

    Returns:
      mus: (K, N, M) points on the MCS manifold.
    """
    device = mcs.device
    dtype  = directions.dtype
    K, N, M = directions.shape
    dirs = directions.clone()

    # ensure per-component unit directions (avoid degenerate near-zero blocks)
    eps = torch.finfo(dtype).eps
    norms = dirs.norm(dim=-1, keepdim=True).clamp_min(eps)  # (K,N,1)
    dirs = dirs / norms

    # broadcast radii per component
    if not torch.is_tensor(radius_per_comp):
        radius_per_comp = torch.as_tensor(radius_per_comp, device=device, dtype=dtype)
        radius_per_comp = radius_per_comp.expand(N)
    r = radius_per_comp.to(device=device, dtype=dtype).view(1, N, 1)  # (1,N,1)

    v_tan = r * dirs  # (K,N,M) tangent vectors at origin with geodesic length r_i in comp i
    zeros = torch.zeros_like(v_tan, device=device, dtype=dtype)
    mus = mcs.exp_map_pairwise(zeros, v_tan)  # (K,N,M)
    return mus

def learn_mus_by_spread(
    mcs,
    n_components: int,
    *,
    radius_per_comp: torch.Tensor | float = 1.0,
    steps: int = 2000,
    lr: float = 0.1,
    momentum: float = 0.9,
    seed: int = 0,
    verbose: bool = False,
):
    """
    One-shot helper: train directions with cosine spread, then place mus
    at the requested geodesic radii per component.
    """
    directions, flat = train_prototype_directions(
        mcs, n_components, steps=steps, lr=lr, momentum=momentum, seed=seed, verbose=verbose
    )
    mus = directions_to_mus(mcs, directions, radius_per_comp)
    return {"mus": mus, "directions": directions, "flat_unit": flat}


# on how to choose radius_per_comp and sigma
def _sphere_surface_area(n: int) -> float:
    """
    Surface area of unit n-sphere S^n (embedded in R^{n+1}).
    S^n area: 2 * pi^{(n+1)/2} / Gamma((n+1)/2)
    """
    return 2.0 * math.pi**((n + 1) / 2.0) / math.gamma((n + 1) / 2.0)

def _pred_sep_angle(M: int, n_mix: int) -> float:
    """
    Predict small-angle separation θ on S^{M-1} for n_mix quasi-uniform points
    via spherical-cap packing: n_mix * (S_{M-2} * θ^{M-1} / (M-1)) ≈ S_{M-1}.
    """
    if M <= 1 or n_mix <= 1:
        return math.pi  # degenerate cases
    S_Mm1 = _sphere_surface_area(M - 1)
    S_Mm2 = _sphere_surface_area(M - 2)
    theta = ((M - 1) * S_Mm1 / (n_mix * S_Mm2)) ** (1.0 / (M - 1))
    return float(theta)

def _nn_distance_given_r_theta(Ki: float, r: float, theta: float) -> float:
    """
    Geodesic NN distance between two points at the same radial distance r from the origin
    whose directions differ by angle theta, in constant curvature Ki.
    Uses exact spherical/hyperbolic cosine rules; Euclidean closed form for Ki=0.
    """
    eps = 1e-12
    if abs(Ki) <= eps:  # Euclidean
        return 2.0 * r * math.sin(0.5 * theta)

    if Ki > 0:  # spherical, radius of curvature R=1/sqrt(K)
        sq = math.sqrt(Ki)
        a = math.cos(sq * r)
        b = math.sin(sq * r)
        cos_arg = a * a + (b * b) * math.cos(theta)
        cos_arg = max(-1.0, min(1.0, cos_arg))
        return math.acos(cos_arg) / sq

    # Ki < 0: hyperbolic, R = 1/sqrt(|K|)
    sq = math.sqrt(-Ki)
    ch = math.cosh(sq * r)
    sh = math.sinh(sq * r)
    arg = ch * ch - (sh * sh) * math.cos(theta)
    # numerical safety: arg >= 1
    arg = max(1.0, arg)
    return math.acosh(arg) / sq

@torch.no_grad()
def suggest_radius_sigma(
    mcs,
    n_mix: int,
    *,
    d_target: float = 1.0,         # target NN gap used to set initial r (Euclidean small-angle), then clamped
    overlap_frac: float = 0.5,     # how much of the NN gap we allow “covered” by the component
    n_sigma: float = 2.0,          # how many sigmas should cover that fraction (→ σ = overlap_frac * d_nn / (2 n_sigma))
    sphere_margin: float = 0.10,   # keep r safely inside π/√K on spheres: r <= (1 - margin)*π/√K
    hyp_max_rho: float = 0.90,     # keep hyperbolic centers inside Poincaré radius ρ ≤ hyp_max_rho
) -> dict:
    """
    Inputs:
      mcs: your MCS instance (fields: N, M, K, device)
      n_mix: number of mixture components (evenly spread)
    Returns:
      dict with:
        r_per_comp:   (N,) tensor of recommended geodesic radii for centers
        sigma_per_comp: (N,) tensor of recommended tangent-space stds (as in log_wrapped_normal)
        dnn_per_comp: (N,) predicted NN geodesic gap at those radii
        theta_sep:    scalar angle (float) used
        caps:         dict of per-comp caps used for safety (r_sph_cap, r_hyp_cap)
    """
    device = mcs.device
    dtype  = torch.get_default_dtype()
    N, M = mcs.N, mcs.M
    Kvec = mcs.K.reshape(-1).to(device=device, dtype=dtype)

    # 1) predicted separation angle on S^{M-1}
    theta = _pred_sep_angle(M, n_mix)  # radians

    # 2) initial r from small-angle Euclidean relation d_nn ≈ 2 r sin(theta/2)
    #    Solve r ≈ d_target / (2 sin(theta/2)), then clamp per curvature.
    denom = max(1e-8, 2.0 * math.sin(0.5 * theta))
    r0 = d_target / denom

    # 3) per-component curvature caps
    eps = torch.finfo(dtype).eps
    is_pos = (Kvec >  eps)
    is_neg = (Kvec < -eps)
    is_euc = ~(is_pos | is_neg)

    # spherical cap: r <= (1 - margin) * π/√K
    r_sph_cap = torch.zeros(N, device=device, dtype=dtype)
    if is_pos.any():
        r_sph_cap[is_pos] = (1.0 - sphere_margin) * (math.pi / torch.sqrt(Kvec[is_pos]))

    # hyperbolic cap: keep Poincaré radius ρ = tanh(√|K| r / 2) <= hyp_max_rho
    r_hyp_cap = torch.full((N,), float('inf'), device=device, dtype=dtype)
    if is_neg.any():
        # r_max = (2/√|K|) * atanh(hyp_max_rho)
        r_hyp_cap[is_neg] = (2.0 / torch.sqrt((-Kvec[is_neg]).clamp_min(1e-12))) * \
                            torch.atanh(torch.tensor(hyp_max_rho, device=device, dtype=dtype))

    # 4) choose r_i with caps
    r_per_comp = torch.full((N,), r0, device=device, dtype=dtype)
    r_per_comp[is_pos] = torch.minimum(r_per_comp[is_pos], r_sph_cap[is_pos])
    r_per_comp[is_neg] = torch.minimum(r_per_comp[is_neg], r_hyp_cap[is_neg])

    # 5) compute predicted NN gap with exact constant-curvature law
    dnn_list = []
    for i in range(N):
        dnn = _nn_distance_given_r_theta(float(Kvec[i].item()), float(r_per_comp[i].item()), theta)
        dnn_list.append(dnn)
    dnn_per_comp = torch.tensor(dnn_list, device=device, dtype=dtype)

    # 6) choose sigma in origin tangent so that 2 * n_sigma * σ ≈ overlap_frac * d_nn
    sigma_per_comp = (overlap_frac * dnn_per_comp) / (2.0 * max(n_sigma, 1e-6))
    # spherical safety: ensure r_i + 3 * (2σ) < π/√K (stay away from antipode)
    if is_pos.any():
        max_geo = (math.pi / torch.sqrt(Kvec[is_pos]))  # total injectivity radius
        slack = (1.0 - sphere_margin) * max_geo - r_per_comp[is_pos]
        safe_sigma = torch.clamp(slack / 6.0, min=1e-6)  # because 3*(2σ) = 6σ
        sigma_per_comp[is_pos] = torch.minimum(sigma_per_comp[is_pos], safe_sigma)

    # hyperbolic safety: ensure r_i + 3*(2σ) ≤ r_hyp_cap
    if is_neg.any():
        slack = r_hyp_cap[is_neg] - r_per_comp[is_neg]
        safe_sigma = torch.clamp(slack / 6.0, min=1e-6)
        sigma_per_comp[is_neg] = torch.minimum(sigma_per_comp[is_neg], safe_sigma)

    return {
        "r_per_comp": r_per_comp,             # (N,) geodesic radii
        "sigma_per_comp": sigma_per_comp,     # (N,) tangent stds for log_wrapped_normal
        "dnn_per_comp": dnn_per_comp,         # (N,) predicted nearest-neighbor gaps
        "theta_sep": theta,
        "caps": {"r_sph_cap": r_sph_cap, "r_hyp_cap": r_hyp_cap},
    }
