from __future__ import annotations
import os, sys, copy, yaml
import numpy as np
import torch
from torch.utils.data import Dataset

from .build import DATASETS  # OpenPoints registry

import glob

def _local_create_scene_dataset(cfg_dir: str, set_name: str):
    class LocalSceneDataset:
        def __init__(self, cfg_dir, set_name):
            txt_path = os.path.join(cfg_dir, f"{set_name.lower()}_list.txt")
            alt_txt = os.path.join(cfg_dir, "data_list.txt")
            npy_path = os.path.join(cfg_dir, f"{set_name.lower()}_list.npy")
            if os.path.isfile(txt_path):
                with open(txt_path, 'r') as f:
                    files = [l.strip() for l in f if l.strip()]
            elif os.path.isfile(npy_path):
                files = list(np.load(npy_path))
            elif os.path.isfile(alt_txt):
                with open(alt_txt, 'r') as f:
                    files = [l.strip() for l in f if l.strip()]
            else:
                cand = glob.glob(os.path.join(cfg_dir, 'raw', '*.npy')) + glob.glob(os.path.join(cfg_dir, '*.npy'))
                files = [os.path.abspath(p) for p in sorted(cand)]
            self.files = files
        def __len__(self):
            return len(self.files)
        def __getitem__(self, idx):
            path = self.files[idx]
            arr = np.load(path).astype(np.float32)
            pos = arr[:, :3]
            feats = arr[:, 3:-1] if arr.shape[1] > 4 else arr[:, 3:6] if arr.shape[1] >= 7 else np.zeros((arr.shape[0], 0), dtype=np.float32)
            labels = arr[:, -1].astype(np.int64)
            return {'points': pos, 'feats': feats, 'labels': labels}
    return LocalSceneDataset(cfg_dir, set_name)

@DATASETS.register_module()
class Urban3DSegBase(Dataset):
    """Dataset wrapper exposing PointSegBase SceneDataset as OpenPoints style."""

    def __init__(self,
                 cfg_dir: str,
                 set_name: str = 'TRAIN',
                 transform=None,
                 **kwargs):
        """Parameters
        ----------
        cfg_dir : str
            Directory containing PointSegBase YAML files (dataset.yaml,
            data_list/, etc.).
        set_name : str, default 'TRAIN'
            One of 'TRAIN', 'VALIDATION', 'TEST' (case-insensitive).
        transform : callable, optional
            OpenPoints transform pipeline applied after raw sample retrieval.
        """
        super().__init__()
        set_name = set_name.upper()
        assert set_name in {'TRAIN', 'VALIDATION', 'TEST'}, f"Invalid set_name {set_name}"
        self.transform = transform
        self.variable = True

        # 1) Load config directory
        if not os.path.isdir(cfg_dir):
            raise FileNotFoundError(f'cfg_dir not found: {cfg_dir}')

        # Build underlying SceneDataset (local minimal replacement)
        self.num_classes = 13
        self.classes = [str(i) for i in range(self.num_classes)]
        self.inner_ds = _local_create_scene_dataset(cfg_dir, set_name)
        self.inner_ds.voxelized = False

        # Optional: build leaf→coarse label mapping from hierarchy matrices (look for yaml in cfg_dir)
        self._gather_ids = None
        candidates = ['h_matrix_list.yaml', 'matrix_file_list.yaml', 'h_list.yaml']
        h_list_yaml = None
        for cand in candidates:
            p = os.path.join(cfg_dir, cand)
            if os.path.isfile(p):
                h_list_yaml = p
                break
        if h_list_yaml:
            try:
                with open(h_list_yaml, 'r') as yf:
                    file_list = yaml.load(yf, Loader=yaml.FullLoader).get('file_list', [])
                self._gather_ids = []
                for f in file_list:
                    f_abs = f if os.path.isabs(f) else os.path.join(cfg_dir, f)
                    if not os.path.isfile(f_abs):
                        raise FileNotFoundError(f'Hierarchy CSV not found: {f_abs}')
                    m = np.loadtxt(f_abs, delimiter=',')
                    self._gather_ids.append(np.argmax(m, axis=0).astype(np.int64))
            except Exception as e:
                print(f'[Urban3DSegBase] Warning: cannot load hierarchy matrices ({e}). Proceed without stacking.')
                self._gather_ids = None

    # -------------------- standard torch Dataset API --------------------
    def __len__(self):
        return len(self.inner_ds)

    def __getitem__(self, idx: int):
        """Fetch one sample; retry if underlying dataset returns None or
        malformed data (e.g. points==None)."""
        for _ in range(3):   # retry a few times before giving up
            sample = self.inner_ds[idx]
            if sample is None:
                idx = np.random.randint(len(self))
                continue
            # Required keys must exist and not be None
            if sample.get('points') is None or sample.get('labels') is None:
                idx = np.random.randint(len(self))
                continue
            break
        else:
            raise RuntimeError(f"Failed to fetch a valid sample after retries (original idx={idx}).")

        pos_np   = sample['points']          # (N,3)
        feat_np  = sample['feats']           # (N,C)
        label_np = sample['labels']          # (N,) fine labels or (N,*).

        # If gather_ids are available and label is 1-D (leaf labels), build stacked hierarchy
        if self._gather_ids is not None and label_np.ndim == 1:
            leaf = label_np.astype(np.int64)
            # Stack hierarchy: for each mapping array gid, map leaf labels
            stacked = [gid[leaf] for gid in self._gather_ids]
            label_np = np.stack(stacked, axis=1)  # shape (N, L)
        
        out = {
            'pos': pos_np.astype(np.float32),
            'x'  : feat_np.astype(np.float32),
            'y'  : label_np.astype(np.int64)
        }

        if self.transform is not None:
            out = self.transform(out)
        return out

    # -------------------- OpenPoints style collate --------------------
    def collate_fn(self, batch):
        coord_list, feat_list, label_list, count = [], [], [], 0
        for b in batch:
            n = b['pos'].shape[0]
            # Accept np.ndarray or Tensor
            coord_tensor = torch.as_tensor(b['pos']) if not torch.is_tensor(b['pos']) else b['pos']
            feat_tensor  = torch.as_tensor(b['x'])  if not torch.is_tensor(b['x'])  else b['x']
            label_tensor = torch.as_tensor(b['y']) if not torch.is_tensor(b['y']) else b['y']

            coord_list.append(coord_tensor)
            feat_list.append(feat_tensor)
            label_list.append(label_tensor)
            count += n
        coord  = torch.cat(coord_list, dim=0)
        feat   = torch.cat(feat_list , dim=0)
        label  = torch.cat(label_list, dim=0)
        offset = torch.cumsum(torch.IntTensor([b['pos'].shape[0] for b in batch]), dim=0)
        return {'pos': coord, 'x': feat, 'y': label, 'offset': offset}


