import os
import h5py
import random, math
from torch.utils.data import Dataset
import torch
from config import processed_data_path
from var import get_mix_prob
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

class ModelNet40(Dataset):
    def __init__(self, num_points=1024, partition='train'):
        paths = processed_data_path
        paths = paths.glob(f"ply_data_{partition}*.h5")
        data, label = [], []
        for p in paths:
            f = h5py.File(p, 'r')
            data.append(torch.from_numpy(f['data'][:]).float())
            label.append(torch.from_numpy(f['label'][:]).long())
            f.close()
        self.data = torch.cat(data)
        self.label = torch.cat(label).squeeze()
        self.num_points = num_points
        self.partition = partition


    def get(self, idx):
        pc = self.data[idx][:self.num_points]
        label = self.label[idx]
        if self.partition == 'train':
            scale = torch.rand((3,)) * (3/2 - 2/3) + 2/3
            pc = pc * scale
            pc = pc[torch.randperm(pc.shape[0])]

        return pc*40, label
        
    def __getitem__(self, idx):
        if self.partition == 'train':
            pca, lba = self.get(idx)
            pcb, lbb = self.get(random.randrange(0, len(self)))
            if random.random() < get_mix_prob():
                crop = random.randint(0, 1024)
            else:
                crop = 0
            pc = torch.cat([pca[crop:], pcb[:crop]], dim=0)
            # noise suppression
            ca = (2048 - crop) / 128 - 2
            ca = (1 - crop / 2048) / (1 + math.exp(-ca))
            cb = crop / 128 - 2
            cb = (crop / 2048) / (1 + math.exp(-cb))
            cs = ca + cb
            label = torch.zeros(40) + 0.2/40
            label[lba.item()] += 0.8 * ca / cs
            label[lbb.item()] += 0.8 * cb / cs
            return pc, label
        else:
            return self.get(idx)

    def __len__(self):
        return self.data.shape[0]
