import os
import json
import math
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import numpy as np


def _plane_normal_by_pca(X: np.ndarray) -> Optional[np.ndarray]:
    X = np.asarray(X, float)
    if X.ndim != 2 or X.shape[0] < 3:
        return None
    if X.shape[1] < 3:
        X = np.pad(X, ((0, 0), (0, 3 - X.shape[1])), mode="constant")

    Xc = X - X.mean(axis=0, keepdims=True)
    if np.linalg.matrix_rank(Xc) < 2:
        return None

    _, _, Vt = np.linalg.svd(Xc, full_matrices=False)
    n = Vt[-1]
    n = n / (float(np.linalg.norm(n)) + 1e-12)
    return n


def _angle_between_normals_deg(n1: np.ndarray, n2: np.ndarray) -> float:
    x = float(abs(np.dot(n1, n2)))
    x = max(-1.0, min(1.0, x))
    return float(np.degrees(math.acos(x)))


def _split_masks(
    a_vals: np.ndarray,
    b_vals: np.ndarray,
    p: int,
    mode: str,
    *,
    tag_q: str = "full",
) -> Tuple[np.ndarray, np.ndarray]:
    if tag_q != "full":
        raise ValueError("rot/ref splitting is only defined for tag_q='full' (2p x 2p grid).")

    a_vals = np.asarray(a_vals, int)
    b_vals = np.asarray(b_vals, int)
    top = a_vals >= int(p)
    right = b_vals >= int(p)

    if mode == "a":
        rot = ~top
        ref = top
    elif mode == "b":
        rot = ~right
        ref = right
    elif mode == "c":
        rot = ((~top) & (~right)) | (top & right)
        ref = ~rot
    else:
        raise ValueError("unknown mode, expected 'a', 'b', or 'c'")

    return rot, ref


def plane_angle_per_cluster(
    coords: np.ndarray,
    a_vals: np.ndarray,
    b_vals: np.ndarray,
    p: int,
    cluster_ids: np.ndarray,
    *,
    mode: str,
    tag_q: str = "full",
    save_dir: Optional[str] = None,
    title: str = "",
) -> Dict[str, Any]:
    coords = np.asarray(coords, float)
    a_vals = np.asarray(a_vals, int)
    b_vals = np.asarray(b_vals, int)
    clusters = np.asarray(cluster_ids, int)

    if not (coords.shape[0] == a_vals.size == b_vals.size == clusters.size):
        raise ValueError("coords, a_vals, b_vals, cluster_ids must have the same length")

    rot_mask_global, ref_mask_global = _split_masks(a_vals, b_vals, int(p), str(mode), tag_q=tag_q)

    per_cluster: Dict[int, float] = {}
    for cid in np.unique(clusters):
        m = clusters == cid
        I_rot = np.where(m & rot_mask_global)[0]
        I_ref = np.where(m & ref_mask_global)[0]

        if I_rot.size < 3 or I_ref.size < 3:
            per_cluster[int(cid)] = float("nan")
            continue

        n_rot = _plane_normal_by_pca(coords[I_rot, :])
        n_ref = _plane_normal_by_pca(coords[I_ref, :])
        if n_rot is None or n_ref is None:
            per_cluster[int(cid)] = float("nan")
            continue

        per_cluster[int(cid)] = _angle_between_normals_deg(n_rot, n_ref)

    angles = np.array([v for v in per_cluster.values() if np.isfinite(v)], float)
    if angles.size == 0:
        hist_counts = np.zeros(36, dtype=int)
        bin_edges = np.linspace(0.0, 90.0, 37)
    else:
        hist_counts, bin_edges = np.histogram(angles, bins=36, range=(0.0, 90.0))

    paths: Dict[str, Any] = {}
    if save_dir:
        Path(save_dir).mkdir(parents=True, exist_ok=True)

        csv_path = os.path.join(save_dir, f"{mode}_plane_angles_per_cluster.csv")
        with open(csv_path, "w") as f:
            f.write("cluster_id,angle_deg\n")
            for cid, ang in sorted(per_cluster.items()):
                if math.isnan(ang):
                    f.write(f"{cid},\n")
                else:
                    f.write(f"{cid},{ang:.6f}\n")
        paths["csv"] = csv_path

        json_path = os.path.join(save_dir, f"{mode}_plane_angles_summary.json")
        with open(json_path, "w") as f:
            json.dump(
                {
                    "mode": str(mode),
                    "p": int(p),
                    "tag_q": str(tag_q),
                    "n_valid": int(angles.size),
                    "mean": float(np.mean(angles)) if angles.size else None,
                    "median": float(np.median(angles)) if angles.size else None,
                    "bin_edges": bin_edges.tolist(),
                    "counts": hist_counts.tolist(),
                    "per_cluster": per_cluster,
                    "title": str(title),
                },
                f,
                indent=2,
            )
        paths["json"] = json_path

        try:
            import matplotlib.pyplot as plt

            fig, ax = plt.subplots(figsize=(6, 4))
            ax.hist(angles, bins=36, range=(0.0, 90.0))
            ax.set_xlabel("angle (deg)")
            ax.set_ylabel("count")
            ax.set_title(title or f"{mode}-rot vs {mode}-ref")
            fig.tight_layout()
            png_path = os.path.join(save_dir, f"{mode}_plane_angles_hist.png")
            fig.savefig(png_path, dpi=150)
            plt.close(fig)
            paths["png"] = png_path
        except Exception as e:
            paths["png_error"] = f"{type(e).__name__}: {e}"

    return {
        "mode": str(mode),
        "angles_deg": angles.tolist(),
        "per_cluster": per_cluster,
        "hist": {"bin_edges": bin_edges.tolist(), "counts": hist_counts.tolist()},
        "paths": paths,
    }


def _plane_mesh_from_points(
    P: np.ndarray,
    dims: Tuple[int, int, int],
    grid: int = 12,
    pad_ratio: float = 0.05,
):
    P = np.asarray(P, float)
    i, j, k = (int(dims[0]), int(dims[1]), int(dims[2]))
    if P.shape[0] < 3:
        return None

    Q = P[:, [i, j, k]]
    Qc = Q - Q.mean(axis=0, keepdims=True)

    _, _, Vt = np.linalg.svd(Qc, full_matrices=False)
    n = Vt[-1]
    n = n / (float(np.linalg.norm(n)) + 1e-12)

    a = np.array([1.0, 0.0, 0.0], dtype=float)
    if abs(float(np.dot(a, n))) > 0.9:
        a = np.array([0.0, 1.0, 0.0], dtype=float)

    u = np.cross(n, a)
    u = u / (float(np.linalg.norm(u)) + 1e-12)
    v = np.cross(n, u)
    v = v / (float(np.linalg.norm(v)) + 1e-12)

    Ucoord = Qc @ u
    Vcoord = Qc @ v

    umin, umax = float(np.min(Ucoord)), float(np.max(Ucoord))
    vmin, vmax = float(np.min(Vcoord)), float(np.max(Vcoord))

    du, dv = (umax - umin), (vmax - vmin)
    umin -= float(pad_ratio) * du
    umax += float(pad_ratio) * du
    vmin -= float(pad_ratio) * dv
    vmax += float(pad_ratio) * dv

    uu = np.linspace(umin, umax, int(grid))
    vv = np.linspace(vmin, vmax, int(grid))
    UU, VV = np.meshgrid(uu, vv)

    center = Q.mean(axis=0)
    XYZ = center[None, None, :] + UU[..., None] * u[None, None, :] + VV[..., None] * v[None, None, :]

    Xg = XYZ[..., 0]
    Yg = XYZ[..., 1]
    Zg = XYZ[..., 2]
    return Xg, Yg, Zg
