import numpy as np

import torch
import torchvision.datasets as datasets

np.random.seed(12345)


class MNISTRandomLabels(datasets.MNIST):
  """MNIST dataset, with support for randomly corrupt labels.
  Params
  ------
  corrupt_prob: float
    Default 0.0. The probability of a label being replaced with
    random label.
  num_classes: int
    Default 10. The number of classes in the dataset.
  """
  def __init__(self, corrupt_prob=0.0, num_classes=10,**kwargs):
    super(MNISTRandomLabels, self).__init__(**kwargs)
    self.n_classes = num_classes
    if corrupt_prob > 0:
        self.corrupt_labels(corrupt_prob)

  def corrupt_labels(self, corrupt_prob):
    labels = np.array(self.targets if self.train else self.test_labels)
    np.random.seed(12345)
    mask = np.random.rand(len(labels)) <= corrupt_prob
    rnd_labels = np.random.choice(self.n_classes, mask.sum())
    labels[mask] = rnd_labels
    # we need to explicitly cast the labels from npy.int64 to
    # builtin int type, otherwise pytorch will fail...
    labels = [int(x) for x in labels]

    if self.train:
        self.targets= labels
    else:
        self.targets = labels



class MNISTSubset_random(datasets.MNIST):
  """MNIST dataset, with support for random selection of a subset: 
  Params
  ------
  num_samples_per_class: int
    Default 0. The number of samples for each class
  num_classes: int
    Default 10. The number of classes in the dataset.
  ------
  Return:
  An unblanced subset (num_samples_per_class x num_classes)of MNIST randomly sampled without replacement 
  """
  def __init__(self, num_samples_per_class=0.0, num_classes=10,**kwargs):
    super(MNISTSubset_random, self).__init__(**kwargs)
    self.num_classes = num_classes
    self.num_samples_per_class = num_samples_per_class


    if self.num_samples_per_class > 0:
        self.get_subset()

  def get_subset(self):
    labels = np.array(self.targets if self.train else self.test_labels)

    total_samples = self.num_classes * self.num_samples_per_class

    rand_loc = np.random.permutation(len(labels))
    final_loc = rand_loc[:total_samples]
    final_loc = np.sort(final_loc)


    if self.train:
        self.targets= self.targets[final_loc]
        self.data = self.data[final_loc]

    else:
        self.targets= self.targets
        self.data = self.data









class MNISTSubset_balanced(datasets.MNIST):
  """MNIST dataset, with support for random selection of a subset (equal number of samples per class): 
  Params
  ------
  num_samples_per_class: int
    Default 0. The number of samples for each class
  num_classes: int
    Default 10. The number of classes in the dataset.
  ------
  Return:
  An unblanced subset (num_samples_per_class x num_classes)of MNIST randomly sampled without replacement 
  """
  def __init__(self, num_samples_per_class=0.0, num_classes=10,**kwargs):
    super(MNISTSubset_balanced, self).__init__(**kwargs)
    self.num_classes = num_classes
    self.num_samples_per_class = num_samples_per_class



    if self.num_samples_per_class > 0:
        self.get_subset()

  def get_subset(self):
    labels = np.array(self.targets if self.train else self.test_labels)


    ########## this is for creating balanced dataset  ############################
    loc = [np.where(labels==i)[0] for i in range(self.num_classes)] # find the locations for each class

    rand_loc = [np.random.permutation(i) for i in loc]

    sub_loc = [i[:self.num_samples_per_class] for i in rand_loc]

    final_loc = np.asarray(sub_loc).flatten()
    final_loc = np.sort(final_loc)

    if self.train:
        self.targets= [self.targets[i] for i in final_loc]
        self.data = self.data[final_loc]

    else:
        self.targets= self.targets
        self.data = self.data












        
