from torch.utils.data import Sampler
import random


class SamplerA(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
        self.num_samples = len(data_source) // 2
        self.indices = list(range(len(data_source)))
        random.shuffle(self.indices)
        self.samplerA_indices = self.indices[:self.num_samples]

    def __iter__(self):
        return iter(self.samplerA_indices)

    def __len__(self):
        return self.num_samples

class SamplerB(Sampler):
    def __init__(self, data_source, samplerA_indices):
        self.data_source = data_source
        self.samplerB_indices = [idx for idx in range(len(data_source)) if idx not in samplerA_indices]

    def __iter__(self):
        return iter(self.samplerB_indices)

    def __len__(self):
        return len(self.samplerB_indices)
