import torch
import numpy as np


class Subset(torch.utils.data.Dataset):

    def __init__(self, dataset, indices, transform):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform
        self.targets = np.array(dataset.targets)[self.indices]

    def __getitem__(self, idx):
        im, targets = self.dataset[self.indices[idx]]
        if self.transform:
            im = self.transform(im)
        return im, targets

    def __len__(self):
        return len(self.indices)


class SingleClassSubset(torch.utils.data.Dataset):

    def __init__(self, dataset, target_class):
        self.dataset = dataset
        self.indices = np.where(np.array(dataset.targets) == target_class)[0]
        self.targets = np.array(dataset.targets)[self.indices]
        self.target_class = target_class

    def __getitem__(self, idx):
        im, targets = self.dataset[self.indices[idx]]
        return im, targets

    def __len__(self):
        return len(self.indices)


class ClassSubset(torch.utils.data.Dataset):

    def __init__(self, dataset, target_classes):
        self.dataset = dataset
        self.indices = np.where(
            np.isin(np.array(dataset.targets), np.array(target_classes)))[0]
        self.targets = np.array(dataset.targets)[self.indices]
        self.target_classes = target_classes
        self.target_mapping = {c: i for i, c in enumerate(target_classes)}

    def __getitem__(self, idx):
        im, target = self.dataset[self.indices[idx]]
        target = self.target_mapping[target]
        return im, target

    def __len__(self):
        return len(self.indices)


class FixedSampleSizeSubset(torch.utils.data.Dataset):

    def __init__(self, dataset, samples_per_class, seed=0):
        self.dataset = dataset
        self.transform = dataset.transform
        self.samples_per_class = samples_per_class
        self.seed = seed
        self._generate_indices()

    def _generate_indices(self):
        np.random.seed(self.seed)
        indices = []
        for c in np.unique(self.dataset.targets):
            class_indices = np.where(np.array(self.dataset.targets) == c)[0]
            try:
                indices.extend(
                    np.random.choice(class_indices,
                                     self.samples_per_class,
                                     replace=False))
            except ValueError:
                indices.extend(
                    np.random.choice(class_indices,
                                     self.samples_per_class,
                                     replace=True))
        self.indices = indices

    def __getitem__(self, idx):
        im, targets = self.dataset[self.indices[idx]]
        return im, targets

    def __len__(self):
        return len(self.indices)
