import numpy as np

import torch
from torch.utils.data import SubsetRandomSampler, Sampler


class MislabelDataset:
    def __init__(self, dataset, num_class, mislabel_ratio=0.1, mislabel_seed=0):
        self.dataset = dataset
        self.num_class = num_class
        self.mislabel_ratio = mislabel_ratio
        self.seed = mislabel_seed
        self.mislabeled_indices = []
        self.mislabeled_targets = []
        self._generate_mislabeled_data()

    def _generate_mislabeled_data(self):
        np.random.seed(self.seed)
        num_samples = len(self.dataset)
        num_mislabeled = int(num_samples * self.mislabel_ratio)
        indices = np.random.choice(num_samples, num_mislabeled, replace=False)
        self.mislabeled_indices = indices
        self.mislabeled_targets = np.random.randint(0, self.num_class, size=num_mislabeled)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        if len(data) == 3:
            img, target, index = data
        else:
            img, target = data
        mislabeled_index = np.where(self.mislabeled_indices == index)[0]
        if len(mislabeled_index) > 0:
            false_target = self.mislabeled_targets[mislabeled_index[0]]
        else:
            false_target = target
        if len(data) == 3:
            # return img, false_target, target, index
            return img, false_target, index
        else:
            # return img, false_target, target
            return img, false_target


class SubsetSampler(Sampler):

    def __init__(self, indices):

        self.indices = indices

    def __iter__(self):

        return (self.indices[i] for i in range(len(self.indices)))

    def __len__(self):
        
        return len(self.indices)