import os
import numpy as np
# from torch.utils.data.dataset import Dataset
# from PIL import ImageFile
# import pandas as pd
# from torchvision.transforms import v2
# ImageFile.LOAD_TRUNCATED_IMAGES = True

# class FlickrDataset(Dataset):
#     def __init__(self, X, Y, data_path, transform=None):
#         self.X = X
#         self.Y = Y
#         self.transform = transform
#         self.data_path = data_path

#     def __getitem__(self, index):
#         x = self.X[index]
#         x = self.transform(x)
#         y = self.Y[index]
#         return x, y

#     def __len__(self):
#         return len(self.X)
    
# class TwitterDataset(Dataset):
#     def __init__(self, X, Y, data_path, transform=None):
#         self.X = X
#         self.Y = Y
#         self.transform = transform
#         self.data_path = data_path

#     def __getitem__(self, index):
#         x = self.X[index]
#         x = self.transform(x)
#         y = self.Y[index]
#         return x, y

#     def __len__(self):
#         return len(self.X)
    

# class FBP5500Dataset(Dataset):
#     def __init__(self, X, Y, data_path, transform=None):
#         self.X = X
#         self.Y = Y
#         self.transform = transform
#         self.data_path = data_path

#     def __getitem__(self, index):
#         x = self.X[index]
#         x = self.transform(x)
#         y = self.Y[index]
#         return x, y

#     def __len__(self):
#         return len(self.X)
    

# class RAFDataset(Dataset):
#     def __init__(self, X, Y, data_path, transform=None):
#         self.X = X
#         self.Y = Y
#         self.transform = transform
#         self.data_path = data_path

#     def __getitem__(self, index):
#         x = self.X[index]
#         x = self.transform(x)
#         y = self.Y[index]
#         return x, y

#     def __len__(self):
#         return len(self.X)
    

# class Emotion6Dataset(Dataset):
#     def __init__(self, X, Y, data_path, transform=None):
#         self.X = X
#         self.Y = Y
#         self.transform = transform
#         self.data_path = data_path

#     def __getitem__(self, index):
#         x = self.X[index]
#         x = self.transform(x)
#         y = self.Y[index]
#         return x, y

#     def __len__(self):
#         return len(self.X)
    


# HANDLER_DICT = {
#     'flickr': FlickrDataset,
#     'twitter': TwitterDataset,
#     'raf': RAFDataset,
#     'emotion6': Emotion6Dataset,
#     'fbp5500': FBP5500Dataset
# }


# def load_data(args):
#     data = {}
#     phases = ['train_label', 'train_unlabel', 'val', 'test']
#     dataset_dir = args.dataset_dir
#     for phase in phases:
#         raw_data = pd.read_csv(os.path.join(dataset_dir, phase+'_data.csv'))
#         data[phase] = {}
#         data[phase]['labels'] = raw_data.iloc[:, 1:].values
#         data[phase]['images'] = raw_data.iloc[:, 0].values
#     return data


# def get_datasets(args):
#     transform_train = v2.Compose([
#         v2.ToTensor()]
#     )

#     val_transform = v2.Compose([
#         v2.ToTensor()]
#     )
#     data = load_data(args)
#     data_handler = HANDLER_DICT[args.dataset_name]
#     train_label_dataset = data_handler(data['train_label']['images'], data['train_label']['labels'], args.dataset_dir, transform=transform_train)
#     train_unlabel_dataset = data_handler(data['train_unlabel']['images'], data['train_unlabel']['labels'], args.dataset_dir, transform=transform_train)
#     val_dataset = data_handler(data['val']['images'], data['val']['labels'], args.dataset_dir, transform=val_transform)
#     test_dataset = data_handler(data['test']['images'], data['test']['labels'], args.dataset_dir, transform=val_transform)
    
#     return train_label_dataset, train_unlabel_dataset, val_dataset, test_dataset

def get_datasets(args):
    train_label_data = np.load(os.path.join(args.dataset_dir, args.dataset_name + '_train_label_data.npy'))
    train_label_label = np.load(os.path.join(args.dataset_dir, args.dataset_name + '_train_label_label.npy'))
    train_unlabel_data = np.load(os.path.join(args.dataset_dir, args.dataset_name + '_train_unlabel_data.npy'))
    train_unlabel_label = np.load(os.path.join(args.dataset_dir, args.dataset_name + '_train_unlabel_label.npy'))
    val_data = np.load(os.path.join(args.dataset_dir, args.dataset_name + '_val_data.npy'))
    val_label = np.load(os.path.join(args.dataset_dir, args.dataset_name + '_val_label.npy'))
    test_data = np.load(os.path.join(args.dataset_dir, args.dataset_name + '_test_data.npy'))
    test_label = np.load(os.path.join(args.dataset_dir, args.dataset_name + '_test_label.npy'))
    return train_label_data, train_label_label, train_unlabel_data, train_unlabel_label, val_data, val_label, test_data, test_label