# this datasets is from huggingface
from matplotlib import pyplot as plt
from datasets import load_dataset
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
from .BaseDataset import BaseDataset


class FairFaceDataset(BaseDataset):
    prediction_targets = ["age", "gender", "race"]

    def __init__(self, split="train", transform=None, seed=42, **kwargs):
        # Load the dataset (Example with generic parameters)
        super(FairFaceDataset, self).__init__(
            split, transform, len(self.prediction_targets), seed
        )
        # fix spliting, the dataset does not have 'test', their 'valid' is the 'test', so we need to add real valid set
        if split == "test":
            self.hf_dataset = load_dataset("HuggingFaceM4/FairFace", "0.25")[
                "validation"
            ]
        else:
            splitting_dataset = load_dataset("HuggingFaceM4/FairFace", "0.25")[
                "train"
            ].train_test_split(test_size=0.1, seed=seed)

            if split == "valid":
                self.hf_dataset = splitting_dataset["test"]
            else:
                self.hf_dataset = splitting_dataset["train"]
                self.transform = transform
        self.annotations = list(range(len(self.hf_dataset)))
        np.random.seed(seed)
        np.random.shuffle(self.annotations)
        self.sub_dataset = self.annotations

    def __getitem__(self, idx):
        # Extract data from the HuggingFace dataset
        dataset_idx = self.sub_dataset[idx]
        item = self.hf_dataset[dataset_idx]
        image = item["image"]
        # Apply transforms
        if self.transform:
            image = self.transform(image)

        # binarize the label:
        sample = {
            'image': image,
            "age": 1 if item["age"] in [3, 4] else 0,  # 20-29, 30-39
            "gender": item["gender"],
            "race": 1 if item["race"] in [1, 3, 5] else 0,
        }  # 3: white, 5: latino, 1: indian

        return sample
