import torch
from torchvision.datasets import CIFAR10, CIFAR100, SVHN
import torchvision.transforms as transforms
import numpy as np
import random
import os
import sys
sys.path.append('..')

def get_data(dataset):
    if dataset == "svhn":
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4496,),(0.1995,))])
        train = SVHN(root='../dataset/', split='train', download=True, transform=transform)
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4560,),(0.2244,))])
        test = SVHN(root='../dataset/', split='test', download=True, transform=transform)

    if dataset == "cifar10":
        transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        train = CIFAR10(root='../dataset/', train=True, download=True, transform=transform_train)
        test = CIFAR10(root='../dataset/', train=False, download=True, transform=transform_test)

    if dataset == "cifar100":
        transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])
        transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])
        train = CIFAR100(root='../dataset/', train=True, download=True, transform=transform_train)
        test = CIFAR100(root='../dataset/', train=False, download=True, transform=transform_test)

    num_examples = {"trainset": len(train), "testset": len(test)}
    return train, test, num_examples

def record_net_data_stats(y_train, net_dataidx_map):
    net_cls_counts = {}
    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp

    print('Data statistics: %s' % str(net_cls_counts))
    return net_cls_counts

def partition_train_data(dataset, y, part_strategy, partition, seed=0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

    N = y.shape[0]
    K = 10
    if dataset == 'cifar100': K = 100
    
    if part_strategy == "iid":
        idxs = np.random.permutation(N)
        batch_idxs = np.array_split(idxs, partition)
        net_dataidx_map = {i: batch_idxs[i] for i in range(partition)}

    elif "labeldir" in part_strategy:
        min_size = 0
        min_require_size = 10
        net_dataidx_map = {}
        beta = eval(part_strategy[8:])
        while min_size < min_require_size:
            idx_batch = [[] for _ in range(partition)]
            for k in range(K):
                idx_k = np.where(y == k)[0]
                np.random.shuffle(idx_k)
                proportions = np.random.dirichlet(np.repeat(beta, partition))
                proportions = np.array([p * (len(idx_j) < N / partition) for p, idx_j in zip(proportions, idx_batch)])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])

        for j in range(partition):
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]
    
    record_net_data_stats(y, net_dataidx_map)

    npy_name = dataset+"-"+part_strategy+"-"+str(partition)+"-"+str(seed)+".npy"
    if not os.path.exists("./npy/"):
            os.makedirs("./npy/")
    np.save("./npy/"+npy_name, net_dataidx_map)
    
    return net_dataidx_map

def generate_train_npy(dataset, part_strategy, partition, seed=0):
    train, _, _= get_data(dataset)
    if "cifar" in dataset:
        partition_train_data(dataset, np.array(train.targets), part_strategy, partition, seed)
    elif "svhn" in dataset:
        partition_train_data(dataset, np.array(train.labels), part_strategy, partition, seed)
    else:
        partition_train_data(dataset, train.targets.data.numpy(), part_strategy, partition, seed)


if __name__ == "__main__":
    seed=0  
    datasets=["svhn", "cifar10", "cifar100"]
    part_strategies=["iid", "labeldir0.5", "labeldir0.1"]  
    partitions=[10]
    for dataset in datasets:
        for part_strategy in part_strategies:
            for partition in partitions:
                generate_train_npy(dataset, part_strategy, partition, seed)

                