import os
import math
import json
from pathlib import Path
from typing import Any, Dict, Optional

import numpy as np
from scipy.spatial.distance import pdist, squareform

import pca_diffusion_plots_w_helpers
from color_rules import (
    colour_quad_a_only,
    colour_quad_b_only,
    colour_c_mod_p,
)
from math import gcd


def _mean_pairwise_distance(X: np.ndarray) -> float:
    X = np.asarray(X, float)
    n = int(X.shape[0])
    if n < 2:
        return 0.0
    return float(pdist(X, metric="euclidean").mean())


def _cluster_collapse_stats(XY: np.ndarray, cluster_ids: np.ndarray) -> dict:
    XY = np.asarray(XY, float)
    cluster_ids = np.asarray(cluster_ids)
    if XY.shape[0] != cluster_ids.shape[0]:
        raise ValueError("XY and cluster_ids must have the same length")

    global_mean = _mean_pairwise_distance(XY)
    per_cluster = []
    weighted_sum = 0.0
    total_pairs = 0.0

    for cid in np.unique(cluster_ids):
        idx = np.where(cluster_ids == cid)[0]
        if idx.size < 2:
            per_cluster.append({"id": int(cid), "size": int(idx.size), "mean_dist": None})
            continue

        Xc = XY[idx]
        m = _mean_pairwise_distance(Xc)
        per_cluster.append({"id": int(cid), "size": int(idx.size), "mean_dist": float(m)})

        n = int(idx.size)
        n_pairs = n * (n - 1) / 2.0
        weighted_sum += float(m) * n_pairs
        total_pairs += n_pairs

    if total_pairs == 0.0 or global_mean <= 0.0:
        return {
            "global_mean_dist": float(global_mean),
            "within_mean_dist": None,
            "ratio": None,
            "score": None,
            "per_cluster": per_cluster,
        }

    within_mean = float(weighted_sum / total_pairs)
    ratio = float(within_mean / global_mean)
    score = float(1.0 - ratio)

    return {
        "global_mean_dist": float(global_mean),
        "within_mean_dist": float(within_mean),
        "ratio": float(ratio),
        "score": float(score),
        "per_cluster": per_cluster,
    }


def _cluster_within_between_stats(coords_3d: np.ndarray, labels: np.ndarray) -> dict:
    coords_3d = np.asarray(coords_3d, float)
    labels = np.asarray(labels)
    N = int(coords_3d.shape[0])

    if N < 2:
        return {
            "within_mean_dist": 0.0,
            "between_mean_dist": 0.0,
            "ratio": 1.0,
            "score": 0.0,
            "per_cluster": [],
        }

    D = squareform(pdist(coords_3d, metric="euclidean"))
    uniq = np.unique(labels)
    idx_map = {lab: np.where(labels == lab)[0] for lab in uniq}

    within_sum = 0.0
    within_cnt = 0
    per_cluster = []

    for lab in uniq:
        idx = idx_map[lab]
        m = int(idx.size)
        if m < 2:
            continue
        sub = D[np.ix_(idx, idx)]
        iu = np.triu_indices(m, k=1)
        vals = sub[iu]
        if vals.size == 0:
            continue
        mean_c = float(vals.mean())
        lab_out = int(lab) if np.isscalar(lab) and isinstance(lab, (int, np.integer)) else lab
        per_cluster.append({"label": lab_out, "size": int(m), "mean_dist": float(mean_c)})
        within_sum += float(vals.sum())
        within_cnt += int(vals.size)

    within_mean = float(within_sum / within_cnt) if within_cnt > 0 else 0.0

    between_sum = 0.0
    between_cnt = 0
    L = int(len(uniq))

    for i in range(L):
        lab_i = uniq[i]
        idx_i = idx_map[lab_i]
        if idx_i.size == 0:
            continue
        for j in range(i + 1, L):
            lab_j = uniq[j]
            idx_j = idx_map[lab_j]
            if idx_j.size == 0:
                continue
            sub = D[np.ix_(idx_i, idx_j)]
            if sub.size == 0:
                continue
            between_sum += float(sub.sum())
            between_cnt += int(sub.size)

    if between_cnt > 0:
        between_mean = float(between_sum / between_cnt)
    else:
        between_mean = float(within_mean) if within_mean > 0 else 0.0

    ratio = float(within_mean / between_mean) if between_mean > 0 else 1.0
    score = float(1.0 - ratio)

    return {
        "within_mean_dist": float(within_mean),
        "between_mean_dist": float(between_mean),
        "ratio": float(ratio),
        "score": float(score),
        "per_cluster": per_cluster,
    }


def _min_enclosing_ball(points: np.ndarray) -> tuple[np.ndarray, float]:
    P = np.asarray(points, float)
    if P.ndim != 2:
        raise ValueError("points must be 2D array of shape (n, d)")
    n, d = P.shape
    if n == 0:
        return np.zeros(d, dtype=float), 0.0
    if n == 1:
        return P[0].copy(), 0.0

    P = P.copy()
    rng = np.random.default_rng(42)
    rng.shuffle(P)
    eps = 1e-12

    def sphere_from_1(p: np.ndarray) -> tuple[np.ndarray, float]:
        return p.copy(), 0.0

    def sphere_from_2(p: np.ndarray, q: np.ndarray) -> tuple[np.ndarray, float]:
        c = (p + q) / 2.0
        r = float(np.linalg.norm(p - c))
        return c, r

    def sphere_from_3(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> tuple[np.ndarray, float]:
        ab = b - a
        ac = c - a

        ab_norm = float(np.linalg.norm(ab))
        if ab_norm < 1e-12:
            return sphere_from_2(a, c)
        e1 = ab / ab_norm

        proj = float(np.dot(ac, e1)) * e1
        ortho = ac - proj
        ortho_norm = float(np.linalg.norm(ortho))
        if ortho_norm < 1e-12:
            pts = np.stack([a, b, c], axis=0)
            best_c, best_r = None, -1.0
            for i in range(3):
                for j in range(i + 1, 3):
                    cc, rr = sphere_from_2(pts[i], pts[j])
                    if rr > best_r:
                        best_r = rr
                        best_c = cc
            return best_c, float(best_r)
        e2 = ortho / ortho_norm

        d_len = ab_norm
        cx = float(np.dot(ac, e1))
        cy = float(np.dot(ac, e2))

        if abs(cy) < 1e-12:
            pts = np.stack([a, b, c], axis=0)
            best_c, best_r = None, -1.0
            for i in range(3):
                for j in range(i + 1, 3):
                    cc, rr = sphere_from_2(pts[i], pts[j])
                    if rr > best_r:
                        best_r = rr
                        best_c = cc
            return best_c, float(best_r)

        ux = d_len / 2.0
        uy = (cx * cx - d_len * cx + cy * cy) / (2.0 * cy)

        center = a + ux * e1 + uy * e2
        radius = float(np.linalg.norm(center - a))
        return center, radius

    def sphere_from_4(a: np.ndarray, b: np.ndarray, c: np.ndarray, d_: np.ndarray) -> tuple[np.ndarray, float]:
        if d != 3:
            raise ValueError("sphere_from_4 supports only d=3")
        p0 = a
        P3 = np.stack([b, c, d_], axis=0)
        A = P3 - p0
        bvec = 0.5 * (np.sum(P3 * P3, axis=1) - np.sum(p0 * p0))
        try:
            center = np.linalg.solve(A, bvec)
        except np.linalg.LinAlgError:
            pts = [a, b, c, d_]
            best_c, best_r = None, math.inf
            for idxs in ((0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)):
                cc, rr = sphere_from_3(pts[idxs[0]], pts[idxs[1]], pts[idxs[2]])
                if all(float(np.linalg.norm(p - cc)) <= float(rr) + 1e-9 for p in pts):
                    if rr < best_r:
                        best_r = rr
                        best_c = cc
            if best_c is None:
                best_c, best_r = sphere_from_2(a, d_)
            return best_c, float(best_r)
        radius = float(np.linalg.norm(center - a))
        return center, radius

    def ball_from(R: list[np.ndarray]) -> tuple[np.ndarray, float]:
        k = len(R)
        if k == 0:
            return np.zeros(d, dtype=float), 0.0
        if k == 1:
            return sphere_from_1(R[0])
        if k == 2:
            return sphere_from_2(R[0], R[1])
        if k == 3:
            return sphere_from_3(R[0], R[1], R[2])
        if k == 4:
            return sphere_from_4(R[0], R[1], R[2], R[3])
        raise ValueError("R too large in ball_from")

    import sys as _sys
    _sys.setrecursionlimit(10000)

    def welzl(P_arr: np.ndarray, R: list[np.ndarray], n_pts: int) -> tuple[np.ndarray, float]:
        if n_pts == 0 or len(R) == 4:
            return ball_from(R)
        p = P_arr[n_pts - 1]
        c, r = welzl(P_arr, R, n_pts - 1)
        if float(np.linalg.norm(p - c)) <= float(r) + eps:
            return c, r
        return welzl(P_arr, R + [p], n_pts - 1)

    center, radius = welzl(P, [], int(P.shape[0]))
    return np.asarray(center, float), float(radius)


def miniball_badoiu_clarkson(
    P: np.ndarray, eps: float = 1e-2, max_iter: Optional[int] = None, seed: int = 0
) -> tuple[np.ndarray, float]:
    P = np.asarray(P, float)
    n, d = P.shape
    if n == 0:
        return np.zeros(d), 0.0
    if n == 1:
        return P[0].copy(), 0.0

    rng = np.random.default_rng(int(seed))
    c = P[int(rng.integers(0, n))].copy()

    if max_iter is None:
        max_iter = int(np.ceil(1.0 / (float(eps) * float(eps))))

    for t in range(1, int(max_iter) + 1):
        dif = P - c[None, :]
        dist2 = np.einsum("ij,ij->i", dif, dif)
        i = int(np.argmax(dist2))
        p_far = P[i]
        c = c + (1.0 / (t + 1.0)) * (p_far - c)

    r = float(np.max(np.linalg.norm(P - c[None, :], axis=1)))
    return np.asarray(c, float), float(r)


def _coset_miniball_stats(coords: np.ndarray, labels: np.ndarray) -> Dict[str, Any]:
    coords = np.asarray(coords, float)
    labels = np.asarray(labels)

    uniq = np.unique(labels)
    centers = []
    radii = []
    sizes = []
    label_list = []

    for lab in uniq:
        idx = np.where(labels == lab)[0]
        pts = coords[idx]
        if pts.shape[0] == 0:
            continue
        if pts.shape[0] == 1:
            c = pts[0].copy()
            r = 0.0
        else:
            c, r = _min_enclosing_ball(pts)
        centers.append(c)
        radii.append(float(r))
        sizes.append(int(pts.shape[0]))
        if np.isscalar(lab) and isinstance(lab, (int, np.integer)):
            label_list.append(int(lab))
        else:
            label_list.append(lab)

    if not centers:
        return {
            "labels": [],
            "sizes": [],
            "centers": [],
            "radii": [],
            "center_dist": None,
            "per_cluster_pairs": [],
        }

    centers_arr = np.vstack(centers)
    radii_arr = np.asarray(radii, float)
    n_c = int(centers_arr.shape[0])

    if n_c == 1:
        center_dist = np.zeros((1, 1), dtype=float)
        overlap = np.zeros((1, 1), dtype=bool)
        disjoint = np.zeros((1, 1), dtype=bool)
    else:
        D_flat = pdist(centers_arr, metric="euclidean")
        center_dist = squareform(D_flat)
        r_sum = radii_arr[:, None] + radii_arr[None, :]
        overlap = center_dist < r_sum
        disjoint = center_dist > r_sum
        np.fill_diagonal(overlap, False)
        np.fill_diagonal(disjoint, False)

    per_cluster_pairs = []
    if n_c == 1:
        per_cluster_pairs.append(
            {
                "label": label_list[0],
                "size": sizes[0],
                "num_other_cosets": 0,
                "num_intersect": 0,
                "num_disjoint": 0,
                "num_ambiguous": 0,
                "frac_intersect": None,
                "frac_disjoint": None,
            }
        )
    else:
        n_other = n_c - 1
        for i, lab in enumerate(label_list):
            n_inter = int(overlap[i].sum())
            n_disj = int(disjoint[i].sum())
            n_amb = int(n_other - n_inter - n_disj)
            per_cluster_pairs.append(
                {
                    "label": lab,
                    "size": sizes[i],
                    "num_other_cosets": int(n_other),
                    "num_intersect": int(n_inter),
                    "num_disjoint": int(n_disj),
                    "num_ambiguous": int(n_amb),
                    "frac_intersect": (n_inter / n_other) if n_other > 0 else None,
                    "frac_disjoint": (n_disj / n_other) if n_other > 0 else None,
                }
            )

    return {
        "labels": label_list,
        "sizes": sizes,
        "centers": [c.tolist() for c in centers],
        "radii": [float(r) for r in radii],
        "center_dist": center_dist.tolist(),
        "per_cluster_pairs": per_cluster_pairs,
    }


def run_and_save_coset_collapse(
    embedding_weights: np.ndarray,
    p: int,
    f: int,
    out_dir: str,
    *,
    label: str,
    tag_q: str = "full",
    num_pca_dims: int = 4,
    seed: Optional[int] = None,
) -> None:
    X = np.asarray(embedding_weights, float)
    N = int(X.shape[0])
    side = int(math.isqrt(N))

    if not (tag_q == "full" and side == 2 * int(p) and side * side == N):
        print(
            f"[{label}] coset collapse skipped: need full grid with N=(2p)^2. "
            f"tag_q={tag_q!r}, N={N}, side={side}, p={p}"
        )
        return

    if gcd(int(p), int(f)) == 1:
        print(f"[{label}] coset collapse skipped: approximate-coset case")
        return

    core = pca_diffusion_plots_w_helpers.run_pca_core(
        mat=X,
        p=int(p),
        save_dir=out_dir,
        seed=int(seed) if seed is not None else 0,
        tag=f"{label}_coset_p{int(p)}_f{abs(int(f))}",
        tag_q=str(tag_q),
        max_components=int(num_pca_dims),
    )

    coords_pca = np.asarray(core["pcs"], float)
    metric_xy = coords_pca[:, : min(3, int(coords_pca.shape[1]))]
    evr = core.get("var_ratio", None)

    indices = np.arange(N, dtype=int)
    a_vals = indices // side
    b_vals = indices % side

    col_a, _caption_a, p_cbar_a, _ = colour_quad_a_only(a_vals, b_vals, int(p), int(f), "full")
    col_b, _caption_b, p_cbar_b, _ = colour_quad_b_only(a_vals, b_vals, int(p), int(f), "full")
    col_c, caption_c, p_cbar_c, _ = colour_c_mod_p(a_vals, b_vals, int(p), int(f), "full")

    col_a = np.asarray(col_a, int)
    col_b = np.asarray(col_b, int)
    col_c = np.asarray(col_c, int)

    n_a = int(p_cbar_a)
    n_b = int(p_cbar_b)
    n_c = int(p_cbar_c)

    cluster_c = col_c
    cluster_ab = col_a * n_b + col_b

    stats_c_global = _cluster_collapse_stats(metric_xy, cluster_c)
    stats_c_between = _cluster_within_between_stats(metric_xy, cluster_c)
    stats_ab_global = _cluster_collapse_stats(metric_xy, cluster_ab)
    stats_ab_between = _cluster_within_between_stats(metric_xy, cluster_ab)

    ball_stats_ab = _coset_miniball_stats(metric_xy, cluster_ab)
    ball_stats_c = _coset_miniball_stats(metric_xy, cluster_c)

    if isinstance(evr, np.ndarray):
        evr = evr.tolist()

    payload: Dict[str, Any] = {
        "label": str(label),
        "p": int(p),
