import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as func
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision
from torchsummary import summary

def get_indices_gen(dataset, class_name):
    indices = []
    for i in range(len(dataset.targets)):
        if dataset.targets[i] == class_name:
            indices.append(i)
    return indices


def data_load_gen(args, dataset):
    idx = []
    for i in range(args.num_classes):
        idx.append(get_indices_gen(dataset, i))

    loader = []
    for i in range(args.num_classes):
        loader.append(DataLoader(dataset, batch_size=256, sampler=SubsetRandomSampler(idx[i])))
    return loader


def get_the_data(args):
    if args.dataset == "CIFAR10":
        train_dataset = datasets.CIFAR10(root='dataset/', transform=transforms.ToTensor(), download=True)
        test_dataset = datasets.CIFAR10(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)
    elif args.dataset == "CIFAR20":
        train_dataset = datasets.CIFAR100(root='dataset/', transform=transforms.ToTensor(), download=True)
        test_dataset = datasets.CIFAR100(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)


    # Split those datasets and get the links pointing to each segment
    split_train_dataset_gen, split_test_dataset_gen = data_load_gen(args, train_dataset), data_load_gen(args, test_dataset)

    return split_train_dataset_gen, split_test_dataset_gen

