import numpy as np
import os, sys, h5py
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data

class ScanObjectNN(Dataset):
    def __init__(self, data_dir, subset='train', split='main_split', version='_augmentedrot_scale75', transform=None, nbr_pts=2048):
        super().__init__()
        self.root = data_dir
        self.subset = subset
        self.split = split
        self.version = version
        self.transform = transform
        self.nbr_pts = nbr_pts
        
        if self.subset == 'train':
            h5 = h5py.File(os.path.join(
                self.root, self.split, f'training_objectdataset{self.version}.h5'), 'r')
            self.points = np.array(h5['data']).astype(np.float32)
            self.labels = np.array(h5['label']).astype(int)
            h5.close()
        elif self.subset == 'test':
            h5 = h5py.File(os.path.join(
                self.root, self.split, f'test_objectdataset{self.version}.h5'), 'r')
            self.points = np.array(h5['data']).astype(np.float32)
            self.labels = np.array(h5['label']).astype(int)
            h5.close()
        else:
            raise ValueError()

        if nbr_pts > self.points.shape[1]: raise ValueError()

        self.nbr_classes = 15
        if max(self.labels) != self.nbr_classes - 1: raise ValueError()

        print(f'Successfully loaded ScanObjectNN shape of {self.points.shape}')

    def __getitem__(self, idx):
        pt_idxs = np.random.choice(self.points.shape[1], (self.nbr_pts,), replace=False)
        
        current_points = self.points[idx, pt_idxs].copy()
        current_points = torch.from_numpy(current_points).float()
        label = torch.tensor([self.labels[idx]])

        data = Data(pos=current_points, y=label)
        
        if self.transform is not None:
            return self.transform(data)
        return data

    def __len__(self):
        return self.points.shape[0]