import torch
import pickle
import numpy as np
from PIL import Image
import os
from collections import Counter
import pickle
from datasets.BaseDataset import BaseDataset


class CXP(BaseDataset):
    def __init__(self, dataframe, path_to_pickles, sens_name, sens_classes, transform, no_return_idx = False):
        super(CXP, self).__init__(dataframe, path_to_pickles, sens_name, sens_classes, transform)
        """
            Dataset class for CheXpert dataset
            
            Arguments:
            dataframe: Whether the dataset represents the train, test, or validation split
            path_to_pickles: Path to the image directory on the server
            sens_name: which sensitive attribute to use, e.g., Sex.
            sens_classes: how many sensitive classes
            transform: Whether conduct transform to the images or not
            no_return_idx: whether return the unique index indicator of the item
            
            Returns:
            (index,) image, label, and sensitive attribute.
        """
        
        self.no_return_idx = no_return_idx
        
        with open(path_to_pickles, 'rb') as f: 
            self.tol_images = pickle.load(f)
        
        if self.sens_name == 'Sex':
            self.A = np.asarray(self.dataframe['Sex'].values != 'M').astype('float')
        elif self.sens_name == 'Age':
            self.A = np.asarray(self.dataframe['Age_binary'].values.astype('int') == 0).astype('float')
        elif self.sens_name == 'Race':
            self.A = np.asarray(self.dataframe['Race'].values != 'White').astype('float')
        
        #self.A = np.asarray(self.dataframe['Sex'].values != 'Male').astype('float') # todo modify age!
        self.Y = (np.asarray(self.dataframe['No Finding'].values) > 0).astype('float')
        self.AY_proportion = None
        
        self.PRED_LABEL = [
            'No Finding',
            'Enlarged Cardiomediastinum',
            'Cardiomegaly',
            'Lung Opacity',
            'Lung Lesion',
            'Edema',
            'Consolidation',
            'Pneumonia',
            'Atelectasis',
            'Pneumothorax',
            'Pleural Effusion',
            'Pleural Other',
            'Fracture',
            'Support Devices'
        ]

    def __getitem__(self, idx):
        item = self.dataframe.iloc[idx]
        
        img = Image.fromarray(self.tol_images[idx]).convert('RGB')
        img = self.transform(img)

        # label = np.zeros(len(self.PRED_LABEL), dtype=int)
        label = torch.FloatTensor(np.zeros(len(self.PRED_LABEL), dtype=float))
        for i in range(0, len(self.PRED_LABEL)):
            if (self.dataframe[self.PRED_LABEL[i].strip()].iloc[idx].astype('float') > 0):
                label[i] = self.dataframe[self.PRED_LABEL[i].strip()].iloc[idx].astype('float')
        
        
        if self.sens_name == 'Sex':
            if item['Sex'] == 'M':
                sensitive = 0
            else:
                sensitive = 1
        elif self.sens_name == 'Age':
            if self.sens_classes == 2:
                sensitive = int(item['Age_binary'])
            elif self.sens_classes == 5:
                sensitive = int(item['Age_multi'])
            elif self.sens_classes == 4:
                sensitive = int(item['Age_multi4'])
                
        elif self.sens_name == 'Race':
            if item['Race'] == 'White':
                sensitive = 0
            else:
                sensitive = 1
                
        if self.no_return_idx:
            return img, label, sensitive
        else:
            return idx, img, label, sensitive