import random
import numpy as np
from functools import reduce

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split


class CIFAR10():
    def __init__(self, trans=None):
        if trans is None:
            self.transforms = transforms.Compose(
                [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
        else:
            self.transforms = trans

        self.trainset = torchvision.datasets.CIFAR10(
                        root='DATASET_PATH', 
                        train=True,
                        download=True, 
                        transform=self.transforms,
                        )
        self.testset = torchvision.datasets.CIFAR10(
                        root='DATASET_PATH', 
                        train=False,
                        download=True, 
                        transform=self.transforms,
                        )

    def split(self, n=5, seed=123):
        self.split_trainsets = torch.utils.data.random_split(
                            self.trainset,
                            [len(self.trainset)//n for _ in range(n)],
                            torch.Generator().manual_seed(seed),
                            )
    
    def get_dataloaders(self, batch_size=64, shuffle=True, 
                        num_workers=2, split=False):
        testloader = torch.utils.data.DataLoader(
                        self.testset, 
                        batch_size=batch_size,
                        shuffle=shuffle, 
                        num_workers=num_workers
                        )
        
        trainloader = torch.utils.data.DataLoader(
                        self.trainset, 
                        batch_size=batch_size,
                        shuffle=shuffle, 
                        num_workers=num_workers
                        )
        if split:
            split_loaders = [torch.utils.data.DataLoader(
                            split_trainset, 
                            batch_size=batch_size, 
                            shuffle=shuffle, 
                            num_workers=num_workers) 
                            for split_trainset in self.split_trainsets]
            return split_loaders, testloader

        return trainloader, testloader
        

class SVHN():
    def __init__(self, trans=None):
        if trans is None:
            self.transforms = transforms.Compose(
                [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
        else:
            self.transforms = trans

        self.trainset = torchvision.datasets.SVHN(
                        root='DATASET_PATH',
                        transform=self.transforms, 
                        download=True,
                        split="train")
        self.testset = torchvision.datasets.SVHN(
                        root='DATASET_PATH',
                        transform=self.transforms, 
                        download=True,
                        split="test")

    def split(self, n=5, seed=123):
        self.split_trainsets = torch.utils.data.random_split(
                            self.trainset,
                            [10000 for _ in range(n)] + [len(self.trainset)-n*10000],
                            torch.Generator().manual_seed(seed),
                            )
        self.testset = torch.utils.data.Subset(self.testset, list(range(10000)))
    
    def get_dataloaders(self, batch_size=64, shuffle=True, 
                        num_workers=2, split=False):
        testloader = torch.utils.data.DataLoader(
                        self.testset, 
                        batch_size=batch_size,
                        shuffle=shuffle, 
                        num_workers=num_workers
                        )
        
        trainloader = torch.utils.data.DataLoader(
                        self.trainset, 
                        batch_size=batch_size,
                        shuffle=shuffle, 
                        num_workers=num_workers
                        )
        if split:
            split_loaders = [torch.utils.data.DataLoader(
                            split_trainset, 
                            batch_size=batch_size, 
                            shuffle=shuffle, 
                            num_workers=num_workers) 
                            for split_trainset in self.split_trainsets]
            return split_loaders, testloader

        return trainloader, testloader