import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union
import re

from PIL import Image
import numpy as np
import pandas as pd
from torchvision.datasets import VisionDataset
import torch


def pil_loader(path: str) -> Image.Image:
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")
    
def process_folders(data_dir):
    folders = [f for f in os.listdir(data_dir) 
                    if os.path.isdir(os.path.join(data_dir, f))]
    
    samples = []
    for folder in folders:
        folder_path = os.path.join(data_dir, folder)
        images_names = sorted([f for f in os.listdir(folder_path) if f.split('.')[0].isdigit() and len(f.split('.')[0]) == 3],
                        key=lambda x: int(x.split('.')[0]))
        
        pattern = r'obj-(\w+)_bg-(\w+)_co_occur_obj-(\w+)'
        match = re.match(pattern, folder)
        if match:
            obj, bg, co_occur_obj = match.groups()
            obj_label = 1 if obj == 'country' else 0
            bg_label = 1 if bg == 'country' else 0
            co_occur_label = 1 if co_occur_obj == 'country' else 0
            label = np.array([obj_label, bg_label, co_occur_label], dtype=np.int64)
        else:
            print("Cannot parse folder name:", folder)
            
        for img in images_names:
            samples.append((os.path.join(folder_path, img), label))        
    return samples

class UrbancarsDataset(VisionDataset):
    def __init__(
        self,
        root: str,
        split: str,
        loader: Callable[[str], Any] = pil_loader,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)

        self.loader = loader
        data_dir = os.path.join(root, split)
        assert os.path.exists(data_dir)
        self.samples = process_folders(data_dir)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)