
'''
the script is for tabular datasets.
'''
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from base.torchvision_dataset import TorchvisionDataset

class CustomDataset(Dataset):

    def __init__(self, data, labels, pvs):
        self.data = data
        self.labels = labels
        self.pvs = pvs
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        _data = torch.from_numpy(self.data[idx]).type(torch.FloatTensor)
        return _data, self.labels[idx], idx, self.pvs[idx]


class TabularDataset(TorchvisionDataset):

      def __init__(self, root, dataset_name=None, balanced=False, fair_c_ratio=None):
        super().__init__(root)
        self.name='tabular'
        self.n_classes = 2  # 0: normal, 1: outlier

        if balanced:
            path = os.path.join(root, '%s/processed/balanced/' % dataset_name)
        else:
            path = os.path.join(root, '%s/processed/imbalanced/' % dataset_name)

        if fair_c_ratio is not None:
            print(f'Contamination Ratio: [{fair_c_ratio}]')
            path = os.path.join(path, f'contaminated/{fair_c_ratio}')
        # path = os.path.join(root, '%s/processed/' % dataset_name)

        train_data = np.load(os.path.join(path, 'train_data.npy'), allow_pickle = True)
        train_lab = np.zeros((train_data.shape[0]))
        train_pvs = train_data[:, 0]
        train_data = train_data[:, 1:]
        test_data = np.load(os.path.join(path, 'test_data.npy'), allow_pickle = True)
        test_lab = np.load(os.path.join(path, 'test_label.npy'), allow_pickle = True)
        test_pvs = test_data[:, 0]
        test_data = test_data[:, 1:]


        train_data = np.array(train_data, dtype=np.float64)
        test_data = np.array(test_data, dtype=np.float64)

        print('======== dataset info =========')
        print(f'train data: {train_data.shape}')
        print(f'test data: {test_data.shape}')
        print('======== dataset info =========')

        ## scale
        mean = np.mean(train_data, 0)
        std = np.std(train_data, 0)

        train_data = (train_data - mean) / (std + 1e-6)
        test_data = (test_data - mean) / (std + 1e-6)
        

        self.train_set = CustomDataset(train_data, train_lab, train_pvs)

        self.test_set = CustomDataset(test_data, test_lab, test_pvs)