from http import client
from imghdr import tests
from tracemalloc import is_tracing
from PIL import Image
import os
import numpy as np
import torch
from torchvision import datasets, transforms
from yaml import DirectiveToken
from util.sampling import iid_sampling, non_iid_dirichlet_sampling
import torch.utils
from util.imbalance_cifar import IMBALANCECIFAR10, IMBALANCECIFAR100, IMBALANCE_IMAGENET
import pdb
from util.ImageNet_LT import *
from util.ignat_loader import *
from collections import Counter
class myDataset():
    def __init__(self, args):
        self.m_args = args
    def get_args(self):
        return self.m_args
    def get_imbalanced_dataset(self, args):
        args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu !=-1 else 'cpu')
        if args.dataset == 'cifar10':
            data_path = './cifar_lt/'
            args.num_classes = 10
            trans_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])],
            )
            trans_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])],
            )
            dataset_train = IMBALANCECIFAR10(data_path, imb_factor=args.IF,train=True, download=True, transform=trans_train)
            dataset_test = datasets.CIFAR10(data_path, train=False, download=True, transform=trans_val)
            n_train = len(dataset_train)
            y_train = np.array(dataset_train.targets)
            n_test = len(dataset_test)
            y_test = np.array(dataset_test.targets)
        elif args.dataset == 'cifar100':
            data_path = './cifar_lt/'
            args.num_classes = 100
            trans_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                                    std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])],
            )
            trans_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                                    std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])],
            )
            dataset_train = IMBALANCECIFAR100(data_path, imb_factor=args.IF,train=True, download=True, transform=trans_train)
            dataset_test = datasets.CIFAR100(data_path, train=False, download=True, transform=trans_val)
            n_train = len(dataset_train)
            y_train = np.array(dataset_train.targets)
            n_test = len(dataset_test)
            y_test = np.array(dataset_test.targets)
        elif args.dataset == 'imagenet':
            data_path = ''
            args.num_classes = 1000
            dataset_train = ImageNetLTDataLoader(training=True)
            dataset_test = ImageNetLTDataLoader(training=False)
            n_train = len(dataset_train)    
            y_train = np.array(dataset_train.dataset.targets)   
            n_test = len(dataset_test)      
            y_test = np.array(dataset_test.dataset.targets)
        elif args.dataset == 'inat':
            args.rootdir = ''
            args.train_file = ''
            args.im_size_train = 299
            args.im_size_test = 299
            args.num_classes = 1023
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std =[0.229, 0.224, 0.225])
            dataset_train  = torch.utils.data.DataLoader(
                IGNAT_Loader(args.rootdir, args.train_file,
                    transforms.Compose([
                    transforms.RandomSizedCrop(args.im_size_train),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]), is_train=True),
                batch_size=args.local_bs, shuffle=True,
                num_workers=2, pin_memory=True)
            dataset_test = torch.utils.data.DataLoader(
                IGNAT_Loader(args.rootdir, args.train_file,
                    transforms.Compose([
                    transforms.Scale(int(args.im_size_test/0.875)),
                    transforms.CenterCrop(args.im_size_test),
                    transforms.ToTensor(),
                    normalize,
                ]), is_train=False),
                batch_size=args.local_bs, shuffle=False,
                num_workers=2, pin_memory=True)
            n_train = len(dataset_train.dataset)
            y_train = np.array(dataset_train.dataset.classes)
            n_test = len(dataset_test.dataset)
            y_test = np.array(dataset_test.dataset.classes)
        else:
            exit('Error: unrecognized dataset')
        if args.iid:
            print("Into iid sampling")
            dict_users = iid_sampling(n_train, args.num_users, args.seed)
        else:
            print("Into non-iid sampling")
            dict_users = non_iid_dirichlet_sampling(y_train, args.num_classes, args.non_iid_prob_class, args.num_users, args.seed, args.alpha_dirichlet)
        clients_sizes= [len(dict_users[i]) for i in range(args.num_users)]
        print("clients_sizes:{}".format(clients_sizes))
        if args.dataset == 'cifar10':
            map_testset = {}
            for i in range(args.num_classes):
                idxs = []
                for j in range(10000):
                    if y_test[j]==i:
                        idxs.append(j)
                assert 1000==len(idxs)
                map_testset[i] = idxs
            assert len(map_testset) == args.num_classes
            alist = np.array([[np.sum(y_train[list(dict_users[i])]==j) for j in range(10)] for i in range(len(clients_sizes))])
            print("training set distribution:")
            print(alist)
            print("Total size of training set")
            print(sum(alist.sum(0)))
            distributions = np.array([[alist[i][j]/sum(alist.sum(0)) for j in range(10)] for i in range(len(clients_sizes)) ])
            testsizes = np.array([[int(distributions[i][j]*10000) for j in range(10)] for i in range(len(clients_sizes)) ])
            print("local test distribution:")
            print(testsizes)
            print("Total size of testing set")  
            print(sum(testsizes.sum(0)))
            print(testsizes.sum(0))
            dict_localtest = {}
            for i in range(args.num_users):
                idxs = []
                for j in range(args.num_classes):
                    cnt = testsizes[i][j]
                    temp=np.random.choice(map_testset[j], cnt, replace=False)
                    for m in range(len(temp)):
                        idxs.append(temp[m])
                dict_localtest[i] = set(idxs)
            blist = np.array([[np.sum(y_test[list(dict_localtest[i])]==j) for j in range(10)] for i in range(len(clients_sizes))])
            assert testsizes.all()==blist.all()
            assert len(dict_users) == len(dict_localtest) ==args.num_users
            self.training_set_distribution = alist  
            self.local_test_distribution = testsizes    
            self.global_test_distribution = np.sum(testsizes, axis=0)
        elif args.dataset == 'cifar100':
            alist = np.array([[np.sum(y_train[list(dict_users[i])]==j) for j in range(100)] for i in range(len(clients_sizes))])
            print("training set distribution:")
            print(alist)
            print("Total size of training set")
            print(sum(alist.sum(0)))
            distributions = np.array([[alist[i][j]/sum(alist.sum(0)) for j in range(100)] for i in range(len(clients_sizes)) ])
            testsizes = np.array([[int(distributions[i][j]*10000) for j in range(100)] for i in range(len(clients_sizes)) ])
            print("local test distribution:")
            print(testsizes)
            print("Total size of testing set")  
            print(sum(testsizes.sum(0)))
            print(testsizes.sum(0))
            dict_localtest = None
            self.training_set_distribution = alist  
            self.local_test_distribution = testsizes    
            self.global_test_distribution = np.sum(testsizes, axis=0)
            return dataset_train, dataset_test, dict_users, dict_localtest
        elif args.dataset == 'imagenet':
            alist = np.array([[np.sum(y_train[list(dict_users[i])]==j) for j in range(1000)] for i in range(len(clients_sizes))])
            print("training set distribution:")
            print(alist)
            print("Total size of training set")
            print(sum(alist.sum(0)))        
            distributions = np.array([[alist[i][j]/sum(alist.sum(0)) for j in range(1000)] for i in range(len(clients_sizes)) ])
            testsizes = np.array([[int(distributions[i][j]*50000) for j in range(1000)] for i in range(len(clients_sizes)) ])
            print("local test distribution:")
            print(testsizes)
            print("Total size of testing set")  
            print(sum(testsizes.sum(0)))
            print(testsizes.sum(0))
            dict_localtest = None
            self.training_set_distribution = alist  
            self.local_test_distribution = testsizes    
            self.global_test_distribution = np.sum(testsizes, axis=0)
            self.global_train_distribution = np.sum(alist, axis=0)
            return dataset_train.dataset, dataset_train.val_dataset, dict_users, dict_localtest
        elif args.dataset == 'inat':
            alist = np.array([[np.sum(y_train[list(dict_users[i])]==j) for j in range(1023)] for i in range(len(clients_sizes))])
            print("training set distribution:")
            print(alist)
            alist.sum(0)
            print("Total size of training set")
            print(sum(alist.sum(0)))
            distributions = np.array([[alist[i][j]/sum(alist.sum(0)) for j in range(1023)] for i in range(len(clients_sizes)) ])
            testsizes = np.array([[int(distributions[i][j]*30690) for j in range(1023)] for i in range(len(clients_sizes)) ])
            print("local test distribution:")
            print(testsizes)
            print("Total size of testing set")  
            print(sum(testsizes.sum(0)))
            print(testsizes.sum(0))
            dict_localtest = None
            self.training_set_distribution = alist  
            self.local_test_distribution = testsizes    
            self.global_test_distribution = np.sum(testsizes, axis=0)
            self.global_train_distribution = np.sum(alist, axis=0)
            return dataset_train.dataset, dataset_test.dataset, dict_users, dict_localtest
    def get_balanced_dataset(self, args):
        args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu !=-1 else 'cpu')
        if args.dataset == 'cifar10':
            data_path = './cifar_lt/'
            args.num_classes = 10
            trans_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])],
            )
            trans_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])],
            )
            dataset_train = datasets.CIFAR10(data_path, train=True, download=True, transform=trans_train)
            dataset_test = datasets.CIFAR10(data_path, train=False, download=True, transform=trans_val)
            n_train = len(dataset_train)
            y_train = np.array(dataset_train.targets)
            y_test = np.array(dataset_test.targets)
        else:
            exit('Error: unrecognized dataset')
        dict_users = iid_sampling(n_train, args.num_users, args.seed)
        for i in range(args.num_users):
            cls_num=10
            client_size = len(dict_users[i])
            img_max = client_size/ cls_num
            lt_sizes = []
            for cls_idx in range(cls_num):
                num = img_max * (args.IF**(cls_idx / (cls_num - 1.0)))
                lt_sizes.append(int(num))
            head = i%cls_num 
            for j in range(10): 
                cur_cls = (head+j)%10   
                target_cls_size = lt_sizes[j]
                labellist = y_train[list(dict_users[i])]==cur_cls
                cur_cls_size = np.sum(labellist)
                indices= []
                for (idx,v) in enumerate(labellist):
                    if v==True:
                        indices.append(list(dict_users[i])[idx])
                assert len(indices)==cur_cls_size
                for n in range(len(indices)):
                    assert y_train[indices[n]]==cur_cls
                if target_cls_size== cur_cls_size or target_cls_size>cur_cls_size:
                    print('the current class doesnt need dropout')
                    continue
                elif target_cls_size<cur_cls_size:
                    clientlist = list(dict_users[i])
                    cnt = cur_cls_size-target_cls_size
                    for m in range(cnt):
                        clientlist.remove(indices[m])
                    dict_users[i] = set(clientlist)
        alist = np.array([[np.sum(y_train[list(dict_users[i])]==j) for j in range(10)] for i in range(40)])
        print("training set distribution:")
        print(alist)
        print("Total size of training set")
        print(sum(alist.sum(0)))
        distributions = np.array([[alist[i][j]/sum(alist.sum(0)) for j in range(10)] for i in range(args.num_users) ])
        testsizes = np.array([[int(distributions[i][j]*10000) for j in range(10)] for i in range(args.num_users) ])
        print("local test distribution:")
        print(testsizes)
        print("Total size of testing set")  
        print(sum(testsizes.sum(0)))
        print(testsizes.sum(0))
        map_testset = {}
        for i in range(args.num_classes):
            idxs = []
            for j in range(10000):
                if y_test[j]==i:
                    idxs.append(j)
            assert 1000==len(idxs)
            map_testset[i] = idxs
        assert len(map_testset) == args.num_classes
        dict_localtest = {}
        for i in range(args.num_users):
            idxs = []
            for j in range(args.num_classes):
                cnt = testsizes[i][j]
                temp=np.random.choice(map_testset[j], cnt, replace=False)
                for m in range(len(temp)):
                    idxs.append(temp[m])
            dict_localtest[i] = set(idxs)
        blist = np.array([[np.sum(y_test[list(dict_localtest[i])]==j) for j in range(10)] for i in range(args.num_users)])
        assert testsizes.all()==blist.all()
        assert len(dict_users) == len(dict_localtest) ==args.num_users
        return dataset_train, dataset_test, dict_users, dict_localtest