"""ScanObjectNN Dataset for Inference Only"""

import os
import h5py
import torch
import numpy as np
from torch.utils.data import Dataset


class ScanObjectNNDataset(Dataset):
    """ScanObjectNN Dataset - Inference Only"""

    CLASSES = [
        'bag', 'bin', 'box', 'cabinet', 'chair',
        'desk', 'display', 'door', 'shelf', 'table',
        'bed', 'pillow', 'sink', 'sofa', 'toilet'
    ]

    def __init__(
        self,
        data_root,
        variant='PB_T50_RS',
        split='test',
        num_points=1024
    ):
        self.data_root = data_root
        self.variant = variant
        self.split = split
        self.num_points = num_points

        self.data, self.labels = self._load_data()

        print(f"ScanObjectNN {variant} ({split}): {len(self.data)} objects")

    def _load_data(self):
        """Load data from .h5 file"""
        if self.variant == 'PB_T50_RS':
            h5_file = os.path.join(
                self.data_root,
                f'{self.split}_objectdataset_augmentedrot_scale75.h5'
            )
        else:
            h5_file = os.path.join(
                self.data_root,
                f'main_split_{self.split}.h5'
            )

        if not os.path.exists(h5_file):
            raise FileNotFoundError(f"Data file not found: {h5_file}")

        with h5py.File(h5_file, 'r') as f:
            data = f['data'][:].astype(np.float32)
            labels = f['label'][:].astype(np.int64)

        return data, labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """Load single object"""
        points = self.data[idx]
        label = self.labels[idx]

        # Sample to fixed number of points
        if len(points) > self.num_points:
            indices = np.random.choice(len(points), self.num_points, replace=False)
            points = points[indices]
        elif len(points) < self.num_points:
            indices = np.random.choice(len(points), self.num_points, replace=True)
            points = points[indices]

        # Normalize to unit sphere
        points = self._normalize_pc(points)

        coords = torch.from_numpy(points).float()
        label = torch.tensor(label).long()

        return {
            'coords': coords,
            'label': label
        }

    def _normalize_pc(self, points):
        """Normalize point cloud to unit sphere"""
        centroid = np.mean(points, axis=0)
        points = points - centroid

        max_dist = np.max(np.sqrt(np.sum(points**2, axis=1)))
        points = points / max_dist

        return points


def collate_fn_scanobjectnn(batch):
    """Collate function for ScanObjectNN"""
    coords = torch.stack([item['coords'] for item in batch], dim=0)
    labels = torch.stack([item['label'] for item in batch], dim=0)

    return {
        'coords': coords,
        'labels': labels
    }
