import numpy as np
import torch
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from data.pytorch_datasets import get_dataset
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='mnist', type=str)

class RandomSubset:
    """
        Select a random subset of size K from the dataset.
    """
    def __init__(self, dataset):
        self.dataset = dataset
    
    def get_sample(self, K):
        idx_list = np.arange(len(self.dataset))
        choice = np.random.choice(idx_list, min(K, len(self.dataset)), replace=False).tolist()
        return Subset(self.dataset, choice)
    
    def split_ds(self, K):
        idx_list = np.arange(len(self.dataset))
        np.random.shuffle(idx_list)
        attack_idxs, rem_idxs = idx_list[:K], idx_list[K:]
        attack_set = Subset(self.dataset, attack_idxs)
        rem_set = Subset(self.dataset, rem_idxs)
        return (attack_set, rem_set)


class ClassRandomSubset:
    """
        Select a random subset of size K from the dataset so that all points have the same class.
    """
    def __init__(self, dataset, clss) -> None:
        self.dataset = dataset
        self.clss = clss
        self.get_class_dataset()
    
    def get_class_dataset(self):
        class_idxs_list = [torch.where(self.dataset.targets == class_idx)[0] for class_idx in self.dataset.class_to_idx.values()]
        self.cls_dataset = Subset(self.dataset, class_idxs_list[self.clss])
    
    def get_sample(self, K):
        idx_list = np.arange(len(self.cls_dataset))
        choice = np.random.choice(idx_list, min(K, len(self.cls_dataset)), replace=False).tolist()
        return Subset(self.cls_dataset, choice)


if __name__=="__main__":
    args = parser.parse_args()
    dataset = get_dataset(args)[0]
    # Test Random Sampler!
    K=20
    print("Random subset with sample size ", K)
    random_sampler = RandomSubset(dataset)
    subset1 = random_sampler.get_sample(K)
    # Print out the classes to see the diversity
    class_dict = dict.fromkeys(list(dataset.class_to_idx.values()), 0)
    for (x, y) in subset1:
        class_dict[y] += 1
    
    print("Class stats from random subset:")

    for k,v in class_dict.items():
        print(f"{k} : {v}")

    # Test Class Random sampler
    clss = 2
    print(f"Class Random subset with class {clss} and sample size {K}")
    cls_random_sampler = ClassRandomSubset(dataset, clss)
    subset2 = cls_random_sampler.get_sample(K)
    # Print out the classes to see that they are all the same!
    for (x, y) in subset2:
        print(y)
