import os
import os.path as osp
import numpy as np
import random
import collections
import torch
import torchvision
from torch.utils import data
from PIL import Image
from torch.utils.data import Sampler
from collections import defaultdict
class DomainNetDataset(data.Dataset):
    def __init__(self, root, image_list = '', transform = None):
        self.root = root
        self.image_list = image_list
        self.transform = transform
        print(osp.join(self.root, 'image_list', self.image_list))
        self.img_ids = [l.strip().split(' ')[0] for l in open(osp.join(self.root,'image_list', self.image_list))]

        self.img_labels = [int(l.strip().split(' ')[1]) for l in open(osp.join(self.root,'image_list', self.image_list))]
        self.domain_labels = [int(l.strip().split(' ')[2]) for l in open(osp.join(self.root,'image_list', self.image_list))]
        self.num_classes = len(np.unique(self.img_labels))
        print('Number of classes: ', self.num_classes)
        
    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, index):

        name = self.img_ids[index]

        image = Image.open(osp.join(self.root, name)).convert('RGB')

        label = self.img_labels[index]

        domain_label = self.domain_labels[index]

        if self.transform:
            image = self.transform(image)

        return image, label, domain_label



class StratifiedDomainBatchSampler(Sampler):
    def __init__(self, domain_labels, batch_size, shuffle=True):
        self.domain_labels = np.array(domain_labels)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.domain_to_indices = defaultdict(list)

        for idx, domain in enumerate(self.domain_labels):
            self.domain_to_indices[domain].append(idx)

        self.num_domains = len(self.domain_to_indices)
        print(f"batch_size: {batch_size}, num_domains: {self.num_domains},can support {batch_size // self.num_domains} samples per domain")
        for id, indices in self.domain_to_indices.items():
            print(f"Domain {id}: {len(indices)} samples,can support {len(indices) // (batch_size // self.num_domains)} batches")
        assert batch_size % self.num_domains == 0, \
            f"batch_size ({batch_size}) must be divisible by number of domains ({self.num_domains})"

        self.batch_size_per_domain = batch_size // self.num_domains
        self.batches = self._create_batches()

    def _create_batches(self):

        domain_indices = {
            d: (random.sample(idxs, len(idxs)) if self.shuffle else idxs[:])
            for d, idxs in self.domain_to_indices.items()
        }

        min_batches = min(len(idxs) // self.batch_size_per_domain for idxs in domain_indices.values())
        batches = []

        for i in range(min_batches):
            batch = []
            for d in domain_indices:
                start = i * self.batch_size_per_domain
                end = (i + 1) * self.batch_size_per_domain
                batch.extend(domain_indices[d][start:end])
            if self.shuffle:
                random.shuffle(batch)
            batches.append(batch)
        return batches

    def __iter__(self):
        if self.shuffle:
            self.batches = self._create_batches()
        return iter(self.batches)

    def __len__(self):
        return len(self.batches)