import os
import random
from PIL import Image
from torch.utils.data import Dataset


class OfficeCaltechDataset(Dataset):
    def __init__(self, root, domains, transform=None, split='train'):
        self.samples = []
        self.targets = []
        self.domain_ids = []
        self.class_to_idx = {}
        self.transform = transform
        random.seed(42)

        all_classes = sorted([
            d for d in os.listdir(os.path.join(root, domains[0]))
            if os.path.isdir(os.path.join(root, domains[0], d))
        ])
        self.classes = all_classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(all_classes)}

        if len(domains) == 4:
            for domain_id, domain in enumerate(domains):
                domain_dir = os.path.join(root, domain)
                for cls in all_classes:
                    cls_dir = os.path.join(domain_dir, cls)
                    if not os.path.exists(cls_dir):
                        continue

                    image_list = [
                        os.path.join(cls_dir, fname)
                        for fname in os.listdir(cls_dir)
                        if fname.lower().endswith(('.jpg', '.png'))
                    ]
                    image_list.sort()
                    random.shuffle(image_list)

                    n_total = len(image_list)
                    n_train = int(n_total * 5 / 6)

                    selected_images = image_list[:n_train] if split == 'train' else image_list[n_train:]
                    label = self.class_to_idx[cls]

                    for img_path in selected_images:
                        self.samples.append((img_path, label))
                        self.targets.append(label)
                        self.domain_ids.append(domain_id)

        else:
            for domain_id, domain in enumerate(domains):
                domain_dir = os.path.join(root, domain)
                for cls in all_classes:
                    cls_dir = os.path.join(domain_dir, cls)
                    if not os.path.exists(cls_dir):
                        continue
                    for fname in os.listdir(cls_dir):
                        if fname.lower().endswith(('.jpg', '.png')):
                            img_path = os.path.join(cls_dir, fname)
                            label = self.class_to_idx[cls]

                            self.samples.append((img_path, label))
                            self.targets.append(label)
                            self.domain_ids.append(domain_id)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        domain_id = self.domain_ids[idx]

        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)

        return img, label, domain_id

