import torch
import pickle
import numpy as np
from PIL import Image
import pickle
from datasets.BaseDataset import BaseDataset


class Fitz17k(BaseDataset):
    def __init__(self, dataframe, path_to_pickles, sens_name, sens_classes, transform, no_return_idx = False):
        super(Fitz17k, self).__init__(dataframe, path_to_pickles, sens_name, sens_classes, transform)
        
        self.no_return_idx = no_return_idx
        
        with open(path_to_pickles, 'rb') as f: 
            self.tol_images = pickle.load(f)
            
        if self.sens_name == 'skin_type':
            self.A = np.asarray(self.dataframe['skin_binary'].values != 0).astype('float')
        else:
            raise ValueError('Please check the sensitive attributes.')
        
        self.Y = (np.asarray(self.dataframe['binary_label'].values) > 0).astype('float')
        self.AY_proportion = None
        
    def __getitem__(self, idx):
        item = self.dataframe.iloc[idx]

        img = Image.fromarray(self.tol_images[idx])

        img = self.transform(img)

        label = torch.FloatTensor([int(item['binary_label'])])
        
        if self.sens_name == 'skin_type':
            if self.sens_classes == 2:
                sensitive = int(item['skin_binary'])
            else:
                sensitive = int(item['skin_type'])
        else:
            raise ValueError('Please check the sensitive attributes.')
                               
        if self.no_return_idx:
            return img, label, sensitive
        else:
            return idx, img, label, sensitive