import pdb
import random
import torch.utils.data as data
import torch
import sys
from PIL import Image
import numpy as np
from torchvision import datasets, transforms
import os
from torch.utils.data import Dataset, DataLoader
from datasets import *


def partition_data(args):

    net_dataidx_map = {}
    train_dataset, _, _, y_train, _, _ = load_data(args)

    data_size = len(train_dataset)

    if args.partition == "iid":
        idxs = np.random.permutation(data_size)
        batch_idxs = np.array_split(idxs, args.n_clients)
        net_dataidx_map = {i: batch_idxs[i] for i in range(args.n_clients)}

    elif args.partition == "Dirichlet":
        min_size = 0
        min_require_size = 10
        label = args.num_classes

        while min_size < min_require_size:
            idx_batch = [[] for _ in range(args.n_clients)]
            for k in range(label):
                idx_k = np.where(y_train == k)[0]
                np.random.shuffle(idx_k)  # shuffle the label
                proportions = np.random.dirichlet(np.repeat(args.beta, args.n_clients))
                proportions = np.array(   # 0 or x
                    [p * (len(idx_j) < data_size / args.n_clients) for p, idx_j in zip(proportions, idx_batch)])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])
        for j in range(args.n_clients):
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]

    elif args.partition == "non_iid":

        n_label = args.n_label
        net_dataidx_map = {i: np.ndarray(0, dtype=np.int64) for i in range(args.n_clients)}
        all_label = np.arange(args.num_classes)
        np.random.shuffle(all_label)
        clients_label = all_label.reshape((args.n_clients, n_label))

        for i in range(args.n_clients):
            client_idx = []
            for j in range(n_label):
                idx = np.where(y_train == clients_label[i][j])[0].tolist()
                client_idx.append(idx)
            net_dataidx_map[i] = np.array(client_idx).reshape(int(data_size / args.n_clients))

    elif "non_iid0" < args.partition <= "non_iid9":
        num = eval(args.partition[7:])
        K = 10
        times=[0 for i in range(K)]
        contain=[]
        for i in range(args.n_clients):
            current=[i%K]
            times[i%K]+=1
            j=1
            while (j<num):
                ind=random.randint(0,K-1)
                if (ind not in current):
                    j=j+1
                    current.append(ind)
                    times[ind]+=1
            contain.append(current)
        net_dataidx_map ={i:np.ndarray(0,dtype=np.int64) for i in range(args.n_clients)}
        print(net_dataidx_map)
        for i in range(K):
            idx_k = np.where(y_train==i)[0]
            np.random.shuffle(idx_k)
            split = np.array_split(idx_k,times[i])
            ids=0
            for j in range(args.n_clients):
                if i in contain[j]:
                    net_dataidx_map[j]=np.append(net_dataidx_map[j],split[ids])
                    ids+=1
    train_data_cls_counts = record_net_data_stats(y_train, net_dataidx_map)
    return net_dataidx_map, train_data_cls_counts


def record_net_data_stats(y_train, net_dataidx_map):
    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)

        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp
    return net_cls_counts


class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)