__author__ = 'Qi'
# Created by on 4/11/22.
# from torchtext import data
import pandas as pd
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
from myDataset import preprocess_adult_data
from collections import Counter
import os


class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, csv_path, img_dir, transform=None):
        df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        self.csv_path = csv_path
        self.img_names = df.index.values
        # self.y = df['Male'].values
        self.y = df.values
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))

        if self.transform is not None:
            img = self.transform(img)
        #  Class Wavy_Hair index 33
        #  Attribute Male index 20
        label = self.y[index][33]
        attribute = self.y[index][20]
        return img, label, attribute, index # index, self.img_names[index],

    def __len__(self):
        return self.y.shape[0]


class Dataset(Dataset):
    # 'Characterizes a dataset for PyTorch'
    def __init__(self, x, labels, sensitive_attribute):
        # 'Initialization'
        self.x = x
        self.labels = labels
        self.sensitive_attribute = sensitive_attribute

    def __len__(self):
        # 'Denotes the total number of samples'
        return len(self.labels)

    def __getitem__(self, index):
        # 'Generates one sample of data'
        return self.x[index], int(self.labels[index]), int(self.sensitive_attribute[index]), index




def adult_dataloader(batch_size):
    # For main task, the column number for groud truth is -1.
    # For protected attribute race, the column number is 1

    train_x, val_x, test_x, train_label, val_label, test_label, train_sensitive_label, val_sensitive_label, test_sensitive_label = preprocess_adult_data(
        seed=0)

    # path = "/content/drive/MyDrive/BiasMitigation/RNF/RNF-Fairness/adult/"
    # train_x = np.load(path+"adult_train_x.npy")
    # train_label = np.load(path+"adult_train_label.npy")
    # train_sensitive_label = np.load(path+"adult_train_sensitive_label.npy")

    # val_x = np.load(path+"adult_val_x.npy")
    # val_label = np.load(path+"adult_val_label.npy")
    # val_sensitive_label = np.load(path+"adult_val_sensitive_label.npy")

    # test_x = np.load(path+"adult_test_x.npy")
    # test_label = np.load(path+"adult_test_label.npy")
    # test_sensitive_label = np.load(path+"adult_test_sensitive_label.npy")

    training_set = Dataset(train_x, train_label, train_sensitive_label)
    training_generator = DataLoader(training_set, batch_size=batch_size, shuffle=True)

    val_set = Dataset(val_x, val_label, val_sensitive_label)
    val_generator = DataLoader(val_set, batch_size=batch_size)

    test_set = Dataset(test_x, test_label, test_sensitive_label)
    test_generator = DataLoader(test_set, batch_size=batch_size)

    print("data length of Adults: ", len(train_label), len(val_label), len(test_label))
    # print(
    #     train_label.reshape(-1)
    # )
    # print(
    #     Counter(train_label)
    # )
    print("Class Label: ", Counter(train_label.reshape(-1)), "Attributes Label: ", Counter(train_sensitive_label.reshape(-1)))


    return training_generator, val_generator, test_generator


def meps_dataloader(batch_size):
    # For main task, the column number for groud truth is -1.
    # For protected attribute race, the column number is 1
    label_index = -1
    sensitive_label_index = 1

    train_data = pd.read_csv("./meps/meps_train.csv")
    train_x = train_data.values[:, :-1]
    train_label = train_data.values[:, label_index]
    train_sensitive_label = train_data.values[:, sensitive_label_index]
    training_set = Dataset(train_x, train_label, train_sensitive_label)
    training_generator = DataLoader(training_set, batch_size=batch_size, shuffle=True, drop_last=True)

    val_data = pd.read_csv("./meps/meps_val.csv")
    val_x = val_data.values[:, :-1]
    val_label = val_data.values[:, label_index]
    val_sensitive_label = val_data.values[:, sensitive_label_index]
    val_set = Dataset(val_x, val_label, val_sensitive_label)
    val_generator = DataLoader(val_set, batch_size=batch_size)

    test_data = pd.read_csv("./meps/meps_test.csv")
    test_x = test_data.values[:, :-1]
    test_label = test_data.values[:, label_index]
    test_sensitive_label = test_data.values[:, sensitive_label_index]
    test_set = Dataset(test_x, test_label, test_sensitive_label)
    test_generator = DataLoader(test_set, batch_size=batch_size)

    print("data length of MEPS: ", len(train_label), len(val_label), len(test_label))
    print("Class Label: ", Counter(train_label), "Attributes Label: ",Counter(train_sensitive_label))

    return training_generator, val_generator, test_generator


def celeba_dataloader(batch_size):

    if 'amax' in os.uname()[1]:
        data_dir = '/data/qiqi/celebA/'
    elif 'test-X11DPG-OT' in os.uname()[1]:
        data_dir = '/home/qiuzh/qiqi/celebA/'
    else:
        data_dir ='/dual_data/not_backed_up/CelebA/'

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = CelebaDataset(
        data_dir + 'celeba_attr_train.csv',
        data_dir + 'img_align_celeba/img_align_celeba/',
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    val_dataset = CelebaDataset(data_dir + 'celeba_attr_val.csv', data_dir + 'img_align_celeba/img_align_celeba/',
                                transforms.Compose([
                                    transforms.ToTensor(),
                                  normalize,
                                ]))

    test_dataset = CelebaDataset(data_dir + 'celeba_attr_test.csv', data_dir + 'img_align_celeba/img_align_celeba/',
                                 transforms.Compose([
                                 transforms.ToTensor(),
                                 normalize,
                                 ]))

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=8, pin_memory=True, sampler=train_sampler)

    # print(Counter(train_dataset.y[:, 33]).items())

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size//4, shuffle=False,
        num_workers=8, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size//4, shuffle=False,
        num_workers=8, pin_memory=True)
    return train_loader, val_loader, test_loader