import torch
import os
import scipy.io
import numpy as np

def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    scale = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / scale
    return pc, centroid, scale

class ModelNetDataset(torch.utils.data.Dataset):
    def __init__(self, data_folder,sample_num=1024, gt_folder=None):
        super(ModelNetDataset, self).__init__()
        self.paths = [os.path.join(data_folder, i) for i in os.listdir(data_folder) if '.mat' in i]
        '''
        if gt_folder != None:
            self.gtpaths = [os.path.join(gt_folder, os.path.basename(i)[:-4]+'.npy') for i in self.paths]
        else:
            self.gtpaths = None'''
        self.sample_num = sample_num
        self.size = len(self.paths)

    def __getitem__(self, index):
        fpath = self.paths[index % self.size]
        pc = scipy.io.loadmat(fpath)['pc']
        pc, centroid, scale = pc_normalize(pc)
        pc = np.random.permutation(pc)
        '''
        if self.gtpaths != None:
            gtpath = self.gtpaths[index % self.size]
            pc_gt = np.load(gtpath)
            pc_gt = (pc_gt-centroid)/scale
        else:
            pc_gt = np.zeros(1)'''
        return pc[:self.sample_num, :].astype(float), np.zeros(1)

    def __len__(self):
        return self.size


class partialDataset(torch.utils.data.Dataset):
    def __init__(self, data_folder,sample_num=1024):
        super(partialDataset, self).__init__()
        self.paths = [os.path.join(data_folder, i) for i in os.listdir(data_folder)]
        self.sample_num = sample_num
        self.size = len(self.paths)
        print(self.size)

    def __getitem__(self, index):
        fpath = self.paths[index % self.size]
        f = np.load(fpath, allow_pickle=True).item()
        pc = f['pc']
        rgt = f['rgt']
        pc = np.random.permutation(pc)
        while pc.shape[0] < self.sample_num:
            pc = np.repeat(pc,2,axis=0)

        return pc[:self.sample_num, :].astype(float), rgt

    def __len__(self):
        return self.size

