import torch
import pickle
import numpy as np
from PIL import Image
import pickle
from datasets.BaseDataset import BaseDataset


class ADNI(BaseDataset):
    def __init__(self, dataframe, path_to_images, sens_name, sens_classes, transform, no_return_idx = False):
        super(ADNI, self).__init__(dataframe, path_to_images, sens_name, sens_classes, transform)
        """
            Dataset class representing OCT dataset
            
        """
        self.no_return_idx = no_return_idx
        
        if self.sens_name == 'Sex':
            self.A = np.asarray(self.dataframe['Sex'].values == 'F').astype('float')
        elif self.sens_name == 'Age':
            self.A = np.asarray(self.dataframe['Age_binary'].values.astype('int') == 1).astype('float')
            
        self.Y = (np.asarray(self.dataframe['label'].values) > 0).astype('float')
        self.AY_proportion = None
        
    def __getitem__(self, idx):
        item = self.dataframe.iloc[idx]
        
        img = np.load(os.path.join(self.path_to_images, item["Path"]).split('.nii')[0] + '.npy')

        img = self.transform(img)

        label = torch.FloatTensor([item['label']])
        
        if self.sens_name == 'Age':
            if self.sens_classes == 2:
                sensitive = int(item['Age_binary'])
            else:
                sensitive = int(item['Age_multi'])
        elif self.sens_name == 'Sex':
            if item['Sex'] == 'M':
                sensitive = 0
            else:
                sensitive = 1
        if self.no_return_idx:
            return img, label, sensitive
        else:
            return idx, img, label, sensitive