import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path

cur_filepath = Path(__file__).resolve()
cur_dir = cur_filepath.parent
root_dir = cur_dir.parent.parent
data_dir = root_dir / "data"

class Chestxray14Dataset(Dataset):
    def __init__(self, csv_path, img_res=512):
        data_info = pd.read_csv(csv_path)
        self.img_path_list = np.asarray(data_info.iloc[:, 0])
        self.class_list = np.asarray(data_info.iloc[:, 3:])

        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.transform = transforms.Compose(
            [transforms.Resize(img_res, interpolation=Image.BICUBIC), transforms.ToTensor(),
             normalize, ])

    def __getitem__(self, index):
        img_path = data_dir / self.img_path_list[index]
        class_label = self.class_list[index]
        img = Image.open(img_path).convert('RGB')
        image = self.transform(img)
        return {"image": image, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)


class ChexpertDataset(Dataset):
    def __init__(self, csv_path, img_res=512):
        data_info = pd.read_csv(csv_path)
        self.img_path_list = np.asarray(data_info.iloc[:, 0])
        self.class_list = np.asarray(data_info.iloc[:, [9, 3, 7, 6, 11]])
        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.transform = transforms.Compose(
            [transforms.Resize([img_res, img_res], interpolation=Image.BICUBIC),
             transforms.ToTensor(), normalize, ])

    def __getitem__(self, index):
        img_path = data_dir / self.img_path_list[index]
        class_label = self.class_list[index]
        img = Image.open(img_path).convert('RGB')
        image = self.transform(img)
        return {"image": image, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)


class Covidxcxr4Dataset(Dataset):
    def __init__(self, csv_path, img_res=512):
        data_info = pd.read_csv(csv_path)
        self.img_path_list = np.asarray(data_info.iloc[:, 0])
        self.class_list = np.asarray(data_info.iloc[:, 2:])
        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.transform = transforms.Compose(
            [transforms.Resize([img_res, img_res], interpolation=Image.BICUBIC),
             transforms.ToTensor(), normalize, ])

    def __getitem__(self, index):
        img_path = data_dir / self.img_path_list[index]
        class_label = self.class_list[index]
        img = Image.open(img_path).convert('RGB')
        image = self.transform(img)
        # print(f"img_path: {img_path}, class_label: {class_label}")
        return {"image": image, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)
