"""S3DIS Dataset for Inference Only"""

import os
import torch
import numpy as np
from torch.utils.data import Dataset


class S3DISDataset(Dataset):
    """S3DIS Dataset - Inference Only"""

    CLASSES = [
        'ceiling', 'floor', 'wall', 'beam', 'column',
        'window', 'door', 'table', 'chair', 'sofa',
        'bookcase', 'board', 'clutter'
    ]

    def __init__(self, data_root, area=5, voxel_size=0.04):
        self.data_root = data_root
        self.area = area
        self.voxel_size = voxel_size
        self.data_list = self._load_file_list()

        print(f"S3DIS Area {area}: {len(self.data_list)} scenes loaded")

    def _load_file_list(self):
        """Load all .pth files from Area directory"""
        area_dir = os.path.join(self.data_root, f'Area_{self.area}')

        if not os.path.exists(area_dir):
            raise FileNotFoundError(f"Area directory not found: {area_dir}")

        data_list = []
        for filename in sorted(os.listdir(area_dir)):
            if filename.endswith('.pth'):
                filepath = os.path.join(area_dir, filename)
                data_list.append({
                    'filepath': filepath,
                    'scene_name': filename.replace('.pth', '')
                })

        if len(data_list) == 0:
            raise ValueError(f"No .pth files found in {area_dir}")

        return data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        """Load single scene"""
        item = self.data_list[idx]
        data = torch.load(item['filepath'])

        # Extract data
        coords = data['coords']
        features = data.get('features', data.get('colors'))
        labels = data.get('labels', data.get('sem_labels'))

        # Convert to numpy
        if isinstance(coords, torch.Tensor):
            coords = coords.numpy()
        if isinstance(features, torch.Tensor):
            features = features.numpy()
        if isinstance(labels, torch.Tensor):
            labels = labels.numpy()

        # Center coordinates
        coords = coords - coords.mean(axis=0)

        # Normalize RGB to [0,1]
        if features.max() > 1.0:
            features = features / 255.0

        # Voxel downsampling
        if self.voxel_size > 0:
            coords, features, labels = self._voxel_downsample(
                coords, features, labels, self.voxel_size
            )

        # Convert to tensor
        coords = torch.from_numpy(coords).float()
        features = torch.from_numpy(features).float()
        labels = torch.from_numpy(labels).long()

        return {
            'coords': coords,
            'features': features,
            'labels': labels,
            'scene_name': item['scene_name']
        }

    def _voxel_downsample(self, coords, features, labels, voxel_size):
        """Simple voxel downsampling"""
        voxel_coords = np.floor(coords / voxel_size).astype(np.int32)
        _, unique_indices = np.unique(voxel_coords, axis=0, return_index=True)

        coords = coords[unique_indices]
        features = features[unique_indices]
        labels = labels[unique_indices]

        return coords, features, labels


def collate_fn_s3dis(batch):
    """Collate function for S3DIS"""
    if len(batch) == 1:
        return batch[0]

    # Concatenate multiple scenes
    coords = torch.cat([item['coords'] for item in batch], dim=0)
    features = torch.cat([item['features'] for item in batch], dim=0)
    labels = torch.cat([item['labels'] for item in batch], dim=0)
    scene_names = [item['scene_name'] for item in batch]

    offsets = [0]
    for item in batch:
        offsets.append(offsets[-1] + len(item['coords']))

    return {
        'coords': coords,
        'features': features,
        'labels': labels,
        'scene_names': scene_names,
        'offsets': torch.tensor(offsets)
    }
