# @Author  : Peizhao Li
# @Contact : peizhaoli05@gmail.com

from PIL import Image
import numpy as np

import torch
from torch.utils import data
from torchvision import transforms
import pandas as pd
import os.path as path
from sklearn.preprocessing import scale, StandardScaler, MaxAbsScaler

min_max_scaler = MaxAbsScaler()

kwargs = {"shuffle": True, "num_workers": 16, "pin_memory": True}


class digital(data.Dataset):
    # with size in 32x32
    def __init__(self, subset, transform=None, transform_aug = None, aug = False):
        file_dir = "./data/{}.txt".format(subset)
        self.data_dir = open(file_dir).readlines()
        self.transform = transform
        self.transform_aug = transform_aug
        self.aug = aug

    def __getitem__(self, index):
        img_dir, label = self.data_dir[index].split()
        img = Image.open(img_dir)
        label = torch.tensor(np.int64(label)).long()
        
        if self.transform is not None:
            if self.aug == True:
                img1 = self.transform(img)
                img2 = self.transform_aug(img)
                return img1, img2, label
            else:
                img = self.transform(img)
                return img, label

    def __len__(self):
        return len(self.data_dir)

def get_digital(args, subset, reverse = False, aug = False):
    
    if 'mnist' in subset:
        transform = transforms.Compose([
                        transforms.Resize(32), transforms.ToTensor(),
                        transforms.Normalize((0.1341,), (0.3026,))
                                        ])

        if aug:
            transform_aug = transforms.Compose([
                        transforms.Resize(32), 
                        transforms.RandomAffine(degrees=25, translate=(0.1,0.1), scale=(0.8, 1.1)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1341,), (0.3026,))
                                        ])
        else:
            transform_aug = None
            
    elif 'usps' in subset:
        transform = transforms.Compose([
                        transforms.Resize(32), transforms.ToTensor(),
                        transforms.Normalize((0.1341,), (0.3026,))
                                        ])

        if aug:
            transform_aug = transforms.Compose([
                        transforms.Resize(32), 
                        transforms.RandomAffine(degrees=25, translate=(0.1,0.1), scale=(0.8, 1.1)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1341,), (0.3026,))
                                        ])
        else:
            transform_aug = None
            
#     if 'mnist' in subset:
#         transform = transforms.Compose([
#                         transforms.Resize(32), transforms.ToTensor(),
#                         transforms.Normalize((0.1341,), (0.3026,))
#                                         ])

#         if aug:
#             transform_aug = transforms.Compose([
#                         transforms.Resize(32), 
#                         transforms.RandomAffine(degrees=25, translate=(0.1,0.1), scale=(0.8, 1.1)),
#                         transforms.ToTensor(),
#                         transforms.Normalize((0.1341,), (0.3026,))
#                                         ])
#         else:
#             transform_aug = None
            
#     elif 'usps' in subset:
#         transform = transforms.Compose([
#                         transforms.Resize(32), transforms.ToTensor(),
#                         transforms.Normalize((0.2592,), (0.3751,))
#                                         ])

#         if aug:
#             transform_aug = transforms.Compose([
#                         transforms.Resize(32), 
#                         transforms.RandomAffine(degrees=25, translate=(0.1,0.1), scale=(0.8, 1.1)),
#                         transforms.ToTensor(),
#                         transforms.Normalize((0.2592,), (0.3751,))
#                                         ])
#         else:
#             transform_aug = None
   
    data = digital(subset, transform=transform, transform_aug = transform_aug, aug = aug)
    
    data_loader = torch.utils.data.DataLoader(
        dataset=data,
        batch_size=args.bs,
        **kwargs
    )

    return data_loader
    
    
def mnist_usps(args, aug = False):
    train_0 = get_digital(args, "train_mnist", aug = aug)
    train_1 = get_digital(args, "train_usps", aug = aug)
    train_data = [train_0, train_1]
    return train_data


def mnist_reverse(args):
    train_0 = get_digital(args, "train_mnist")
    train_1 = get_digital(args, "train_reverse_mnist", reverse = True)
    train_data = [train_0, train_1]

    return train_data


class FaceLandmarksDataset(data.Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data = pd.read_csv(csv_file, sep=" ", header=None,
                                           names=["#image path", "#x1","#x2","#x3","#x4","#x5","#y1","#y2","#y3"
                                               ,"#y4","#y5","#gender"," #smile", "#wearing glasses", "#head pose"])
        
        self.transform = transforms.Compose(
                       [transforms.Resize((224, 224)),
                        transforms.ToTensor(),
                        transforms.Normalize((89.93/255, 99.5/255, 119.78/255), (1., 1., 1.))])
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_dir, sens, label = path.join('./data/MTFL/', self.data['#image path'][index].replace('\\', '/')), self.data['#wearing glasses'][index], self.data['#gender'][index]
        img = Image.open(img_dir).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
            
        label = torch.tensor(np.int64(label)).long()
        sens = torch.tensor(np.int64(sens)).long()
        return img, sens, label

class TabDataset(data.Dataset):
    def __init__(self, dataset, sens_idx):
        self.label = dataset.labels.squeeze(-1).astype(int)
        
        self.feature_size = dataset.features.shape[1]
        sens_loc = np.zeros(self.feature_size).astype(bool)
        if isinstance(sens_idx, list):
            for sens in sens_idx:
                sens_loc[sens] = 1
        else:
            sens_loc[sens_idx] = 1

        self.feature = dataset.features[:,~sens_loc] #data without sensitive
        self.feature = min_max_scaler.fit_transform(self.feature)
        
        self.sensitive = dataset.features[:,sens_loc]
        #n_values = int(np.max(self.label) + 1)
        #self.label = np.eye(n_values)[self.label.astype(int)].squeeze(1)
        self.enc = dict()
        for i, idx in enumerate(np.unique(self.sensitive, axis = 0)):
            self.enc[str(idx)] = i   
            
    def __getitem__(self, idx):
        y = self.label[idx]
        x = self.feature[idx] + np.random.normal(0, 1e-6, self.feature.shape[1])
        a = self.enc[str(self.sensitive[idx])]
        
        return x, a, y
    
 
    def __len__(self):
        return len(self.label)