import os
from PIL import Image
from torch.utils.data import Dataset


class DomainNetDataset(Dataset):
    def __init__(self, root, domains, transform=None, split='train'):
        self.samples = []
        self.targets = []
        self.domain_ids = []
        self.transform = transform

        for domain_id, domain in enumerate(domains):
            domain_dir = os.path.join(root, domain)
            list_file = os.path.join(domain_dir, f"{split}.txt")

            if not os.path.exists(list_file):
                raise FileNotFoundError(f"{list_file} not exists")

            with open(list_file, "r") as f:
                lines = f.readlines()

            for line in lines:
                path, label = line.strip().split()
                label = int(label)

                if not os.path.isabs(path):
                    path = os.path.join(domain_dir, path)

                self.samples.append((path, label))
                self.targets.append(label)
                self.domain_ids.append(domain_id)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        domain_id = self.domain_ids[idx]

        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)

        return img, label, domain_id


class DomainNetDataset_10(Dataset):
    def __init__(self, root, domains, transform=None, split="train"):

        self.samples = []
        self.targets = []
        self.domain_ids = []
        self.transform = transform
        allowed_classes = {'bird':0, 'feather':1, 'headphones':2, 'ice_cream':3, 'teapot':4, 'tiger':5, 'whale':6, 'windmill':7, 'wine_glass':8, 'zebra':9}
        self.class_to_idx = allowed_classes
        self.classes = ['bird', 'feather', 'headphones', 'ice_cream', 'teapot', 'tiger', 'whale', 'windmill', 'wine_glass', 'zebra']
        allowed_classnames = set(allowed_classes.keys())

        for domain_id, domain in enumerate(domains):

            domain_dir = os.path.join(root, domain)
            txt_file = os.path.join(domain_dir, f"{domain}_{split}.txt")
            if not os.path.exists(txt_file):
                raise FileNotFoundError(f"Missing {txt_file}")

            with open(txt_file, "r") as f:
                lines = f.readlines()

            for line in lines:
                path, _ = line.strip().split()
                cls_name = os.path.basename(os.path.dirname(path))

                if cls_name not in allowed_classnames:
                    continue

                label = allowed_classes[cls_name]
                img_full = os.path.join(root, path)

                self.samples.append(img_full)
                self.targets.append(label)
                self.domain_ids.append(domain_id)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        label = self.targets[idx]
        domain_id = self.domain_ids[idx]

        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)

        return img, label, domain_id
