__author__ = ''
__date__ = '2023/07/26'

'''
get image data for training
'''



from os import path as osp
from base.torchvision_dataset import TorchvisionDataset
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch
import numpy as np
from preparation.data_preprocess import global_contrast_normalization

class CelebADataset(TorchvisionDataset):

    def __init__(self, root, balanced=False):
        super(CelebADataset, self).__init__(root)
        self.name = 'Celeb-A'
        self.n_classes = 2  # 0: normal, 1: outlier

        if balanced:
            path = osp.join(root, 'balanced/')
        else:
            path = osp.join(root, 'imbalanced/')


        train_data = np.load(osp.join(path, 'train_data.npy'), allow_pickle = True)
        train_lab = np.zeros((train_data.shape[0]))
        train_pvs = train_data[:, -1]
        train_img = train_data[:, :-1]
        test_data = np.load(osp.join(path, 'test_data.npy'), allow_pickle = True)
        test_lab = np.load(osp.join(path, 'test_label.npy'), allow_pickle = True)
        test_pvs = test_data[:, -1]
        test_img = test_data[:, :-1]

        print('======== dataset info =========')
        print(f'train data: {train_img.shape}')
        print(f'test data: {test_img.shape}')
        print('======== dataset info =========')

        # Pre-computed min and max values (after applying GCN) from train data
        if balanced:
            _min = -7.36940097
            _max = 7.02800990
        else:
            _min = -6.01566553
            _max = 6.37051535


        transform = transforms.Compose([transforms.ToTensor(), 
                                        transforms.Lambda(lambda x: global_contrast_normalization(x, scale='l1')),
                                        transforms.Normalize([_min] * 3, [_max - _min] * 3)])

        self.train_set = CustomDataset(train_img, train_lab, train_pvs, transform=transform)
        self.test_set = CustomDataset(test_img, test_lab, test_pvs, transform=transform)



class CustomDataset(Dataset):

    def __init__(self, data, labels, pvs, transform=None):
        self.data = data
        self.labels = labels
        self.pvs = pvs
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img = self.data[idx].reshape(64, 64, 3)
        if self.transform:
            img = self.transform(img).type(torch.float32)

        return img, self.labels[idx], idx, self.pvs[idx]

