# -*- coding: utf-8 -*-
import PIL
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path

cur_filepath = Path(__file__).resolve()
cur_dir = cur_filepath.parent
root_dir = cur_dir.parent.parent
data_dir = root_dir / "data"


class Chestxray14Dataset(Dataset):
    def __init__(self, csv_path):
        data_info = pd.read_csv(csv_path)

        self.img_path_list = np.asarray(data_info.iloc[:, 0])
        self.class_list = np.asarray(data_info.iloc[:, 3:])

        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.transform = transforms.Compose(
            [transforms.Resize([224, 224]), transforms.ToTensor(), normalize, ])

    def __getitem__(self, index):
        img_path = data_dir / self.img_path_list[index]
        class_label = self.class_list[index]
        img = PIL.Image.open(img_path).convert('RGB')
        image = self.transform(img)

        return {"image": image, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)


class Covidxcxr4Dataset(Dataset):
    def __init__(self, csv_path):
        data_info = pd.read_csv(csv_path)

        self.img_path_list = np.asarray(data_info.iloc[:, 0])
        self.class_list = np.asarray(data_info.iloc[:, 2:])

        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.transform = transforms.Compose(
            [transforms.Resize([224, 224]), transforms.ToTensor(), normalize, ])

    def __getitem__(self, index):
        img_path = data_dir / self.img_path_list[index]
        class_label = self.class_list[index]
        img = PIL.Image.open(img_path).convert('RGB')
        image = self.transform(img)

        return {"image": image, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)


class ChexpertDataset(Dataset):
    def __init__(self, csv_path):
        data_info = pd.read_csv(csv_path)

        self.img_path_list = np.asarray(data_info.iloc[:, 0])
        self.class_list = np.asarray(data_info.iloc[:, [9, 3, 7, 6, 11]])

        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.transform = transforms.Compose(
            [transforms.Resize([224, 224]), transforms.ToTensor(), normalize, ])

    def __getitem__(self, index):
        img_path = data_dir / self.img_path_list[index]
        class_label = self.class_list[index]
        img = PIL.Image.open(img_path).convert('RGB')
        image = self.transform(img)

        return {"image": image, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)
