from http import client
from imghdr import tests
from PIL import Image
import os
import numpy as np
import torch
from torchvision import datasets, transforms
from yaml import DirectiveToken
from util.sampling import iid_sampling, non_iid_dirichlet_sampling
import torch.utils
from util.imbalance_cifar import IMBALANCECIFAR10, IMBALANCECIFAR100, IMBALANCE_IMAGENET
import pdb
from util.ImageNet_LT import *
class myDataset():
    def __init__(self, args):
        self.m_args = args
    def get_args(self):
        return self.m_args
    def get_imbalanced_dataset(self, args):
        args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu !=-1 else 'cpu')
        if args.dataset == 'cifar10':
            data_path = './cifar_lt/'
            args.num_classes = 10
            trans_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])],
            )
            trans_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])],
            )
            dataset_train = IMBALANCECIFAR10(data_path, imb_factor=args.IF,train=True, download=True, transform=trans_train)
            dataset_test = datasets.CIFAR10(data_path, split='val', download=True, transform=trans_val)
            n_train = len(dataset_train)
            y_train = np.array(dataset_train.targets)
            n_test = len(dataset_test)
            y_test = np.array(dataset_test.targets)
        elif args.dataset == 'cifar100':
            data_path = './cifar_lt/'
            args.num_classes = 100
            trans_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                                    std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])],
            )
            trans_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                                    std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])],
            )
            dataset_train = IMBALANCECIFAR100(data_path, imb_factor=args.IF,train=True, download=True, transform=trans_train)
            dataset_test = datasets.CIFAR100(data_path, train=False, download=True, transform=trans_val)
            n_train = len(dataset_train)
            y_train = np.array(dataset_train.targets)
            n_test = len(dataset_test)
            y_test = np.array(dataset_test.targets)
        elif args.dataset == 'imagenet':
            data_path = ''
            args.num_classes = 1000
            dataset_train = ImageNetLTDataLoader(training=True)
            dataset_test = ImageNetLTDataLoader(training=False)
            n_train = len(dataset_train)
            y_train = np.array(dataset_train.dataset.targets)
            n_test = len(dataset_test)
            y_test = np.array(dataset_test.dataset.targets)
        else:
            exit('Error: unrecognized dataset')
        if args.iid:
            print("Into iid sampling")
            dict_users = iid_sampling(n_train, args.num_users, args.seed)
        else:
            print("Into non-iid sampling")
            dict_users = non_iid_dirichlet_sampling(y_train, args.num_classes, args.non_iid_prob_class, args.num_users, args.seed, args.alpha_dirichlet)
        clients_sizes= [len(dict_users[i]) for i in range(args.num_users)]
        print("clients_sizes:{}".format(clients_sizes))
        if args.dataset == 'cifar10':
            map_testset = {}    
            for i in range(args.num_classes):
                idxs = []
                for j in range(10000):
                    if y_test[j]==i:
                        idxs.append(j)
                assert 1000==len(idxs)
                map_testset[i] = idxs
            assert len(map_testset) == args.num_classes
            alist = np.array([[np.sum(y_train[list(dict_users[i])]==j) for j in range(10)] for i in range(len(clients_sizes))])
            print("training set distribution:")
            print(alist)
            print("Total size of training set")
            print(sum(alist.sum(0)))
            distributions = np.array([[alist[i][j]/sum(alist.sum(0)) for j in range(10)] for i in range(len(clients_sizes)) ])
            testsizes = np.array([[int(distributions[i][j]*10000) for j in range(10)] for i in range(len(clients_sizes)) ])
            print("local test distribution:")
            print(testsizes)
            print("Total size of testing set")  
            print(sum(testsizes.sum(0)))
            print(testsizes.sum(0))
            dict_localtest = {}
            for i in range(args.num_users):
                idxs = []
                for j in range(args.num_classes):
                    cnt = testsizes[i][j]
                    temp=np.random.choice(map_testset[j], cnt, replace=False)
                    for m in range(len(temp)):
                        idxs.append(temp[m])
                dict_localtest[i] = set(idxs)
            blist = np.array([[np.sum(y_test[list(dict_localtest[i])]==j) for j in range(10)] for i in range(len(clients_sizes))])
            assert testsizes.all()==blist.all()
            assert len(dict_users) == len(dict_localtest) ==args.num_users
            self.training_set_distribution = alist  
            self.local_test_distribution = testsizes    
            self.global_test_distribution = np.sum(testsizes, axis=0)
            return dataset_train, dataset_test, dict_users, dict_localtest
        elif args.dataset == 'cifar100':
            alist = np.array([[np.sum(y_train[list(dict_users[i])]==j) for j in range(100)] for i in range(len(clients_sizes))])
            print("training set distribution:")
            print(alist)
            print("Total size of training set")
            print(sum(alist.sum(0)))
            distributions = np.array([[alist[i][j]/sum(alist.sum(0)) for j in range(100)] for i in range(len(clients_sizes)) ])
            testsizes = np.array([[int(distributions[i][j]*10000) for j in range(100)] for i in range(len(clients_sizes)) ])
            print("local test distribution:")
            print(testsizes)
            print("Total size of testing set")  
            print(sum(testsizes.sum(0)))
            print(testsizes.sum(0))
            dict_localtest = None
            self.training_set_distribution = alist  
            self.local_test_distribution = testsizes    
            self.global_test_distribution = np.sum(testsizes, axis=0)
            return dataset_train, dataset_test, dict_users, dict_localtest
        elif args.dataset == 'imagenet':
            alist = np.array([[np.sum(y_train[list(dict_users[i])]==j) for j in range(1000)] for i in range(len(clients_sizes))])
            print("training set distribution:")
            print(alist)
            print("Total size of training set")
            print(sum(alist.sum(0)))
            distributions = np.array([[alist[i][j]/sum(alist.sum(0)) for j in range(1000)] for i in range(len(clients_sizes)) ])
            testsizes = np.array([[int(distributions[i][j]*50000) for j in range(1000)] for i in range(len(clients_sizes)) ])
            print("local test distribution:")
            print(testsizes)
            print("Total size of testing set")  
            print(sum(testsizes.sum(0)))
            print(testsizes.sum(0))
            dict_localtest = None
            self.training_set_distribution = alist  
            self.local_test_distribution = testsizes    
            self.global_test_distribution = np.sum(testsizes, axis=0)
            self.global_train_distribution = np.sum(alist, axis=0)
            return dataset_train.dataset, dataset_train.val_dataset, dict_users, dict_localtest
    def get_balanced_dataset(self, args):
        args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu !=-1 else 'cpu')
        if args.dataset == 'cifar10':
            data_path = './cifar_lt/'
            args.num_classes = 10
            trans_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])],
            )
            trans_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])],
            )
            dataset_train = datasets.CIFAR10(data_path, train=True, download=True, transform=trans_train)
            dataset_test = datasets.CIFAR10(data_path, train=False, download=True, transform=trans_val)
            n_train = len(dataset_train)
            y_train = np.array(dataset_train.targets)
            y_test = np.array(dataset_test.targets)
        else:
            exit('Error: unrecognized dataset')
        dict_users = iid_sampling(n_train, args.num_users, args.seed)
        for i in range(args.num_users):
            cls_num=10
            client_size = len(dict_users[i])
            img_max = client_size/ cls_num
            lt_sizes = []
            for cls_idx in range(cls_num):
                num = img_max * (args.IF**(cls_idx / (cls_num - 1.0)))
                lt_sizes.append(int(num))
            head = i%cls_num 
            for j in range(10): 
                cur_cls = (head+j)%10   
                target_cls_size = lt_sizes[j]
                labellist = y_train[list(dict_users[i])]==cur_cls
                cur_cls_size = np.sum(labellist)
                indices= []
                for (idx,v) in enumerate(labellist):
                    if v==True:
                        indices.append(list(dict_users[i])[idx])
                assert len(indices)==cur_cls_size
                for n in range(len(indices)):
                    assert y_train[indices[n]]==cur_cls
                if target_cls_size== cur_cls_size or target_cls_size>cur_cls_size:
                    print('the current class doesnt need dropout')
                    continue
                elif target_cls_size<cur_cls_size:
                    clientlist = list(dict_users[i])
                    cnt = cur_cls_size-target_cls_size
                    for m in range(cnt):
                        clientlist.remove(indices[m])
                    dict_users[i] = set(clientlist)
        alist = np.array([[np.sum(y_train[list(dict_users[i])]==j) for j in range(10)] for i in range(40)])
        print("training set distribution:")
        print(alist)
        print("Total size of training set")
        print(sum(alist.sum(0)))
        distributions = np.array([[alist[i][j]/sum(alist.sum(0)) for j in range(10)] for i in range(args.num_users) ])
        testsizes = np.array([[int(distributions[i][j]*10000) for j in range(10)] for i in range(args.num_users) ])
        print("local test distribution:")
        print(testsizes)
        print("Total size of testing set")  
        print(sum(testsizes.sum(0)))
        print(testsizes.sum(0))
        map_testset = {}
        for i in range(args.num_classes):
            idxs = []
            for j in range(10000):
                if y_test[j]==i:
                    idxs.append(j)
            assert 1000==len(idxs)
            map_testset[i] = idxs
        assert len(map_testset) == args.num_classes
        dict_localtest = {}
        for i in range(args.num_users):
            idxs = []
            for j in range(args.num_classes):
                cnt = testsizes[i][j]
                temp=np.random.choice(map_testset[j], cnt, replace=False)
                for m in range(len(temp)):
                    idxs.append(temp[m])
            dict_localtest[i] = set(idxs)
        blist = np.array([[np.sum(y_test[list(dict_localtest[i])]==j) for j in range(10)] for i in range(args.num_users)])
        assert testsizes.all()==blist.all()
        assert len(dict_users) == len(dict_localtest) ==args.num_users
        return dataset_train, dataset_test, dict_users, dict_localtest