import torch
import trimesh
import glob
import numpy as np
import os

def _minmax_neg1_pos1(pc):
    mn = pc.min(axis=0, keepdims=True)
    mx = pc.max(axis=0, keepdims=True)
    denom = np.clip(mx - mn, 1e-12, None)
    return ((pc - mn) / denom * 2.0 - 1.0).astype(np.float32)

def _unit_sphere(pc):
    pc = pc - pc.mean(axis=0, keepdims=True)
    scale = np.linalg.norm(pc, axis=1).max()
    return (pc / max(scale, 1e-12)).astype(np.float32)

def _sample_surface(mesh, n_points):
    pts, _ = trimesh.sample.sample_surface(mesh, n_points)
    return pts.astype(np.float32)

def _load_pc(path, n_points):
    m = trimesh.load(path, process=False)
    if isinstance(m, trimesh.Trimesh):
        return _sample_surface(m, n_points)
    geoms = [g for g in getattr(m, "geometry", {}).values()]
    merged = trimesh.util.concatenate(geoms)
    return _sample_surface(merged, n_points)

def save_n_per_class(root_dir, out_dir, split, n_per_class, n_points=2048, normalize="minmax", seed=42, labels=None):
    rng = np.random.default_rng(seed)
    classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
    class_to_idx = {c: i for i, c in enumerate(classes)}
    if labels is not None:
        chosen, seen = [], set()
        lower_map = {c.lower(): c for c in classes}
        for lb in labels:
            if isinstance(lb, int):
                if 0 <= lb < len(classes):
                    c = classes[lb]
                    if c not in seen:
                        chosen.append(c); seen.add(c)
            else:
                key = str(lb).lower()
                if key in lower_map:
                    c = lower_map[key]
                    if c not in seen:
                        chosen.append(c); seen.add(c)
        used_classes = chosen if len(chosen) > 0 else classes
    else:
        used_classes = classes
    X_list, y_list = [], []
    for cls in used_classes:
        folder = os.path.join(root_dir, cls, split)
        files = sorted(glob.glob(os.path.join(folder, "*.off")) + glob.glob(os.path.join(folder, "*.ply")))
        k = min(n_per_class, len(files))
        if k == 0:
            continue
        pick = rng.permutation(len(files))[:k]
        for idx in pick:
            pc = _load_pc(files[idx], n_points)
            if normalize == "minmax":
                pc = _minmax_neg1_pos1(pc)
            elif normalize == "unit":
                pc = _unit_sphere(pc)
            else:
                pc = pc.astype(np.float32)
            X_list.append(pc)
            y_list.append(class_to_idx[cls])
    if len(X_list) == 0:
        return
    os.makedirs(out_dir, exist_ok=True)
    X = torch.from_numpy(np.stack(X_list, axis=0))
    y = torch.tensor(y_list, dtype=torch.long)
    torch.save(X, os.path.join(out_dir, f"X_{split}.pt"))
    torch.save(y, os.path.join(out_dir, f"y_{split}.pt"))
    torch.save(classes, os.path.join(out_dir, "label_classes.pt"))

if __name__ == "__main__":
    labels = [
        "airplane","bathtub","bed","bench","bookshelf","bottle","bowl","car","chair","cone",
        "cup","curtain","desk","door","dresser","flower_pot","glass_box","guitar","keyboard",
        "lamp","laptop","mantel","monitor","night_stand","person","piano","plant","radio",
        "range_hood","sink","sofa","stairs","stool","table","tent","toilet","tv_stand",
        "vase","wardrobe","xbox"
    ][:10]
    root = "data/ModelNet40"
    out_root = "preprocessed_data/ModelNet40"
    for lb in labels:
        out_dir = os.path.join(out_root, lb)
        save_n_per_class(root, out_dir, "train", n_per_class=100, labels=[lb])
        save_n_per_class(root, out_dir, "test",  n_per_class=10, labels=[lb])
