from pathlib import Path

import numpy as np
import pandas as pd
from torch.utils.data import Dataset

cur_filepath = Path(__file__).resolve()
cur_dir = cur_filepath.parent
root_dir = cur_dir.parent.parent
data_dir = root_dir / "data"


class Chestxray14Dataset(Dataset):
    def __init__(self, csv_path: Path):
        data_info = pd.read_csv(csv_path)
        self.img_path_list = np.array([data_dir / img_path for img_path in np.asarray(data_info.iloc[:, 0])])
        self.class_list = np.asarray(data_info.iloc[:, 3:])

    def __getitem__(self, index):
        img_path = Path(self.img_path_list[index])
        class_label = self.class_list[index]

        return {"img_path": img_path, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)


class Covidxcxr4Dataset(Dataset):
    def __init__(self, csv_path: Path):
        data_info = pd.read_csv(csv_path)
        self.img_path_list = np.array([data_dir / img_path for img_path in np.asarray(data_info.iloc[:, 0])])
        self.class_list = np.asarray(data_info.iloc[:, 2:])

    def __getitem__(self, index):
        img_path = Path(self.img_path_list[index])
        class_label = self.class_list[index]

        return {"img_path": img_path, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)


class ChexpertDataset(Dataset):
    def __init__(self, csv_path: Path):
        data_info = pd.read_csv(csv_path)
        self.img_path_list = np.array([data_dir / img_path for img_path in np.asarray(data_info.iloc[:, 0])])
        self.class_list = np.asarray(data_info.iloc[:, [9, 3, 7, 6, 11]])

    def __getitem__(self, index):
        img_path = Path(self.img_path_list[index])
        class_label = self.class_list[index]

        return {"img_path": img_path, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)
