from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import random
from collections import defaultdict
import itertools

class BalancedBatchSampler(Sampler):
    
    def __init__(self, labels, batch_size):
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.classes = sorted(set(self.labels))
        self.n_classes = len(self.classes)
        
      
        self.samples_per_class = batch_size // self.n_classes
        assert self.samples_per_class > 0, f"Batch size {batch_size} too small for {self.n_classes} classes"
        
        
        self.actual_batch_size = self.samples_per_class * self.n_classes
        
   
        self.class_indices = defaultdict(list)
        for i, label in enumerate(self.labels):
            self.class_indices[label].append(i)
        
        
        min_samples = min(len(indices) for indices in self.class_indices.values())
        self.n_batches = min_samples // self.samples_per_class
        
    def __iter__(self):
   
        class_indices = {
            label: np.random.permutation(indices).tolist()
            for label, indices in self.class_indices.items()
        }
        
       
        for _ in range(self.n_batches):
            batch = []
       
            for label in self.classes:
              
                batch.extend([class_indices[label].pop() for _ in range(self.samples_per_class)])
                
       
            np.random.shuffle(batch)
            yield batch
            
    def __len__(self):
        return self.n_batches
    

class BalancedProxyDataset(Dataset):
    def __init__(self, reference_labels, large_imgs, large_labels, samples_per_class=8000, transform=None):
        self.reference_labels = sorted(set(reference_labels)) 
        self.transform = transform
        
       
        self.class_data = defaultdict(list)
        for img, label in zip(large_imgs, large_labels):
            if label in self.reference_labels:
                self.class_data[label].append(img)
        
       
        if samples_per_class is None:
            samples_per_class = min(len(self.class_data[label]) for label in self.reference_labels)
        
      
        self.samples = []
        for label in self.reference_labels:
      
            selected_samples = self.class_data[label][:samples_per_class]
            for img in selected_samples:
                self.samples.append((img, label))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img, label = self.samples[idx]
        
        if self.transform:
            img = self.transform(img)
            
        return img, label
    
    def get_labels(self):
        
        return [label for _, label in self.samples]


class CustomDataset(Dataset):
    def __init__(self, img, labels, transform=None):
        self.img = img
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
       
        img = self.img[idx]
        label = self.labels[idx]
        
    
        if self.transform:
            img = self.transform(img)
        
        return img, label

