from dataset.augmentations import DomainNetTransform
from PIL import Image
from torch.utils.data import Dataset
from typing import Tuple
import os

class DomainNet(Dataset):
    """Custom PyTorch Dataset for loading and transforming the Domainnet dataset with train-test split."""
    def __init__(self, train: bool = True, image_size: int = 256, transform=None, *args, **kwargs) -> None:
        self.domain_names = ['clipart', 'infograph', 'painting', 'quickdraw', 'real']
        self.root_dir = os.path.join('data', 'domainnet')
        self.data, self.targets = [], []
        self.label_to_id, self.id_to_label = {}, {}
        
        train_files = [os.path.join(self.root_dir, domain_name + '_train.txt') for domain_name in self.domain_names]
        test_files = [os.path.join(self.root_dir, domain_name + '_test.txt') for domain_name in self.domain_names]
        file_list_paths = train_files if train else test_files
        self.transform = DomainNetTransform(image_size) if transform is None else transform
        # self.transform = domainnet_transform.train_transform if train else domainnet_transform.test_transform
        
        # Load file paths and extract labels
        for file_list_path in file_list_paths:
            with open(file_list_path, 'r') as f:
                for line in f:
                    # Each line in the file is "domain/class/image.jpg label"
                    image_path, label = line.strip().split()
                    self.id_to_label[image_path.split('/')[1]] = int(label)
                    self.data.append(image_path)
                    self.targets.append(int(label))

        self.label_to_id = {label:label_id for label_id, label in self.id_to_label.items()}

    def __len__(self) -> int:
        return len(self.targets)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, int]:
        img_path, label = self.data[idx], self.targets[idx]
        # Construct the full path
        full_img_path = os.path.join(self.root_dir, img_path)
        image = Image.open(full_img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


