import os, json, math
import numpy as np
from pathlib import Path
from typing import Dict, Any, Tuple, Optional

def _plane_normal_by_pca(X: np.ndarray) -> Optional[np.ndarray]:
    X = np.asarray(X, float)
    if X.ndim != 2 or X.shape[0] < 3:
        return None
    if X.shape[1] < 3:
        X = np.pad(X, ((0,0),(0,3-X.shape[1])), mode="constant")
    Xc = X - X.mean(axis=0, keepdims=True)
    if np.linalg.matrix_rank(Xc) < 2:
        return None
    U,S,Vt = np.linalg.svd(Xc, full_matrices=False)
    n = Vt[-1]; n = n / (np.linalg.norm(n) + 1e-12)
    return n

def _angle_between_normals_deg(n1: np.ndarray, n2: np.ndarray) -> float:
    x = float(abs(np.dot(n1, n2)))
    x = max(-1.0, min(1.0, x))
    return float(np.degrees(math.acos(x)))


def _split_masks(a_vals: np.ndarray, b_vals: np.ndarray, p: int, mode: str, *, tag_q="full") -> Tuple[np.ndarray,np.ndarray]:
    if tag_q != "full":
        raise ValueError("Need full grid to cunduct the plane fit.")
    a_vals = np.asarray(a_vals, int)
    b_vals = np.asarray(b_vals, int)
    top   = (a_vals >= p)
    right = (b_vals >= p)

    if mode == "a":
        rot = ~top         
        ref = top          
    elif mode == "b":
        rot = ~right       
        ref = right        
    elif mode == "c":
        rot = (~top & ~right) | (top & right)   # BL/TR
        ref = ~rot                               # TL/BR
    else:
        raise ValueError(f"unknown mode={mode!r}, expect 'a'/'b'/'c'")
    return rot, ref

def plane_angle_per_cluster(
    coords: np.ndarray,          
    a_vals: np.ndarray,          
    b_vals: np.ndarray,          
    p: int,
    cluster_ids: np.ndarray,     
    *,
    mode: str,                   
    tag_q: str = "full",
    save_dir: Optional[str] = None,
    title: str = ""
) -> Dict[str, Any]:
    
    coords   = np.asarray(coords, float)
    a_vals   = np.asarray(a_vals, int)
    b_vals   = np.asarray(b_vals, int)
    clusters = np.asarray(cluster_ids, int)
    assert coords.shape[0] == a_vals.size == b_vals.size == clusters.size

    rot_mask_global, ref_mask_global = _split_masks(a_vals, b_vals, p, mode, tag_q=tag_q)

    per_cluster: Dict[int, float] = {}
    for cid in np.unique(clusters):
        m = (clusters == cid)
        I_rot = np.where(m & rot_mask_global)[0]
        I_ref = np.where(m & ref_mask_global)[0]

        if I_rot.size < 3 or I_ref.size < 3:
            per_cluster[int(cid)] = float("nan")
            continue

        n_rot = _plane_normal_by_pca(coords[I_rot, :])
        n_ref = _plane_normal_by_pca(coords[I_ref, :])
        if n_rot is None or n_ref is None:
            per_cluster[int(cid)] = float("nan")
            continue

        per_cluster[int(cid)] = _angle_between_normals_deg(n_rot, n_ref)

    angles = np.array([v for v in per_cluster.values() if np.isfinite(v)], float)
    hist_counts, bin_edges = (np.zeros(36, int), np.linspace(0, 90, 37)) if angles.size==0 \
                             else np.histogram(angles, bins=36, range=(0, 90))

    paths = {}
    if save_dir:
        Path(save_dir).mkdir(parents=True, exist_ok=True)
        # CSV
        csv_path = os.path.join(save_dir, f"{mode}_plane_angles_per_cluster.csv")
        with open(csv_path, "w") as f:
            f.write("cluster_id,angle_deg\n")
            for cid, ang in sorted(per_cluster.items()):
                f.write(f"{cid},{'' if math.isnan(ang) else f'{ang:.6f}'}\n")
        paths["csv"] = csv_path
        # JSON
        json_path = os.path.join(save_dir, f"{mode}_plane_angles_summary.json")
        with open(json_path, "w") as f:
            json.dump(dict(
                mode=mode, p=int(p), tag_q=str(tag_q),
                n_valid=int(angles.size),
                mean=(float(np.mean(angles)) if angles.size else None),
                median=(float(np.median(angles)) if angles.size else None),
                bin_edges=bin_edges.tolist(), counts=hist_counts.tolist(),
                per_cluster=per_cluster, title=title
            ), f, indent=2)
        paths["json"] = json_path

        
        try:
            import matplotlib.pyplot as plt
            fig, ax = plt.subplots(figsize=(6,4))
            ax.hist(angles, bins=36, range=(0,90))
            ax.set_xlabel("angle (deg)"); ax.set_ylabel("count")
            ax.set_title(title or f"{mode}-rot vs {mode}-ref")
            fig.tight_layout()
            png_path = os.path.join(save_dir, f"{mode}_plane_angles_hist.png")
            fig.savefig(png_path, dpi=150); plt.close(fig)
            paths["png"] = png_path
        except Exception as e:
            paths["png_error"] = f"{type(e).__name__}: {e}"

    return dict(
        mode=mode, angles_deg=angles.tolist(),
        per_cluster=per_cluster,
        hist=dict(bin_edges=bin_edges.tolist(), counts=hist_counts.tolist()),
        paths=paths
    )

def _plane_mesh_from_points(P: np.ndarray,
                            dims: tuple[int,int,int],
                            grid: int = 12,
                            pad_ratio: float = 0.05):
    
    P = np.asarray(P, float)
    i, j, k = map(int, dims)
    if P.shape[0] < 3:
        return None

    Q = P[:, [i, j, k]]
    Qc = Q - Q.mean(axis=0, keepdims=True)

    U, S, Vt = np.linalg.svd(Qc, full_matrices=False)
    n = Vt[-1]; n = n / (np.linalg.norm(n) + 1e-12)

    a = np.array([1.0, 0.0, 0.0])
    if abs(np.dot(a, n)) > 0.9:
        a = np.array([0.0, 1.0, 0.0])
    u = np.cross(n, a); u = u / (np.linalg.norm(u) + 1e-12)
    v = np.cross(n, u); v = v / (np.linalg.norm(v) + 1e-12)

    Ucoord = Qc @ u
    Vcoord = Qc @ v
    umin, umax = np.min(Ucoord), np.max(Ucoord)
    vmin, vmax = np.min(Vcoord), np.max(Vcoord)

    du, dv = umax-umin, vmax-vmin
    umin -= pad_ratio*du; umax += pad_ratio*du
    vmin -= pad_ratio*dv; vmax += pad_ratio*dv

    uu = np.linspace(umin, umax, grid)
    vv = np.linspace(vmin, vmax, grid)
    UU, VV = np.meshgrid(uu, vv)
    center = Q.mean(axis=0)
    
    XYZ = center[None,None,:] + UU[...,None]*u[None,None,:] + VV[...,None]*v[None,None,:]
    X, Y, Z = XYZ[...,0], XYZ[...,1], XYZ[...,2]
    return X, Y, Z
