import torch
import numpy as np
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100, SVHN
from torch.utils.data import DataLoader


def get_data(dataset):
    if dataset == "svhn":
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4496,),(0.1995,))])
        train = SVHN(root='../dataset/', split='train', download=True, transform=transform)
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4560,),(0.2244,))])
        test = SVHN(root='../dataset/', split='test', download=True, transform=transform)

    if dataset == "cifar10":
        transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        train = CIFAR10(root='../dataset/', train=True, download=True, transform=transform_train)
        test = CIFAR10(root='../dataset/', train=False, download=True, transform=transform_test)

    if dataset == "cifar100":
        transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])
        transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])
        train = CIFAR100(root='../dataset/', train=True, download=True, transform=transform_train)
        test = CIFAR100(root='../dataset/', train=False, download=True, transform=transform_test)

    num_examples = {"trainset": len(train), "testset": len(test)}
    return train, test, num_examples

def get_client_data(dataset, part_strategy, partition, id=0, val_ratio=0.1, seed=0):
    train, test, num= get_data(dataset)

    npy_name = dataset+"-"+part_strategy+"-"+str(partition)+"-"+str(seed)+".npy"
    train_id_map = np.load("../dataloader/npy/"+npy_name, allow_pickle=True)
    train_id_map=train_id_map.item()

    n_test = int(num["testset"] / partition)
    low, high = id * n_test, (id + 1) * n_test
    if high > num["testset"]:
        high = num["testset"]
    
    train_parition = torch.utils.data.Subset(train, train_id_map[id])
    n_valset = int(len(train_parition) * val_ratio)
    val_parition = torch.utils.data.Subset(train_parition, range(0, n_valset))
    train_parition = torch.utils.data.Subset(train_parition, range(n_valset, len(train_parition)))
    test_parition = torch.utils.data.Subset(test, range(low, high))

    return train_parition, val_parition, test_parition

def get_server_test_dataloader(dataset, batch_size):
    if dataset == "svhn":
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4560,),(0.2244,))])
        test = SVHN(root='../dataset/', split='test', download=True, transform=transform)
    
    if dataset == "cifar10":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        test = CIFAR10(root='../dataset/', train=False, download=True, transform=transform)
        
    if dataset == "cifar100":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        test = CIFAR100(root='../dataset/', train=False, download=True, transform=transform)

    test_Loader = DataLoader(test, batch_size=batch_size)
    return test_Loader

