#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Useful utility functions for point cloud classification
- Point cloud file reading (PLY)
- Object ID extraction from filenames
- Point cloud sampling/padding/normalization
"""
import os
import re
import numpy as np
from plyfile import PlyData


def extract_object_id(fname: str) -> str | None:
    """
    Extract object ID from PLY filename using regular expressions
    Args:
        fname: Full path to the PLY file
    Returns:
        Extracted object ID string or None if no match
    """
    basename = os.path.basename(fname)
    # First pattern: match view5_<obj_id>_Scale/Rigid/Deformable (case-insensitive)
    m = re.search(r'view5_([^_]+)_(?:Scale|Rigid|Deformable)', basename, flags=re.IGNORECASE)
    if m:
        return m.group(1)
    # Second pattern: match <letters><numbers> format as fallback
    m2 = re.search(r'([A-Za-z]+[0-9]+)', basename)
    if m2:
        return m2.group(1)
    return None


def read_ply_xyz(fname: str) -> np.ndarray:
    """
    Read XYZ coordinates from PLY file (only vertex data)
    Args:
        fname: Full path to PLY file
    Returns:
        Numpy array of shape (N, 3) with float32 dtype (N: number of points)
    """
    ply = PlyData.read(fname)
    vertex = ply['vertex']
    # Stack X/Y/Z channels and convert to float32 for PyTorch compatibility
    point_cloud = np.stack([vertex['x'], vertex['y'], vertex['z']], axis=-1).astype(np.float32)
    return point_cloud


def sample_or_pad_points(pts: np.ndarray, num_points: int) -> np.ndarray:
    """
    Sample fixed number of points from point cloud or pad with repeated points if insufficient
    Args:
        pts: Input point cloud (N, 3)
        num_points: Target number of points to keep
    Returns:
        Resampled point cloud (num_points, 3)
    """
    N = pts.shape[0]
    if N == num_points:
        return pts
    # Downsample if more points than target
    if N > num_points:
        idx = np.random.choice(N, num_points, replace=False)
        return pts[idx]
    # Upsample if fewer points than target
    extra_idx = np.random.choice(N, num_points - N, replace=True)
    return np.concatenate([pts, pts[extra_idx]], axis=0)


def augment_point_cloud(pts: np.ndarray) -> np.ndarray:
    """
    Augment point cloud with scale, Gaussian noise and Z-axis rotation
    Args:
        pts: Input point cloud (N, 3) (zero-mean normalized recommended)
    Returns:
        Augmented point cloud (N, 3)
    """
    # Random scaling (0.9 ~ 1.1)
    scale = np.random.uniform(0.9, 1.1)
    pts = pts * scale
    # Random Gaussian noise (sigma=0.002)
    pts = pts + np.random.normal(scale=0.002, size=pts.shape)
    # Random rotation around Z-axis (0 ~ 2π)
    theta = np.random.uniform(0, 2 * np.pi)
    c, s = np.cos(theta), np.sin(theta)
    rotation_mat = np.array([[c, -s, 0],
                             [s,  c, 0],
                             [0,  0, 1]], dtype=np.float32)
    # Apply rotation (matrix multiplication with transposed rotation mat)
    pts = pts.dot(rotation_mat.T)
    return pts


def pad_points_deterministic(pts: np.ndarray, num_points: int) -> np.ndarray:
    """
    Deterministic padding (no random) for test set: repeat points sequentially
    Args:
        pts: Input point cloud (N, 3)
        num_points: Target number of points to keep
    Returns:
        Deterministically padded point cloud (num_points, 3)
    """
    N = pts.shape[0]
    if N >= num_points:
        return pts[:num_points]
    # Calculate repeat times and remainder
    reps = num_points // N
    rem = num_points % N
    # Build index array for sequential repetition
    idxs = np.concatenate([np.arange(N) for _ in range(reps)]) if reps > 0 else np.array([], dtype=np.int64)
    if rem > 0:
        idxs = np.concatenate([idxs, np.arange(rem)])
    return pts[idxs]