import os, json
from pathlib import Path
import math
import numpy as np
from color_rules import step_size

TAU_ISO   = 0.015
TAU_LAM2  = 0.05
TAU_FLOW  = 0.20


def _fit_link_basis(XY: np.ndarray):
    try:
        Yw, Yraw, V2, lam2 = _pca2(np.asarray(XY, float))
    except Exception:
        d = XY.shape[1] if XY.ndim == 2 else 0
        V2 = np.zeros((d, 2));
        if d >= 1: V2[:, 0] = 1.0
        lam2 = 0.0
    return dict(V_link=V2, lam2_share=float(lam2))

def _fit_joint_2d(P: np.ndarray, Q: np.ndarray, V_link: np.ndarray | None = None):
    P = np.asarray(P, float); Q = np.asarray(Q, float)
    X = np.vstack([P, Q])
    Yw, Yraw, V2, lam2_share_joint = _pca2(X)
    if P.size and Q.size and V2.size:
        anchor = (Q.mean(axis=0) - P.mean(axis=0)).astype(float)
        a2 = (anchor - X.mean(axis=0)) @ V2
        if a2[0] < 0 or (abs(a2[0]) < 1e-12 and a2[1] < 0):
            V2[:, 0] *= -1.0
            Yw[:, 0] *= -1.0; Yraw[:, 0] *= -1.0
    if V_link is not None and V_link.size:
        M = V2.T @ V_link
        if np.linalg.det(M) < 0:
            V2 *= -1.0; Yw *= -1.0; Yraw *= -1.0
        if (V2[:, 0] @ V_link[:, 0]) < 0:
            V2[:, 0] *= -1.0; Yw[:, 0] *= -1.0; Yraw[:, 0] *= -1.0
    nP = P.shape[0]
    return dict(
        P_w=Yw[:nP],  Q_w=Yw[nP:],
        P_r=Yraw[:nP],Q_r=Yraw[nP:],
        V_joint=V2,
        lam2_share_joint=float(lam2_share_joint)
    )

def _signed_area(XY: np.ndarray) -> float:
    x, y = XY[:, 0], XY[:, 1]
    return 0.5 * np.sum(x * np.roll(y, -1) - np.roll(x, -1) * y)

def _direction_sign_from_area(XY_ordered: np.ndarray) -> int:
    if XY_ordered.shape[0] < 3:
        return 0
    A = _signed_area(XY_ordered)
    return +1 if A > 0 else -1

def _angle_flow_strength(XY_ordered: np.ndarray) -> float:
    if XY_ordered.shape[0] < 3:
        return 0.0
    c = XY_ordered.mean(axis=0, keepdims=True)
    th = np.arctan2(XY_ordered[:, 1] - c[0, 1], XY_ordered[:, 0] - c[0, 0])
    dth = np.unwrap(np.diff(np.r_[th, th[:1]]))
    s = np.sign(dth)
    return float(np.mean(s))

def _order_plus_d(resid: np.ndarray, g: int, d: int) -> np.ndarray:
    r = np.asarray(resid, int) % g
    r0 = int(np.min(r))
    seq = [(r0 + t * (d % g)) % g for t in range(g)]
    pos = {v: i for i, v in enumerate(seq)}
    keys = np.vectorize(pos.get)(r)
    return np.argsort(keys, kind="stable")

def _order_plus_d_with_start(resid: np.ndarray, g: int, d: int, step_sign: int, start_resid: int) -> np.ndarray:
    r = np.asarray(resid, int) % g
    d_mod = int(d % g)
    step = d_mod if step_sign >= 0 else (g - d_mod) % g
    seq = [(int(start_resid) + t * step) % g for t in range(g)]
    pos = {v: i for i, v in enumerate(seq)}
    keys = np.vectorize(pos.get)(r)
    return np.argsort(keys, kind="stable")

def _roll_oi_to_start(resid: np.ndarray, oi: np.ndarray, g: int, start_resid: int):
    r = (np.asarray(resid, int) % g)[oi]
    hits = np.where(r == (start_resid % g))[0]
    if hits.size == 0:
        return oi, False
    k = int(hits[0])
    return np.roll(oi, -k), True

def _poly_metrics(XY_ordered: np.ndarray):
    n = XY_ordered.shape[0]
    if n < 3:
        return dict(area_abs=0.0, perim=0.0, iso=0.0)
    A = abs(_signed_area(XY_ordered))
    closed = np.vstack([XY_ordered, XY_ordered[:1]])
    P = float(np.sum(np.linalg.norm(np.diff(closed, axis=0), axis=1)))
    iso = (4.0 * math.pi * A) / (P * P) if P > 0 else 0.0
    return dict(area_abs=float(A), perim=P, iso=float(iso))

def _pca_k(X, k=3):
    X = np.asarray(X, float)
    if X.ndim != 2:
        X = X.reshape(-1, X.shape[-1])
    N, D = X.shape
    Xc = X - X.mean(axis=0, keepdims=True)
    if N == 0 or D == 0:
        V = np.zeros((D, k)); Y = np.zeros((N, k)); lam_shares = [0.0]*k
        return Y, V, lam_shares
    Uk, Sk, Vtk = np.linalg.svd(Xc, full_matrices=False)
    r = min(k, Sk.size)
    V = Vtk[:r].T
    if r < k:
        V = np.hstack([V, np.zeros((D, k-r))])
    Y = Xc @ V
    lam = (Sk[:r]**2)
    s = float(lam.sum()) if lam.size else 0.0
    shares = [(float(l)/s if s>0 else 0.0) for l in lam]
    if r < k:
        shares += [0.0]*(k-r)
    return Y, V, shares[:k]

def _pca2(X):
    X = np.asarray(X, float)
    if X.ndim != 2:
        X = X.reshape(-1, X.shape[-1])
    N, D = X.shape
    Xc = X - X.mean(axis=0, keepdims=True)
    if N < 2 or D < 2:
        V2 = np.zeros((D, 2), dtype=float)
        if D >= 1:
            V2[:, 0] = 1.0
        Y_raw = np.zeros((N, 2), dtype=float)
        if D >= 1:
            Y_raw[:, 0] = Xc[:, 0]
        Y_whiten = Y_raw.copy()
        return Y_whiten, Y_raw, V2, 0.0
    U, S, Vt = np.linalg.svd(Xc, full_matrices=False)
    k = min(2, S.size)
    V = Vt[:k].T
    if k < 2:
        v2 = np.zeros((D, 1))
        V2 = np.hstack([V, v2])
    else:
        V2 = V[:, :2]
    Y_raw = Xc @ V2
    lam = (S[:k] ** 2)
    lam2_share = float((lam[1] if lam.size >= 2 else 0.0) / max(1e-12, lam.sum()))
    S2 = np.ones(2, dtype=float)
    S2[:k] = S[:k].copy()
    S2[S2 == 0] = 1.0
    Y_whiten = Y_raw / S2
    return Y_whiten, Y_raw, V2, lam2_share

def _fit_local_2d(X: np.ndarray):
    Yw, Yraw, V2, lam2_share = _pca2(X)
    return dict(Y_w=Yw, Y_r=Yraw, V=V2, lam2_share=float(lam2_share))

def _transport_step_sign(step_sign_local: int, V_local2: np.ndarray, V_joint2: np.ndarray) -> int:
    if V_local2.size == 0 or V_joint2.size == 0:
        return int(step_sign_local)
    M = V_local2.T @ V_joint2
    s = np.sign(np.linalg.det(M))
    if s == 0: s = +1
    return int(step_sign_local * s)

def _order_clockwise_in_basis(resid: np.ndarray, Yw: np.ndarray, Yraw: np.ndarray,
                              g: int, d: int, lam2_share_for_this_basis: float):
    out = dict(oi=None, step_sign=+1, s_area=0, strength=0.0,
               degenerate=False, deg_reason=None, tie=False,
               iso=None, lam2_share=float(lam2_share_for_this_basis))
    if Yw.shape[0] == 0 or resid.size == 0:
        out.update(degenerate=True, deg_reason="empty")
        return out
    d_mod = int(d % g)
    d_neg = (g - d_mod) % g
    oi_plus  = _order_plus_d(resid, g, d_mod)
    oi_minus = _order_plus_d(resid, g, d_neg)
    Pp_w, Pm_w = Yw[oi_plus],  Yw[oi_minus]
    Pp_r, Pm_r = Yraw[oi_plus], Yraw[oi_minus]
    s_plus   = _direction_sign_from_area(Pp_w)
    s_minus  = _direction_sign_from_area(Pm_w)
    str_plus = _angle_flow_strength(Pp_w)
    str_minus= _angle_flow_strength(Pm_w)
    met_plus  = _poly_metrics(Pp_r)
    met_minus = _poly_metrics(Pm_r)
    iso_max   = max(met_plus["iso"], met_minus["iso"])
    flow_max  = max(abs(str_plus), abs(str_minus))
    if (lam2_share_for_this_basis < TAU_LAM2) or (iso_max < TAU_ISO) or (flow_max < TAU_FLOW):
        out.update(degenerate=True,
                   deg_reason=("lam2_small" if lam2_share_for_this_basis < TAU_LAM2 else
                               "iso_small" if iso_max < TAU_ISO else
                               "weak_flow"),
                iso=float(iso_max),
                flow=float(flow_max),
                lam2_share=float(lam2_share_for_this_basis),
                tau_iso=TAU_ISO, tau_flow=TAU_FLOW, tau_lam2=TAU_LAM2)
        return out
    if s_plus == -1 and s_minus != -1:
        out.update(oi=oi_plus, step_sign=+1, s_area=-1, strength=float(str_plus), iso=float(iso_max))
        return out
    if s_minus == -1 and s_plus != -1:
        out.update(oi=oi_minus, step_sign=-1, s_area=-1, strength=float(str_minus), iso=float(iso_max))
        return out
    out.update(tie=True, iso=float(iso_max))
    return out

def _farthest_chord_features(P_ord: np.ndarray, Q_ord: np.ndarray):
    if P_ord.size == 0 or Q_ord.size == 0:
        return dict(far_ok=False, far_dist=None, far_angle=None)
    D2 = ((P_ord[:, None, :] - Q_ord[None, :, :]) ** 2).sum(axis=2)
    i_star, j_star = np.unravel_index(np.argmax(D2), D2.shape)
    v = Q_ord[j_star] - P_ord[i_star]
    dist = float(np.sqrt(D2[i_star, j_star]))
    ang = float(np.arctan2(v[1], v[0]))
    return dict(far_ok=True, far_dist=dist, far_angle=ang)

def _farthest_chord_indices(P_ord: np.ndarray, Q_ord: np.ndarray):
    if P_ord.size == 0 or Q_ord.size == 0:
        return None, None, None
    D2 = ((P_ord[:, None, :] - Q_ord[None, :, :]) ** 2).sum(axis=2)
    i_star, j_star = np.unravel_index(np.argmax(D2), D2.shape)
    dist = float(np.sqrt(D2[i_star, j_star]))
    return int(i_star), int(j_star), dist

def _anchor_shift_from_farthest_pair(P_ord, Q_ord, I_resid, J_resid, g, rng=None):
    if P_ord.shape[0] == 0 or Q_ord.shape[0] == 0:
        return None
    if rng is None:
        rng = np.random.default_rng()
    D2 = ((P_ord[:, None, :] - Q_ord[None, :, :]) ** 2).sum(axis=2)
    maxv = D2.max()
    Ii, Jj = np.where(D2 == maxv)
    idx = int(rng.integers(0, len(Ii)))
    i_star, j_star = int(Ii[idx]), int(Jj[idx])
    m = min(P_ord.shape[0], Q_ord.shape[0])
    if i_star >= m or j_star >= m:
        return None
    s = (j_star - i_star) % m
    return int(s)

def _constant_k_consistency(K: np.ndarray) -> tuple[int, float]:
    vals, cnts = np.unique(K, return_counts=True)
    idx = int(np.argmax(cnts))
    return int(vals[idx]), float(cnts[idx] / K.size)

def _direction_and_basis_for_family(
    resid, XY_family, g, d,
    V_joint, joint_piece_w, joint_piece_r,
    lam2_joint
):
    joint_try = _order_clockwise_in_basis(
        resid, joint_piece_w, joint_piece_r, g, d,
        lam2_share_for_this_basis=lam2_joint
    )
    if not joint_try["degenerate"] and not joint_try["tie"]:
        joint_try["basis"] = "joint"
        return joint_try
    loc = _fit_local_2d(XY_family)
    local_try = _order_clockwise_in_basis(
        resid, loc["Y_w"], loc["Y_r"], g, d,
        lam2_share_for_this_basis=loc["lam2_share"]
    )
    if local_try["degenerate"] or local_try["tie"]:
        local_try["basis"] = "local"
        return local_try
    step_sign_global = _transport_step_sign(local_try["step_sign"], loc["V"], V_joint)
    out = dict(local_try); out.update(step_sign=int(step_sign_global), basis="local")
    return out

def plus_d_direction_fixed_b(a_vals, b_vals, XY, p, f, b_fix, V_link=None, rng=None):
    a_vals = np.asarray(a_vals, int); b_vals = np.asarray(b_vals, int); XY = np.asarray(XY, float)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    d = step_size(f, p)
    m = (b_vals == b_fix)
    mask_rot = m & (a_vals >= 0) & (a_vals < g)
    mask_ref = m & (a_vals >= p) & (a_vals < p + g)
    i_resid = (a_vals[mask_rot] % g)
    j_resid = ((a_vals[mask_ref] - p) % g)
    Pf = XY[mask_rot]; Qb = XY[mask_ref]
    out = dict(
        n_rot=int(Pf.shape[0]), n_ref=int(Qb.shape[0]),
        g=int(g), d=int(d),
        s_rot=0, s_ref=0, strength_rot=0.0, strength_ref=0.0,
        step_sign_rot=+1, step_sign_ref=+1,
        oi=None, oj=None,
        deg_rot=False, deg_ref=False, tie_rot=False, tie_ref=False,
        deg_reason_rot=None, deg_reason_ref=None,
        basis_rot=None, basis_ref=None,
        iso_rot=None, iso_ref=None,
    )
    if Pf.shape[0] >= 3 and Qb.shape[0] >= 3:
        J = _fit_joint_2d(Pf, Qb, V_link=V_link)
        if i_resid.size > 0:
            pick = _direction_and_basis_for_family(
                i_resid, Pf, g, d,
                J["V_joint"], J["P_w"], J["P_r"],
                J["lam2_share_joint"]
            )
            out.update(oi=pick["oi"], s_rot=pick["s_area"], strength_rot=pick["strength"],
                       step_sign_rot=int(pick["step_sign"]), deg_rot=bool(pick["degenerate"]), tie_rot=bool(pick["tie"]),
                       deg_reason_rot=pick["deg_reason"], basis_rot=pick.get("basis"), iso_rot=pick.get("iso"))
        if j_resid.size > 0:
            pick = _direction_and_basis_for_family(j_resid, Qb, g, d, J["V_joint"], J["Q_w"], J["Q_r"], J["lam2_share_joint"] )
            out.update(oj=pick["oi"], s_ref=pick["s_area"], strength_ref=pick["strength"],
                       step_sign_ref=int(pick["step_sign"]), deg_ref=bool(pick["degenerate"]), tie_ref=bool(pick["tie"]),
                       deg_reason_ref=pick["deg_reason"], basis_ref=pick.get("basis"), iso_ref=pick.get("iso"))
        if (out["deg_rot"] or out["deg_ref"] or out["tie_rot"] or out["tie_ref"]):
            return out
        P_seq = (J["P_r"][out["oi"]] if out.get("oi") is not None else J["P_r"])
        Q_seq = (J["Q_r"][out["oj"]] if out.get("oj") is not None else J["Q_r"])
        out.update(_farthest_chord_features(P_seq, Q_seq))
    return out

def plus_d_direction_fixed_a(a_vals, b_vals, XY, p, f, a_fix, V_link=None, rng=None):
    a_vals = np.asarray(a_vals, int); b_vals = np.asarray(b_vals, int); XY = np.asarray(XY, float)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    d = step_size(f, p)
    m = (a_vals == a_fix)
    mask_rot = m & (b_vals >= 0) & (b_vals < g)
    mask_ref = m & (b_vals >= p) & (b_vals < p + g)
    i_resid = (b_vals[mask_rot] % g)
    j_resid = ((b_vals[mask_ref] - p) % g)
    Pf = XY[mask_rot]; Qb = XY[mask_ref]
    out = dict(
        n_rot=int(Pf.shape[0]), n_ref=int(Qb.shape[0]),
        g=int(g), d=int(d),
        s_rot=0, s_ref=0, strength_rot=0.0, strength_ref=0.0,
        step_sign_rot=+1, step_sign_ref=+1,
        oi=None, oj=None,
        deg_rot=False, deg_ref=False, tie_rot=False, tie_ref=False,
        deg_reason_rot=None, deg_reason_ref=None,
        basis_rot=None, basis_ref=None,
        iso_rot=None, iso_ref=None,
    )
    if Pf.shape[0] >= 3 and Qb.shape[0] >= 3:
        J = _fit_joint_2d(Pf, Qb,V_link=V_link)
        if i_resid.size > 0:
            pick = _direction_and_basis_for_family(
                i_resid, Pf, g, d,
                J["V_joint"], J["P_w"], J["P_r"],
                J["lam2_share_joint"]
            )
            out.update(oi=pick["oi"], s_rot=pick["s_area"], strength_rot=pick["strength"],
                       step_sign_rot=int(pick["step_sign"]), deg_rot=bool(pick["degenerate"]), tie_rot=bool(pick["tie"]),
                       deg_reason_rot=pick["deg_reason"], basis_rot=pick.get("basis"), iso_rot=pick.get("iso"))
        if j_resid.size > 0:
            pick = _direction_and_basis_for_family(j_resid, Qb, g, d, J["V_joint"], J["Q_w"], J["Q_r"], J["lam2_share_joint"])
            out.update(oj=pick["oi"], s_ref=pick["s_area"], strength_ref=pick["strength"],
                       step_sign_ref=int(pick["step_sign"]), deg_ref=bool(pick["degenerate"]), tie_ref=bool(pick["tie"]),
                       deg_reason_ref=pick["deg_reason"], basis_ref=pick.get("basis"), iso_ref=pick.get("iso"))
        if (out["deg_rot"] or out["deg_ref"] or out["tie_rot"] or out["tie_ref"]):
            return out
        P_seq = (J["P_r"][out["oi"]] if out.get("oi") is not None else J["P_r"])
        Q_seq = (J["Q_r"][out["oj"]] if out.get("oj") is not None else J["Q_r"])
        out.update(_farthest_chord_features(P_seq, Q_seq))
    return out

def _find_anchor_for_b(a_vals, b_vals, XY, p, f):
    a_vals = np.asarray(a_vals, int); b_vals = np.asarray(b_vals, int)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    d = step_size(f, p)
    for b_fix in list(range(g)) + list(range(p, p + g)):
        m = (b_vals == b_fix)
        mask_rot = m & (a_vals >= 0) & (a_vals < g)
        mask_ref = m & (a_vals >= p) & (a_vals < p + g)
        if not (mask_rot.any() and mask_ref.any()):
            continue
        dire = plus_d_direction_fixed_b(a_vals, b_vals, XY, p, f, b_fix)
        if dire.get("deg_rot") or dire.get("deg_ref") or dire.get("tie_rot") or dire.get("tie_ref"):
            continue
        if not dire.get("far_ok"):
            continue
        Pf = XY[mask_rot]; Qb = XY[mask_ref]
        I_resid = (a_vals[mask_rot] % g)
        J_resid = ((a_vals[mask_ref] - p) % g)
        oi, oj = dire["oi"], dire["oj"]
        P_ord = Pf[oi]; Q_ord = Qb[oj]
        i_star, j_star, _ = _farthest_chord_indices(P_ord, Q_ord)
        if i_star is None:
            continue
        r0 = int(I_resid[oi][i_star])
        j0 = int(J_resid[oj][j_star])
        return dict(kind="b", g=int(g), d=int(d), b_seed=int(b_fix), rot_start=r0, ref_start=j0)
    return None

def _find_anchor_for_a(a_vals, b_vals, XY, p, f):
    a_vals = np.asarray(a_vals, int); b_vals = np.asarray(b_vals, int)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    d = step_size(f, p)
    for a_fix in list(range(g)) + list(range(p, p + g)):
        m = (a_vals == a_fix)
        mask_rot = m & (b_vals >= 0) & (b_vals < g)
        mask_ref = m & (b_vals >= p) & (b_vals < p + g)
        if not (mask_rot.any() and mask_ref.any()):
            continue
        dire = plus_d_direction_fixed_a(a_vals, b_vals, XY, p, f, a_fix)
        if dire.get("deg_rot") or dire.get("deg_ref") or dire.get("tie_rot") or dire.get("tie_ref"):
            continue
        if not dire.get("far_ok"):
            continue
        Pf = XY[mask_rot]; Qb = XY[mask_ref]
        I_resid = (b_vals[mask_rot] % g)
        J_resid = ((b_vals[mask_ref] - p) % g)
        oi, oj = dire["oi"], dire["oj"]
        P_ord = Pf[oi]; Q_ord = Qb[oj]
        i_star, j_star, _ = _farthest_chord_indices(P_ord, Q_ord)
        if i_star is None:
            continue
        r0 = int(I_resid[oi][i_star])
        j0 = int(J_resid[oj][j_star])
        return dict(kind="a", g=int(g), d=int(d), a_seed=int(a_fix), rot_start=r0, ref_start=j0)
    return None

def test_k_fixed_b(a_vals, b_vals, XY, p, f, b_fix,
                   model: str = "auto",
                   s_mode: str = "scan",
                   V_link=None,
                   rng=None,
                   anchor: dict | None = None):
    a_vals = np.asarray(a_vals, int); b_vals = np.asarray(b_vals, int); XY = np.asarray(XY, float)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    d = step_size(f, p)
    m_mask = (b_vals == b_fix)
    mask_rot = m_mask & (a_vals >= 0) & (a_vals < g)
    mask_ref = m_mask & (a_vals >= p) & (a_vals < p + g)
    I_resid = (a_vals[mask_rot] % g); J_resid = ((a_vals[mask_ref] - p) % g)
    Pf = XY[mask_rot]; Qb = XY[mask_ref]
    if I_resid.size == 0 or J_resid.size == 0:
        return dict(success=False, reason="empty stripe", g=int(g))
    J2d = _fit_joint_2d(Pf, Qb, V_link=V_link)
    pick_i = _direction_and_basis_for_family(I_resid, Pf, g, d, J2d["V_joint"], J2d["P_w"], J2d["P_r"], J2d["lam2_share_joint"])
    pick_j = _direction_and_basis_for_family(J_resid, Qb, g, d, J2d["V_joint"], J2d["Q_w"], J2d["Q_r"], J2d["lam2_share_joint"])
    if pick_i["degenerate"] or pick_j["degenerate"]:
        return dict(success=False, reason="degenerate",
                    g=int(g), deg_rot=bool(pick_i["degenerate"]), deg_ref=bool(pick_j["degenerate"]),
                    deg_reason_rot=pick_i.get("deg_reason"), deg_reason_ref=pick_j.get("deg_reason"))
    if pick_i["tie"] or pick_j["tie"]:
        return dict(success=False, reason="tie",
                    g=int(g), tie_rot=bool(pick_i["tie"]), tie_ref=bool(pick_j["tie"]))
    oi, oj = pick_i["oi"].copy(), pick_j["oi"].copy()
    aligned_rot = aligned_ref = False
    if anchor is not None and anchor.get("kind") in (None, "b"):
        oi, aligned_rot = _roll_oi_to_start(I_resid, oi, g, anchor["rot_start"])
        oj, aligned_ref = _roll_oi_to_start(J_resid, oj, g, anchor["ref_start"])
    I = I_resid[oi]; J = J_resid[oj]
    Pf_ord2d = J2d["P_r"][oi]; Qb_ord2d = J2d["Q_r"][oj]
    m = min(I.size, J.size)
    if m < 2:
        return dict(success=False, reason="too few pairs", g=int(g))
    I2, J2 = I[:m], J[:m]
    if anchor is not None and (aligned_rot or aligned_ref):
        s_candidates = [0]
    else:
        s_candidates = []
        if s_mode == "anchor":
            s0 = _anchor_shift_from_farthest_pair(Pf_ord2d[:m], Qb_ord2d[:m], I2, J2, g, rng=rng)
            if s0 is not None: s_candidates = [s0]
        if not s_candidates: s_candidates = list(range(m))
    best_rot = dict(model="rotation",  k_hat=None, consistency=-1.0, s_used=None)
    best_ref = dict(model="reflection", k_hat=None, consistency=-1.0, s_used=None)
    for s in s_candidates:
        Jroll = np.roll(J2, s)
        kR, cR = _constant_k_consistency((I2 + Jroll) % g)
        if cR > best_ref["consistency"]:
            best_ref.update(k_hat=int(kR), consistency=float(cR), s_used=int(s))
        kT, cT = _constant_k_consistency((Jroll - I2) % g)
        if cT > best_rot["consistency"]:
            best_rot.update(k_hat=int(kT), consistency=float(cT), s_used=int(s))
    if model == "reflection":
        chosen = best_ref
    elif model == "rotation":
        chosen = best_rot
    else:
        chosen = best_ref if best_ref["consistency"] >= best_rot["consistency"] else best_rot
    return dict(success=True, g=int(g), m=int(m),
                model=("generator right mult" if chosen is best_ref else "generator left mult"),
                k_hat=chosen["k_hat"], consistency=chosen["consistency"], s_used=chosen["s_used"],
                step_sign_rot=int(pick_i["step_sign"]), step_sign_ref=int(pick_j["step_sign"]),
                basis_rot=pick_i.get("basis"), basis_ref=pick_j.get("basis"),
                start_aligned_rot=bool(aligned_rot), start_aligned_ref=bool(aligned_ref))

def test_k_fixed_a(a_vals, b_vals, XY, p, f, a_fix,
                   model: str = "auto",
                   s_mode: str = "scan",
                   V_link=None,
                   rng=None,
                   anchor: dict | None = None):
    a_vals = np.asarray(a_vals, int); b_vals = np.asarray(b_vals, int); XY = np.asarray(XY, float)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    d = step_size(f, p)
    m_mask = (a_vals == a_fix)
    mask_rot = m_mask & (b_vals >= 0) & (b_vals < g)
    mask_ref = m_mask & (b_vals >= p) & (b_vals < p + g)
    I_resid = (b_vals[mask_rot] % g); J_resid = ((b_vals[mask_ref] - p) % g)
    Pf = XY[mask_rot]; Qb = XY[mask_ref]
    if I_resid.size == 0 or J_resid.size == 0:
        return dict(success=False, reason="empty stripe", g=int(g))
    J2d = _fit_joint_2d(Pf, Qb, V_link=V_link)
    pick_i = _direction_and_basis_for_family(I_resid, Pf, g, d, J2d["V_joint"], J2d["P_w"], J2d["P_r"], J2d["lam2_share_joint"])
    pick_j = _direction_and_basis_for_family(J_resid, Qb, g, d, J2d["V_joint"], J2d["Q_w"], J2d["Q_r"], J2d["lam2_share_joint"])
    if pick_i["degenerate"] or pick_j["degenerate"]:
        return dict(success=False, reason="degenerate",
                    g=int(g), deg_rot=bool(pick_i["degenerate"]), deg_ref=bool(pick_j["degenerate"]),
                    deg_reason_rot=pick_i.get("deg_reason"), deg_reason_ref=pick_j.get("deg_reason"))
    if pick_i["tie"] or pick_j["tie"]:
        return dict(success=False, reason="tie",
                    g=int(g), tie_rot=bool(pick_i["tie"]), tie_ref=bool(pick_j["tie"]))
    oi, oj = pick_i["oi"].copy(), pick_j["oi"].copy()
    aligned_rot = aligned_ref = False
    if anchor is not None and anchor.get("kind") in (None, "a"):
        oi, aligned_rot = _roll_oi_to_start(I_resid, oi, g, anchor["rot_start"])
        oj, aligned_ref = _roll_oi_to_start(J_resid, oj, g, anchor["ref_start"])
    I = I_resid[oi]; J = J_resid[oj]
    Pf_ord2d = J2d["P_r"][oi]; Qb_ord2d = J2d["Q_r"][oj]
    m = min(I.size, J.size)
    if m < 2:
        return dict(success=False, reason="too few pairs", g=int(g))
    I2, J2 = I[:m], J[:m]
    if anchor is not None and (aligned_rot or aligned_ref):
        s_candidates = [0]
    else:
        s_candidates = []
        if s_mode == "anchor":
            s0 = _anchor_shift_from_farthest_pair(Pf_ord2d[:m], Qb_ord2d[:m], I2, J2, g, rng=rng)
            if s0 is not None: s_candidates = [s0]
        if not s_candidates: s_candidates = list(range(m))
    best_rot = dict(model="rotation",  k_hat=None, consistency=-1.0, s_used=None)
    best_ref = dict(model="reflection", k_hat=None, consistency=-1.0, s_used=None)
    for s in s_candidates:
        Jroll = np.roll(J2, s)
        kR, cR = _constant_k_consistency((I2 + Jroll) % g)
        if cR > best_ref["consistency"]:
            best_ref.update(k_hat=int(kR), consistency=float(cR), s_used=int(s))
        kT, cT = _constant_k_consistency((Jroll - I2) % g)
        if cT > best_rot["consistency"]:
            best_rot.update(k_hat=int(kT), consistency=float(cT), s_used=int(s))
    if model == "reflection":
        chosen = best_ref
    elif model == "rotation":
        chosen = best_rot
    else:
        chosen = best_ref if best_ref["consistency"] >= best_rot["consistency"] else best_rot
    return dict(success=True, g=int(g), m=int(m),
                model=("generator right mult" if chosen is best_ref else "generator left mult"),
                k_hat=chosen["k_hat"], consistency=chosen["consistency"], s_used=chosen["s_used"],
                step_sign_rot=int(pick_i["step_sign"]), step_sign_ref=int(pick_j["step_sign"]),
                basis_rot=pick_i.get("basis"), basis_ref=pick_j.get("basis"),
                start_aligned_rot=bool(aligned_rot), start_aligned_ref=bool(aligned_ref))

def summarize_by_b(a_vals, b_vals, XY, p, f,
                   include_direction=True,
                   model="auto", s_mode="scan",
                   V_link=None,
                   rng=None, anchor=None):
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    rows = []
    for b_fix in list(range(g)) + list(range(p, p + g)):
        row = dict(b_fix=int(b_fix), g=int(g), d=int(step_size(f, p)))
        if include_direction:
            dire = plus_d_direction_fixed_b(a_vals, b_vals, XY, p, f, b_fix,V_link=V_link, rng=rng)
            row.update(
                s_rot=dire["s_rot"], s_ref=dire["s_ref"],
                strength_rot=dire["strength_rot"], strength_ref=dire["strength_ref"],
                step_sign_rot=dire["step_sign_rot"], step_sign_ref=dire["step_sign_ref"],
                basis_rot=dire.get("basis_rot"), basis_ref=dire.get("basis_ref"),
                far_ok=bool(dire.get("far_ok", False)),
                far_dist=dire.get("far_dist"),
                far_angle=dire.get("far_angle"),
                far_angle_deg=(float(np.degrees(dire["far_angle"])) if dire.get("far_angle") is not None else None),
                deg_rot=bool(dire.get("deg_rot")), deg_ref=bool(dire.get("deg_ref")),
                tie_rot=bool(dire.get("tie_rot")), tie_ref=bool(dire.get("tie_ref")),
                deg_reason_rot=dire.get("deg_reason_rot"), deg_reason_ref=dire.get("deg_reason_ref"),
                iso_rot=dire.get("iso_rot"), iso_ref=dire.get("iso_ref"),
            )
        res = test_k_fixed_b(a_vals, b_vals, XY, p, f, b_fix, V_link=V_link, model=model, s_mode=s_mode, rng=rng, anchor=anchor)
        if res.get("success"):
            row.update(
                generator=res["model"], k_hat=res["k_hat"], consistency=res["consistency"], s_used=res["s_used"],
                step_sign_rot=res.get("step_sign_rot"), step_sign_ref=res.get("step_sign_ref"]),
                basis_rot=row.get("basis_rot"), basis_ref=row.get("basis_ref"),
                start_aligned_rot=res.get("start_aligned_rot"), start_aligned_ref=res.get("start_aligned_ref"),
            )
        else:
            row.update(generator=None, k_hat=None, consistency=None, s_used=None, reason=res.get("reason"))
        row.setdefault("reason", None)
        rows.append(row)
    return rows

def summarize_by_a(a_vals, b_vals, XY, p, f,
                   include_direction=True,
                   model="auto", s_mode="scan",
                   V_link=None,
                   rng=None, anchor=None):
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    rows = []
    for a_fix in list(range(g)) + list(range(p, p + g)):
        row = dict(a_fix=int(a_fix), g=int(g), d=int(step_size(f, p)))
        if include_direction:
            dire = plus_d_direction_fixed_a(a_vals, b_vals, XY, p, f, a_fix, V_link=V_link, rng=rng)
            row.update(
                s_rot=dire["s_rot"], s_ref=dire["s_ref"],
                strength_rot=dire["strength_rot"], strength_ref=dire["strength_ref"],
                step_sign_rot=dire["step_sign_rot"], step_sign_ref=dire["step_sign_ref"],
                basis_rot=dire.get("basis_rot"), basis_ref=dire.get("basis_ref"),
                far_ok=bool(dire.get("far_ok", False)),
                far_dist=dire.get("far_dist"),
                far_angle=dire.get("far_angle"),
                far_angle_deg=(float(np.degrees(dire["far_angle"])) if dire.get("far_angle") is not None else None),
                deg_rot=bool(dire.get("deg_rot")), deg_ref=bool(dire.get("deg_ref")),
                tie_rot=bool(dire.get("tie_rot")), tie_ref=bool(dire.get("tie_ref")),
                deg_reason_rot=dire.get("deg_reason_rot"), deg_reason_ref=dire.get("deg_reason_ref"),
                iso_rot=dire.get("iso_rot"), iso_ref=dire.get("iso_ref"),
            )
        res = test_k_fixed_a(a_vals, b_vals, XY, p, f, a_fix, V_link=V_link, model=model, s_mode=s_mode, rng=rng, anchor=anchor)
        if res.get("success"):
            row.update(
                generator=res["model"], k_hat=res["k_hat"], consistency=res["consistency"], s_used=res["s_used"],
                step_sign_rot=res.get("step_sign_rot"), step_sign_ref=res.get("step_sign_ref"]),
                basis_rot=row.get("basis_rot"), basis_ref=row.get("basis_ref"),
                start_aligned_rot=res.get("start_aligned_rot"), start_aligned_ref=res.get("start_aligned_ref"),
            )
        else:
            row.update(generator=None, k_hat=None, consistency=None, s_used=None, reason=res.get("reason"))
        row.setdefault("reason", None)
        rows.append(row)
    return rows

def _collect_anchors_for_b(a_vals, b_vals, XY, p, f):
    anchors = []
    for b_fix in list(range(p // math.gcd(p,f) if math.gcd(p,f) else p)) + list(range(p, p + (p // math.gcd(p,f) if math.gcd(p,f) else p))):
        dire = plus_d_direction_fixed_b(a_vals, b_vals, XY, p, f, b_fix, V_link=None)
        if dire.get("deg_rot") or dire.get("deg_ref") or dire.get("tie_rot") or dire.get("tie_ref") or (not dire.get("far_ok")):
            continue
        m = (b_vals == b_fix)
        mask_rot = m & (a_vals >= 0) & (a_vals < (p // math.gcd(p,f) if math.gcd(p,f) else p))
        mask_ref = m & (a_vals >= p) & (a_vals < p + (p // math.gcd(p,f) if math.gcd(p,f) else p))
        I_resid = (a_vals[mask_rot] % (p // math.gcd(p,f) if math.gcd(p,f) else p))
        J_resid = ((a_vals[mask_ref] - p) % (p // math.gcd(p,f) if math.gcd(p,f) else p))
        oi, oj = dire["oi"], dire["oj"]
        P_ord = XY[mask_rot][oi]; Q_ord = XY[mask_ref][oj]
        i_star, j_star, _ = _farthest_chord_indices(P_ord, Q_ord)
        if i_star is None:
            continue
        anchors.append(dict(kind="b", b_seed=int(b_fix),
                            g=int(p // math.gcd(p,f) if math.gcd(p,f) else p),
                            rot_start=int(I_resid[oi][i_star]),
                            ref_start=int(J_resid[oj][j_star])))
    return anchors

def _collect_anchors_for_a(a_vals, b_vals, XY, p, f):
    anchors = []
    for a_fix in list(range(p // math.gcd(p,f) if math.gcd(p,f) else p)) + list(range(p, p + (p // math.gcd(p,f) if math.gcd(p,f) else p))):
        dire = plus_d_direction_fixed_a(a_vals, b_vals, XY, p, f, a_fix, V_link=None)
        if dire.get("deg_rot") or dire.get("deg_ref") or dire.get("tie_rot") or dire.get("tie_ref") or (not dire.get("far_ok")):
            continue
        m = (a_vals == a_fix)
        mask_rot = m & (b_vals >= 0) & (b_vals < (p // math.gcd(p,f) if math.gcd(p,f) else p))
        mask_ref = m & (b_vals >= p) & (b_vals < p + (p // math.gcd(p,f) if math.gcd(p,f) else p))
        I_resid = (b_vals[mask_rot] % (p // math.gcd(p,f) if math.gcd(p,f) else p))
        J_resid = ((b_vals[mask_ref] - p) % (p // math.gcd(p,f) if math.gcd(p,f) else p))
        oi, oj = dire["oi"], dire["oj"]
        P_ord = XY[mask_rot][oi]; Q_ord = XY[mask_ref][oj]
        i_star, j_star, _ = _farthest_chord_indices(P_ord, Q_ord)
        if i_star is None:
            continue
        anchors.append(dict(kind="a", a_seed=int(a_fix),
                            g=int(p // math.gcd(p,f) if math.gcd(p,f) else p),
                            rot_start=int(I_resid[oi][i_star]),
                            ref_start=int(J_resid[oj][j_star])))
    return anchors

def _rows_consistency(rows):
    ok = [r for r in rows if r.get("generator") is not None]
    if not ok:
        return dict(all_gen_equal=False, all_k_equal=False, success_frac=0.0,
                    generator=None, k=None)
    gen_set = {r["generator"] for r in ok}
    k_set   = {r["k_hat"]     for r in ok}
    return dict(
        all_gen_equal=(len(gen_set) == 1),
        all_k_equal  =(len(k_set)   == 1),
        success_frac = len(ok)/len(rows),
        generator    = next(iter(gen_set)) if len(gen_set)==1 else None,
        k            = next(iter(k_set))   if len(k_set)==1   else None,
    )

def _sweep_family(a_vals, b_vals, XY, p, f, family, V_link, model, s_mode, rng):
    if family == "b":
        anchors = _collect_anchors_for_b(a_vals, b_vals, XY, p, f)
        if not anchors:
            return dict(num_anchors=0, all_gen_equal_across_anchors=None, all_k_equal_across_anchors=None,
                        min_success_frac=None)
        stats = []
        for anc in anchors:
            rows = summarize_by_b(a_vals, b_vals, XY, p, f, include_direction=True,
                                  model=model, s_mode=s_mode, rng=rng, anchor=anc, V_link=V_link)
            stats.append(_rows_consistency(rows))
        return dict(
            num_anchors=len(anchors),
            all_gen_equal_across_anchors=all(s["all_gen_equal"] for s in stats),
            all_k_equal_across_anchors  =all(s["all_k_equal"]   for s in stats),
            min_success_frac=min(s["success_frac"] for s in stats),
        )
    else:
        anchors = _collect_anchors_for_a(a_vals, b_vals, XY, p, f)
        if not anchors:
            return dict(num_anchors=0, all_gen_equal_across_anchors=None, all_k_equal_across_anchors=None,
                        min_success_frac=None)
        stats = []
        for anc in anchors:
            rows = summarize_by_a(a_vals, b_vals, XY, p, f, include_direction=True,
                                  model=model, s_mode=s_mode, rng=rng, anchor=anc, V_link=V_link)
            stats.append(_rows_consistency(rows))
        return dict(
            num_anchors=len(anchors),
            all_gen_equal_across_anchors=all(s["all_gen_equal"] for s in stats),
            all_k_equal_across_anchors  =all(s["all_k_equal"]   for s in stats),
            min_success_frac=min(s["success_frac"] for s in stats),
        )

def run_and_save_stripe_analysis(
    XY: np.ndarray,
    a_vals: np.ndarray,
    b_vals: np.ndarray,
    p: int,
    f_list: list[int],
    out_dir: str,
    *,
    label: str,
    tag_q: str = "full",
    s_mode: str = "scan",
    model: str = "auto",
    seed: int | None = None
) -> None:
    rng = np.random.default_rng(seed)
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    N = XY.shape[0]
    is_full_grid = (tag_q == "full") and (N == (2*p)*(2*p))
    if not is_full_grid:
        print(f"[{label}] stripe analysis skipped: need (tag_q='full' & N=(2p)^2). "
              f"Now tag_q={tag_q!r}, N={N}, p={p}")
        return
    def _to_rows_dict(rows):
        out = []
        for r in rows:
            rd = {k: (int(v) if isinstance(v, (np.integer,)) else
                      float(v) if isinstance(v, (np.floating,)) else v)
                  for k, v in r.items()}
            out.append(rd)
        return out
    for f in f_list or []:
        f_abs = abs(int(f))
        if f_abs % p == 0:
            continue
        if XY.shape[1] < 2:
            XY = np.hstack([XY, np.zeros((XY.shape[0], 1), dtype=XY.dtype)])
        LINK = _fit_link_basis(XY)
        anchor_b = _find_anchor_for_b(a_vals, b_vals, XY, p, f_abs)
        anchor_a = _find_anchor_for_a(a_vals, b_vals, XY, p, f_abs)
        rows_b = summarize_by_b(a_vals, b_vals, XY, p, f_abs,
                                include_direction=True, model=model, s_mode=s_mode, rng=rng, anchor=anchor_b, V_link=LINK["V_link"])
        rows_a = summarize_by_a(a_vals, b_vals, XY, p, f_abs,
                                include_direction=True, model=model, s_mode=s_mode, rng=rng, anchor=anchor_a, V_link=LINK["V_link"])
        def _global_stats(rows, key_model="generator", key_k="k_hat", key_cons="consistency",
                          key_srot="s_rot", key_sref="s_ref"):
            import collections
            if not rows:
                return {}
            m_counter = collections.Counter([r.get(key_model) for r in rows if r.get(key_model) is not None])
            k_counter = collections.Counter([r.get(key_k) for r in rows if r.get(key_k) is not None])
            cons_values = [r.get(key_cons, 0.0) for r in rows if isinstance(r.get(key_cons), (int,float))]
            sprod = [r.get(key_srot,0)*r.get(key_sref,0) for r in rows
                     if isinstance(r.get(key_srot), int) and isinstance(r.get(key_sref), int)]
            out = dict(
                dominant_generator=(m_counter.most_common(1)[0][0] if m_counter else None),
                dominant_k=(k_counter.most_common(1)[0][0] if k_counter else None),
                mean_consistency=(float(np.mean(cons_values)) if cons_values else 0.0),
                frac_srot_times_sref_eq_minus1=(float(np.mean([1 if v==-1 else 0 for v in sprod])) if sprod else 0.0),
            )
            return out
        stats_b = _global_stats(rows_b, key_model="generator")
        stats_a = _global_stats(rows_a, key_model="generator")
        sweep_b = _sweep_family(a_vals, b_vals, XY, p, f_abs, family="b",
                                V_link=LINK["V_link"], model=model, s_mode=s_mode, rng=rng)
        sweep_a = _sweep_family(a_vals, b_vals, XY, p, f_abs, family="a",
                                V_link=LINK["V_link"], model=model, s_mode=s_mode, rng=rng)
        anchor_blk = dict(
            b_family=sweep_b,
            a_family=sweep_a,
        )
        stats_b["anchor_sweep"] = anchor_blk
        stats_a["anchor_sweep"] = anchor_blk
        if anchor_b is not None:
            stats_b.update(anchor_seed_b=anchor_b.get("b_seed"),
                           rot_start=anchor_b.get("rot_start"), ref_start=anchor_b.get("ref_start"))
        if anchor_a is not None:
            stats_a.update(anchor_seed_a=anchor_a.get("a_seed"),
                           rot_start=anchor_a.get("rot_start"), ref_start=anchor_a.get("ref_start"))
        def _mean_axis_from_angles(angle_list):
            if not angle_list:
                return None, None
            A = np.asarray(angle_list, float)
            C = float(np.mean(np.cos(2*A)))
            S = float(np.mean(np.sin(2*A)))
            theta = 0.5 * math.atan2(S, C)
            u = np.array([math.cos(theta), math.sin(theta)])
            return theta, u
        angles_b = [r["far_angle"] for r in rows_b if r.get("far_ok") and (r.get("far_angle") is not None)]
        angles_a = [r["far_angle"] for r in rows_a if r.get("far_ok") and (r.get("far_angle") is not None)]
        th_b, ub = _mean_axis_from_angles(angles_b)
        th_a, ua = _mean_axis_from_angles(angles_a)
        angle_between_means_deg = None
        abs_dot_between_means = None
        if (ub is not None) and (ua is not None):
            abs_dot_between_means = float(abs(np.dot(ub, ua)))
            angle_between_means_deg = float(np.degrees(math.acos(min(1.0, max(0.0, abs_dot_between_means)))))
        PERP_DELTA_DEG = 15.0
        global_cross = dict(
            mean_far_angle_a_deg=(float(np.degrees(th_a)) if th_a is not None else None),
            mean_far_angle_b_deg=(float(np.degrees(th_b)) if th_b is not None else None),
            angle_between_mean_ab_deg=angle_between_means_deg,
            abs_dot_between_mean_ab=abs_dot_between_means,
            perp_delta_deg=PERP_DELTA_DEG,
            means_perpendicular=(angle_between_means_deg is not None and
                                 abs(angle_between_means_deg - 90.0) <= PERP_DELTA_DEG)
        )
        base = os.path.join(out_dir, f"{label.lower()}_stripe_summary_f{f_abs}")
        with open(base + "_by_b.json", "w") as fh:
            json.dump({"rows": _to_rows_dict(rows_b), "global": stats_b, "global_cross": global_cross,
                       "p": int(p), "f": int(f_abs)}, fh, indent=2)
        with open(base + "_by_a.json", "w") as fh:
            json.dump({"rows": _to_rows_dict(rows_a), "global": stats_a, "global_cross": global_cross,
                       "p": int(p), "f": int(f_abs)}, fh, indent=2)
        def _write_csv(path, rows):
            if not rows:
                return
            import csv
            fieldnames = sorted(set().union(*(r.keys() for r in rows)))
            with open(path, "w", newline="") as fcsv:
                w = csv.DictWriter(fcsv, fieldnames=fieldnames, extrasaction="ignore")
                w.writeheader()
                for r in rows:
                    row_filled = {k: r.get(k, None) for k in fieldnames}
                    w.writerow(row_filled)
        _write_csv(base + "_by_b.csv", _to_rows_dict(rows_b))
        _write_csv(base + "_by_a.csv", _to_rows_dict(rows_a))
        print(f"[{label}] stripe analysis saved for f={f_abs} at '{out_dir}'")
