import numpy as np
import torch
from .s3dis import S3DIS
from .build import DATASETS


@DATASETS.register_module()
class S3DISHier(S3DIS):
    """S3DIS dataset that returns hierarchical labels `[coarse, fine]`.

    It reuses all logic from the original :class:`S3DIS` but after obtaining the
    sample, converts the fine label to a 2-column array where the first column
    is a manually specified coarse label (default mapping identical to PointLiBR
    paper) and the second column is the original fine label.
    """

    def __init__(self, fine2coarse=None, **kwargs):
        super().__init__(**kwargs)
        if fine2coarse is None:
            # 13-class → 4-class default mapping (static / opening / furniture / misc)
            fine2coarse = [0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 0, 3]
        self.f2c = np.asarray(fine2coarse, dtype=np.int64)

    # ---------------------------------------------------------------------
    def _stack_label(self, label_arr: np.ndarray):
        coarse = self.f2c[label_arr]
        return np.stack([coarse, label_arr], axis=1)  # shape [N,2]

    def _convert_dict(self, d: dict):
        if 'y' in d and d['y'] is not None:
            y = d['y']
            if torch.is_tensor(y):
                y_np = y.numpy() if y.dtype != torch.int64 else y.cpu().numpy()
                y_stacked = torch.from_numpy(self._stack_label(y_np))
            else:
                y_stacked = self._stack_label(y)
            d['y'] = y_stacked
        return d

    # ------------------------------------------------------------------
    def __getitem__(self, idx):
        sample = super().__getitem__(idx)
        if isinstance(sample, dict):
            return self._convert_dict(sample)
        elif isinstance(sample, tuple):
            new_sample = []
            # The label tensor `y` is typically the last or one of the last
            # elements when returned as a raw tuple from S3DIS. We make a
            # simple guess based on dtype and dimension.
            for s in sample:
                is_label = torch.is_tensor(s) and s.dtype == torch.long and s.dim() == 1
                if is_label:
                    y_np = s.cpu().numpy()
                    y_stacked = torch.from_numpy(self._stack_label(y_np))
                    new_sample.append(y_stacked)
                elif isinstance(s, dict):
                    new_sample.append(self._convert_dict(s))
                else:
                    new_sample.append(s)
            return tuple(new_sample)
        else:
            return sample 