import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union

from PIL import Image
import numpy as np
import pandas as pd
from torchvision.datasets import VisionDataset
import torch

"""
This module defines a custom dataset class WaterbirdsDataset for loading the Waterbirds dataset.
Main features include:
- Loading image data and metadata from a specified path.
- Loading appropriate data based on different dataset splits (train, valid, test).
- Providing transformation functionality for images and labels.
"""

def pil_loader(path: str) -> Image.Image:
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")

class WaterbirdsDataset(VisionDataset):
    def __init__(
        self,
        root: str,
        split: str,
        loader: Callable[[str], Any] = pil_loader,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)

        self.loader = loader
        csv = pd.read_csv(os.path.join(root, 'metadata.csv'))
        split = {'test': 2, 'valid': 1, 'train': 0}[split]
        csv = csv[csv['split'] == split]
        self.samples = [(os.path.join(root, csv.iloc[i]['img_filename']), 
                        np.array([csv.iloc[i]['y'], csv.iloc[i]['place']], dtype=np.int64)) for i in range(len(csv))]

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target
    
    def __len__(self) -> int:
        return len(self.samples)