import torch
from torch.utils.data import DataLoader, Dataset, Sampler
import numpy as np


class DualSampler(Sampler):
    def __init__(self, weights, fraction_tail=0.5):
        self.indices = np.arange(len(weights))
        self.num_samples = len(self.indices)
        self.fraction = fraction_tail
        self.weights = weights

    def __iter__(self):
        # import IPython
        # IPython.embed()
        tail_num_samples = int(self.fraction * self.num_samples)
        head_num_samples = self.num_samples - tail_num_samples
        #
        # tail_indices = np.random.choice(self.indices, size=tail_num_samples, p=self.weights, replace=False)
        # # Half from the tail distribution and half uniformly
        # uniform_indices = np.random.choice(list(set(self.indices) - set(tail_indices)),
        #                                    size=head_num_samples,
        #                                    replace=False)
        #
        # combined = np.concatenate([tail_indices, uniform_indices])
        # np.random.shuffle(combined)
        # return iter(combined)

        # Sample tail indices based on weights
        if tail_num_samples > 0:
            tail_indices = torch.multinomial(torch.from_numpy(self.weights), tail_num_samples, replacement=True)
        else:
            tail_indices = torch.tensor([], dtype=torch.long)

        # Create a mask for head indices
        mask = torch.ones(len(self.indices), dtype=torch.bool)
        mask[tail_indices] = 0

        # Sample head indices uniformly
        if head_num_samples > 0:
            head_indices = torch.multinomial(mask.float(), head_num_samples, replacement=True)
        else:
            head_indices = torch.tensor([], dtype=torch.long)

        # Combine and shuffle
        combined = torch.cat([tail_indices, head_indices])
        combined = combined[torch.randperm(len(combined))]

        yield from combined.tolist()

        # return iter(combined.tolist())

    def __len__(self):
        return self.num_samples