import os
import pickle
from typing import List, Tuple, Dict, Any

import numpy as np
import pandas as pd
from tqdm import tqdm


# ================== Config ================== #

RAW_DIR = "/home/data/AOPHand"

OUT_ROOT = "/home/data/mm_data_unified/AOPHand_aop"
OUT_PATH = os.path.join(OUT_ROOT, "AOPHand_aop_pre.pkl")

F_MAX = 4           
P_MAX = 64          
FEAT_COLS = ["x", "y", "z", "velocity"]
FEAT_DIM = len(FEAT_COLS)
SEED = 42

np.random.seed(SEED)


def ensure_dir(path: str):
    if not os.path.exists(path):
        os.makedirs(path)


def list_classes_and_samples(root: str) -> Tuple[List[str], Dict[str, int], List[Tuple[str, int, str]]]:

    classes = sorted(
        [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
    )
    class2id = {c: i for i, c in enumerate(classes)}

    samples: List[Tuple[str, int, str]] = []
    for cname in classes:
        cdir = os.path.join(root, cname)
        for fname in os.listdir(cdir):
            if fname.lower().endswith(".csv"):
                path = os.path.join(cdir, fname)
                samples.append((path, class2id[cname], cname))
    return classes, class2id, samples


def process_one_csv(csv_path: str) -> np.ndarray:

    df = pd.read_csv(csv_path)


    if len(df) == 0:
        return np.zeros((F_MAX, P_MAX, FEAT_DIM), dtype=np.float32)


    if not {"frame", *FEAT_COLS}.issubset(df.columns):
        raise ValueError(
            f"[AOPHand-pre] CSV {csv_path} "
            f"'frame' + {FEAT_COLS}: {list(df.columns)}"
        )

    frames: List[Tuple[int, np.ndarray]] = []
    for fid, df_frame in df.groupby("frame"):
        pts = df_frame[FEAT_COLS].values.astype(np.float32)  # (n_pts, D)


        n = len(pts)
        if n > P_MAX:
            idx = np.linspace(0, n - 1, P_MAX).astype(int)
            pts = pts[idx]
        elif n < P_MAX:
            pad = np.zeros((P_MAX - n, FEAT_DIM), dtype=np.float32)
            pts = np.vstack([pts, pad])

        frames.append((fid, pts))


    frames.sort(key=lambda x: x[0])
    frames_only = [f[1] for f in frames]  # list of (P_MAX, FEAT_DIM)


    F = len(frames_only)
    if F >= F_MAX:
        idx = np.linspace(0, F - 1, F_MAX).astype(int)
        frames_only = [frames_only[i] for i in idx]
    else:
        pad_frame = np.zeros((P_MAX, FEAT_DIM), dtype=np.float32)
        frames_only += [pad_frame] * (F_MAX - F)

    arr = np.stack(frames_only, axis=0)  # (F_MAX, P_MAX, FEAT_DIM)
    return arr


def compute_mean_std(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:

    N, F, P, D = X.shape
    flat = X.reshape(-1, D)  # (N*F*P, D)
    mean = flat.mean(axis=0)
    std = flat.std(axis=0) + 1e-6
    return mean.astype(np.float32), std.astype(np.float32)


def normalize(X: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:

    return (X - mean.reshape(1, 1, 1, -1)) / std.reshape(1, 1, 1, -1)


# ================== Build Dataset ================== #

def build_dataset() -> Tuple[np.ndarray, np.ndarray, List[Dict[str, Any]], Dict[str, int]]:
    classes, class2id, samples = list_classes_and_samples(RAW_DIR)
    print("[AOPHand-pre] Classes:", classes)
    print(f"[AOPHand-pre] Total CSV files: {len(samples)}")

    data_list: List[np.ndarray] = []
    label_list: List[int] = []
    meta_list: List[Dict[str, Any]] = []

    for path, cid, cname in tqdm(samples, desc="Processing AOPHand CSV"):
        arr = process_one_csv(path)  # (F_MAX, P_MAX, FEAT_DIM)
        data_list.append(arr)
        label_list.append(cid)
        meta_list.append({
            "path": path,
            "class_name": cname,
            "label_id": cid,
        })

    data = np.stack(data_list, axis=0).astype(np.float32)   # (N, F_MAX, P_MAX, D)
    labels = np.array(label_list, dtype=np.int64)           # (N,)
    N = data.shape[0]

    print(f"[AOPHand-pre] Raw data shape: {data.shape}, labels shape: {labels.shape}")
    print(f"[AOPHand-pre] Num classes = {len(classes)}, N = {N}")
    return data, labels, meta_list, class2id


def main():
    ensure_dir(OUT_ROOT)

    X, y, meta, label_map = build_dataset()

    coord_mean, coord_std = compute_mean_std(X)
    print("[coord_mean]", coord_mean)
    print("[coord_std ]", coord_std)

    obj = {
        "data": X.astype(np.float32),           # (N, F_MAX, P_MAX, D)
        "labels": y,              # (N,)
        "meta": meta,             # list[dict]

        "coord_mean": coord_mean, # (D,)
        "coord_std": coord_std,   # (D,)
        "max_frames": F_MAX,
        "max_points": P_MAX,
        "feat_dim": FEAT_DIM,

        "label_map": label_map,   # class_name -> id
        "num_classes": len(label_map),
    }

    with open(OUT_PATH, "wb") as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)

    print(f"[AOPHand-pre] Saved unified point-cloud pkl → {OUT_PATH}")


if __name__ == "__main__":
    main()
