#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Point cloud dataset class for PyTorch DataLoader
- Custom Dataset implementation for PLY point cloud files
- Integrates data augmentation, normalization and fixed-point resampling
"""
import os
import glob
import numpy as np
import torch
from torch.utils.data import Dataset

from utils import read_ply_xyz, extract_object_id, augment_point_cloud
from utils import sample_or_pad_points, pad_points_deterministic


class PointCloudDataset(Dataset):
    """
    PyTorch Dataset for point cloud classification (PLY files)
    Dataset structure: root_dir/class_name/*.ply
    """
    def __init__(self, root: str, classes: list, num_points: int = 1024, augmentation: bool = False):
        """
        Initialize PointCloudDataset
        Args:
            root: Root directory of the dataset
            classes: List of class names (e.g., ['Strategy_A', 'Strategy_B'])
            num_points: Fixed number of points to resample for each point cloud
            augmentation: Whether to apply data augmentation (for training only)
        """
        self.root = root
        self.classes = classes
        self.num_points = num_points
        self.augmentation = augmentation
        self.samples = self._load_samples()  # List of (file_path, label_index)

    def _load_samples(self) -> list:
        """
        Load all PLY file paths and their corresponding class labels
        Returns:
            Sorted list of tuples (file_path, label_index)
        """
        samples = []
        for label_idx, cls in enumerate(self.classes):
            cls_dir = os.path.join(self.root, cls)
            if not os.path.isdir(cls_dir):
                continue
            # Get all PLY files in the class directory
            ply_files = glob.glob(os.path.join(cls_dir, '*.ply'))
            for fpath in ply_files:
                samples.append((fpath, label_idx))
        # Sort samples for deterministic loading
        samples.sort()
        return samples

    def __len__(self) -> int:
        """Return total number of samples in the dataset"""
        return len(self.samples)

    def __getitem__(self, idx: int) -> tuple:
        """
        Load single sample by index
        Args:
            idx: Sample index
        Returns:
            Tuple of (point_cloud, label, file_path, object_id)
            - point_cloud: Torch tensor (3, num_points) (channel-first for Conv1d)
            - label: Integer class label
            - file_path: String path to PLY file
            - object_id: String object ID (or __UNKNOWN__{idx} if not extracted)
        """
        fpath, label = self.samples[idx]
        # Read raw point cloud (N, 3)
        pts = read_ply_xyz(fpath)
        # Zero-mean normalization (critical for point cloud processing)
        pts = pts - np.mean(pts, axis=0, keepdims=True)

        # Apply data augmentation (training only)
        if self.augmentation:
            pts = augment_point_cloud(pts)
            pts = sample_or_pad_points(pts, self.num_points)
        # Deterministic resampling (test only)
        else:
            pts = pad_points_deterministic(pts, self.num_points)

        # Convert to channel-first format (3, num_points) for PyTorch Conv1d
        pts = pts.T.astype(np.float32)
        # Extract object ID for contrastive loss
        obj_id = extract_object_id(fpath)
        if obj_id is None:
            obj_id = f"__UNKNOWN__{idx}"

        # Convert to PyTorch tensor and return
        return torch.from_numpy(pts), int(label), fpath, obj_id