from torch.utils.data import Dataset
import numpy as np
import os

class PolyDataset2D(Dataset):
    def __init__(self, data_path,  train = True,split_ratio = 0.8):

        self.data_path = os.path.join(data_path, 'poly_np.npy')
        self.datapos_path = os.path.join(data_path, 'polypos_np.npy')
        self.datainfo_path = os.path.join(data_path, 'polyinfo_np.npy')
        self.datas = np.load(self.data_path)
        self.datapos = np.load(self.datapos_path)
        self.datainfo = np.load(self.datainfo_path)
        self.split_ratio = split_ratio
        if train:
            self.datas = self.datas[0:int(len(self.datas)*self.split_ratio)]
            self.datapos = self.datapos[0:int(len(self.datapos)*self.split_ratio)]
            self.datainfo = self.datainfo[0:int(len(self.datainfo)*self.split_ratio)]
            print('train_len(self.datas): ',len(self.datas))
        else:
            self.datas = self.datas[int(len(self.datas)*self.split_ratio):int(len(self.datas))]
            self.datapos = self.datapos[int(len(self.datapos)*self.split_ratio):int(len(self.datapos))]
            self.datainfo = self.datainfo[int(len(self.datainfo)*self.split_ratio):int(len(self.datainfo))]
            print('test_len(self.datas): ',len(self.datas))


    def __getitem__(self, index):
        data =self.datas[index]
        datapos = self.datapos[index]
        datainfo = self.datainfo[index]
        return data, datapos, datainfo

    def __len__(self):
        return len(self.datas)

class PolyDatasetRoad(Dataset):
    def __init__(self, data_path, train = True,split_ratio = 0.8):

        self.data_path = os.path.join(data_path, 'poly_np.npy')
        self.datapos_path = os.path.join(data_path, 'polypos_np.npy')
        self.datainfo_path = os.path.join(data_path, 'polyinfo_np.npy')
        self.dataroad_path = os.path.join(data_path, 'polyroad_np.npy')
        self.datas = np.load(self.data_path)
        self.datapos = np.load(self.datapos_path)
        self.datainfo = np.load(self.datainfo_path)
        self.dataroad = np.load(self.dataroad_path)
        self.split_ratio = split_ratio
        if train:
            self.datas = self.datas[0:int(len(self.datas)*self.split_ratio)]
            self.datapos = self.datapos[0:int(len(self.datapos)*self.split_ratio)]
            self.datainfo = self.datainfo[0:int(len(self.datainfo)*self.split_ratio)]
            self.dataroad = self.dataroad[0:int(len(self.dataroad)*self.split_ratio)]
            print('train_len(self.datas): ',len(self.datas))
        else:
            self.datas = self.datas[int(len(self.datas)*self.split_ratio):int(len(self.datas))]
            self.datapos = self.datapos[int(len(self.datapos)*self.split_ratio):int(len(self.datapos))]
            self.datainfo = self.datainfo[int(len(self.datainfo)*self.split_ratio):int(len(self.datainfo))]
            self.dataroad = self.dataroad[int(len(self.dataroad)*self.split_ratio):int(len(self.dataroad))]
            print('test_len(self.datas): ',len(self.datas))

    def __getitem__(self, index):
        data =self.datas[index]
        datapos = self.datapos[index]
        datainfo = self.datainfo[index]
        dataroad = self.dataroad[index]
        return data, datapos, datainfo, dataroad

    def __len__(self):
        return len(self.datas)
    
class PolyDataset3D(Dataset):
    def __init__(self, data_path, train = True, split_ratio = 0.8):

        self.data_path = os.path.join(data_path, 'poly_np.npy')
        self.datapos_path = os.path.join(data_path, 'polypos_np.npy')
        self.datainfo_path = os.path.join(data_path, 'polyinfo_np.npy')
        self.datah_path = os.path.join(data_path, 'polyh_np.npy')
        self.datas = np.load(self.data_path)
        self.datapos = np.load(self.datapos_path)
        self.datainfo = np.load(self.datainfo_path)
        self.datah = np.load(self.datah_path)
        self.split_ratio = split_ratio
        if train:
            self.datas = self.datas[0:int(len(self.datas)*self.split_ratio)]
            self.datapos = self.datapos[0:int(len(self.datapos)*self.split_ratio)]
            self.datainfo = self.datainfo[0:int(len(self.datainfo)*self.split_ratio)]
            self.datah = self.datah[0:int(len(self.datah)*self.split_ratio)]
            print('train_len(self.datas): ',len(self.datas))
        else:
            self.datas = self.datas[int(len(self.datas)*self.split_ratio):int(len(self.datas))]
            self.datapos = self.datapos[int(len(self.datapos)*self.split_ratio):int(len(self.datapos))]
            self.datainfo = self.datainfo[int(len(self.datainfo)*self.split_ratio):int(len(self.datainfo))]
            self.datah = self.datah[int(len(self.datah)*self.split_ratio):int(len(self.datah))]
            print('test_len(self.datas): ',len(self.datas))

    def __getitem__(self, index):
        data =self.datas[index]
        datapos = self.datapos[index]
        datainfo = self.datainfo[index]
        datah = self.datah[index]
        return data, datapos, datainfo, datah

    def __len__(self):
        return len(self.datas)
    
class PolyDatasetClassification(Dataset):
    def __init__(self, data_path, train = True,split_ratio = 0.8):

        self.data_path = os.path.join(data_path, 'poly_np.npy')
        self.datapos_path = os.path.join(data_path, 'polypos_np.npy')
        self.datainfo_path = os.path.join(data_path, 'polyinfo_np.npy')
        self.dataclass_path = os.path.join(data_path, 'polyclass_np.npy')
        self.datas = np.load(self.data_path)
        self.datapos = np.load(self.datapos_path)
        self.datainfo = np.load(self.datainfo_path)
        self.dataclass = np.load(self.dataclass_path)
        
        self.split_ratio = split_ratio
        if train:
            self.datas = self.datas[0:int(len(self.datas)*self.split_ratio)]
            self.datapos = self.datapos[0:int(len(self.datapos)*self.split_ratio)]
            self.datainfo = self.datainfo[0:int(len(self.datainfo)*self.split_ratio)]
            self.dataclass = self.dataclass[0:int(len(self.dataclass)*self.split_ratio)]
            print('train_len(self.datas): ',len(self.datas))
        else:
            self.datas = self.datas[int(len(self.datas)*self.split_ratio):int(len(self.datas))]
            self.datapos = self.datapos[int(len(self.datapos)*self.split_ratio):int(len(self.datapos))]
            self.datainfo = self.datainfo[int(len(self.datainfo)*self.split_ratio):int(len(self.datainfo))]
            self.dataclass = self.dataclass[int(len(self.dataclass)*self.split_ratio):int(len(self.dataclass))]
            print('test_len(self.datas): ',len(self.datas))

    def __getitem__(self, index):
        data =self.datas[index]
        datapos = self.datapos[index]
        datainfo = self.datainfo[index]
        dataclass = self.dataclass[index]
        return data, datapos, datainfo, dataclass

    def __len__(self):
        return len(self.datas)
