import os, sys
import numpy as np
import torch
from torchvision import datasets, transforms
from sklearn.utils import shuffle
from torchvision import datasets,transforms
import json
from torch.utils.data import Dataset
from sklearn.utils import shuffle
from PIL import Image

def get(seed=0, fixed_order=False, pc_valid=0, sim_ntasks=10, dataset_size='small'):
    size=[3,32,32]
    # n_tasks = 10
    n_tasks = 5

    data={}
    taskcla=[]

    data_celeba, taskcla_celeba, size_celeba = read_celeba(seed=seed,sim_ntasks=sim_ntasks, data_size=dataset_size)



    all_celeba = [data_celeba[x]['name'] for x in range(sim_ntasks)]

    f_name = 'mixceleba_random_'+str(sim_ntasks*2)  


    with open(f_name,'r') as f_random_seq:
        random_sep = f_random_seq.readlines()[3].split()

    print(random_sep)
    #对random_sep进行处理，将含有'fe-mnist'的放在前面，'emnist'放在后面
    new_f_seq = []
    new_e_seq = []
    for i in range(len(random_sep)):
        if 'celeba' in random_sep[i]:
            new_f_seq.append(random_sep[i])
        elif 'cifar100' in random_sep[i]:
            new_e_seq.append(random_sep[i])
    random_sep = new_f_seq + new_e_seq


    print(random_sep)
    for task_id in range(sim_ntasks):
        if 'cifar100' in random_sep[task_id]:# Even
            break
            cifar100_id = all_cifar100.index(random_sep[task_id])
            data[task_id] = data_cifar100[cifar100_id]
            taskcla.append((task_id,data_cifar100[cifar100_id]['ncla']))

        elif 'celeba'in random_sep[task_id]:
            celeba_id = all_celeba.index(random_sep[task_id])
            print(celeba_id)
            data[task_id] = data_celeba[celeba_id]
            taskcla.append((task_id,data_celeba[celeba_id]['ncla']))

    print(taskcla)
    return data,taskcla,size

def read_celeba(seed=0,pc_valid=0.10,sim_ntasks=10, data_size='full'):
    data={}
    taskcla=[]
    size=[3, 218, 178]
    size=[3, 32, 32]


    num_task = sim_ntasks
    n_tasks = sim_ntasks

    if 'small' in data_size:
        data_type = 'small'
    elif 'full' in data_size:
        data_type = 'full'

    if not os.path.isdir('./data/'+data_type+'_binary_celeba/'+str(n_tasks)+'/'):
        os.makedirs('./data/'+data_type+'_binary_celeba/'+str(n_tasks))

        mean=[x/255 for x in [125.3,123.0,113.9]]
        std=[x/255 for x in [63.0,62.1,66.7]]

        # celeba
        dat={}
        train_dataset = CELEBATrain(root_dir='./data/celeba/'+data_type+'/iid/train/',img_dir='./data/celeba/data/raw/img_align_celeba/',transform=transforms.Compose([transforms.Resize(size=(32,32)),transforms.ToTensor(),transforms.Normalize(mean,std)]))
        dat['train'] = train_dataset

        test_dataset = CELEBATest(root_dir='./data/celeba/'+data_type+'/iid/test/',img_dir='./data/celeba/data/raw/img_align_celeba/',transform=transforms.Compose([transforms.Resize(size=(32,32)),transforms.ToTensor(),transforms.Normalize(mean,std)]))
        dat['test'] = test_dataset

        users = [x[0] for x in set([user for user,image,target in torch.utils.data.DataLoader(dat['train'],batch_size=1,shuffle=True)])]
        users.sort()
        users = users[:num_task]
        print('users: ',users)
        print('users length: ',len(users))

        # 统计一下每个user有多少个图片
        user_img_num = {}
        for user,image,target in torch.utils.data.DataLoader(dat['train'],batch_size=1,shuffle=True):
            if user[0] not in user_img_num.keys():
                user_img_num[user[0]] = 1
            else:
                user_img_num[user[0]] += 1

        # 按照图片数量由大到小排序
        user_img_num_sorted = sorted(user_img_num.items(), key=lambda x: x[1], reverse=True)
        # 取前num_task个user
        user_img_num_sorted = user_img_num_sorted[:num_task]
        # 取出user id
        users = [x[0] for x in user_img_num_sorted]
        print('users: ',users)
        print('users length: ',len(users))
        


        # # totally 10 tasks, each tasks 2 classes (whether smiling)
        #
        for task_id,user in enumerate(users):
            data[task_id]={}
            data[task_id]['name'] = 'celeba-'+str(user)
            data[task_id]['ncla'] = 2


        for s in ['train','test']:
            loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=True)

            for task_id,user in enumerate(users):
                data[task_id][s]={'x': [],'y': []}

            for user,image,target in loader:
                if user[0] not in users: continue # we dont want too may classes
                label=target.numpy()[0]
                data[users.index(user[0])][s]['x'].append(image)
                data[users.index(user[0])][s]['y'].append(label)


        # # "Unify" and save
        for n,user in enumerate(users):
            for s in ['train','test']:
                data[n][s]['x']=torch.stack(data[n][s]['x']).view(-1,size[0],size[1],size[2])
                data[n][s]['y']=torch.LongTensor(np.array(data[n][s]['y'],dtype=int)).view(-1)
                torch.save(data[n][s]['x'], os.path.join(os.path.expanduser('./data/'+data_type+'_binary_celeba/'+str(n_tasks)),'data'+str(n)+s+'x.bin'))
                torch.save(data[n][s]['y'], os.path.join(os.path.expanduser('./data/'+data_type+'_binary_celeba/'+str(n_tasks)),'data'+str(n)+s+'y.bin'))


    # number of example
    # need to further slice [:user_num]
    # number of example


    # Load binary files
    data={}
    ids=list(shuffle(np.arange(num_task),random_state=seed))
    print('Task order =',ids)
    for i in range(num_task):
        data[i] = dict.fromkeys(['name','ncla','train','test'])
        for s in ['train','test']:
            data[i][s]={'x':[],'y':[]}
            data[i][s]['x']=torch.load(os.path.join(os.path.expanduser('./data/'+data_type+'_binary_celeba/'+str(n_tasks)),'data'+str(ids[i])+s+'x.bin'))
            data[i][s]['y']=torch.load(os.path.join(os.path.expanduser('./data/'+data_type+'_binary_celeba/'+str(n_tasks)),'data'+str(ids[i])+s+'y.bin'))
        data[i]['ncla']=len(np.unique(data[i]['train']['y'].numpy()))
        data[i]['name']='celeba-'+str(ids[i])


    # Validation
    for t in data.keys():
        r=np.arange(data[t]['train']['x'].size(0))
        r=np.array(shuffle(r,random_state=seed),dtype=int)
        print('len r: ',len(r))
        nvalid=int(pc_valid*len(r))
        if nvalid == 0:
            nvalid = 1
        ivalid=torch.LongTensor(r[:nvalid])
        print('ivalid: ',len(ivalid))
        itrain=torch.LongTensor(r[nvalid:])
        data[t]['valid']={}
        data[t]['valid']['x']=data[t]['train']['x'][ivalid].clone()
        data[t]['valid']['y']=data[t]['train']['y'][ivalid].clone()
        data[t]['train']['x']=data[t]['train']['x'][itrain].clone()
        data[t]['train']['y']=data[t]['train']['y'][itrain].clone()

    # Others
    n=0
    for t in data.keys():
        taskcla.append((t,data[t]['ncla']))
        n+=data[t]['ncla']
    data['ncla']=n

    return data,taskcla,size


########################################################################################################################



# customize dataset class

class CELEBATrain(Dataset):
    """Federated EMNIST dataset."""

    def __init__(self, root_dir,img_dir, transform=None):
        self.transform = transform
        self.size=[218, 178, 3]

        self.x = []
        self.y = []
        self.user = []
        for file in os.listdir(root_dir):
            with open(root_dir+file) as json_file:
                data = json.load(json_file) # read file and do whatever we need to do.
                for key, value in data['user_data'].items():
                    for type, data in value.items():
                        if type == 'x':
                            for img in data:
                                img_name = img_dir + img
                                im = Image.open(img_name)
                                np_im = np.array(im)
                                self.x.append(torch.from_numpy(np_im))
                        elif type == 'y':
                            self.y.append(data)

                    for _ in range(len(data)):
                        self.user.append(key)

        self.x=torch.cat(self.x,0).view(-1,self.size[0],self.size[1],self.size[2])
        self.y=torch.LongTensor(np.array([d for f in self.y for d in f],dtype=int)).view(-1).numpy()

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):

        user = self.user[idx]
        x = self.x[idx]
        y = self.y[idx]

        x = x.data.numpy()
        x = Image.fromarray(x)
        # x = Image.fromarray((x * 255).astype(np.uint8))

        if self.transform:
            x = self.transform(x)
        return user,x,y






class CELEBATest(Dataset):
    """Federated EMNIST dataset."""

    def __init__(self, root_dir,img_dir, transform=None):
        self.transform = transform
        self.size=[218, 178, 3]

        self.x = []
        self.y = []
        self.user = []
        for file in os.listdir(root_dir):
            with open(root_dir+file) as json_file:
                data = json.load(json_file) # read file and do whatever we need to do.
                for key, value in data['user_data'].items():
                    for type, data in value.items():
                        if type == 'x':
                            for img in data:
                                img_name = img_dir + img
                                im = Image.open(img_name)
                                np_im = np.array(im)
                                self.x.append(torch.from_numpy(np_im))
                        elif type == 'y':
                            self.y.append(data)

                    for _ in range(len(data)):
                        self.user.append(key)

        self.x=torch.cat(self.x,0).view(-1,self.size[0],self.size[1],self.size[2])
        self.y=torch.LongTensor(np.array([d for f in self.y for d in f],dtype=int)).view(-1).numpy()

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):

        user = self.user[idx]
        x = self.x[idx]
        y = self.y[idx]

        x = x.data.numpy()
        x = Image.fromarray(x)
        # x = Image.fromarray((x * 255).astype(np.uint8))

        if self.transform:
            x = self.transform(x)
        return user,x,y
    
if __name__ == "__main__":
    data,taskcla,inputsize=get(seed=0)
    for t,ncla in taskcla:
        print('*'*100)
        print('Task {:2d} ({:s})'.format(t,data[t]['name']))
        print('*'*100)
        xtrain=data[t]['train']['x']
        ytrain=data[t]['train']['y']
        xvalid=data[t]['valid']['x']
        yvalid=data[t]['valid']['y']
        print('  Input size:',inputsize,'train',xtrain.size(),'test',xvalid.size(),'ncla',ncla)

# if __name__ == '__main__':
#     mean=(0.1307,) # Mean and std including the padding
#     std=(0.3081,)
#     data={}
#     data['train']= CELEBATrain(root_dir='./data/celeba/small/iid/train/',img_dir='./data/celeba/data/raw/img_align_celeba/',transform=transforms.Compose([transforms.Resize(size=(32,32)),transforms.ToTensor(),transforms.Normalize(mean,std)]))
#     data['test']= CELEBATest(root_dir='./data/celeba/small/iid/test/',img_dir='./data/celeba/data/raw/img_align_celeba/',transform=transforms.Compose([transforms.Resize(size=(32,32)),transforms.ToTensor(),transforms.Normalize(mean,std)]))
    
#     data['name']='celeba'
#     data['ncla']=47
#     class_num = set()
#     for s in ['train','test']:
#         loader=torch.utils.data.DataLoader(data[s],batch_size=1,shuffle=False)
#         data[s]={'x': [],'y': []}
#         for user,image,target in loader:
#             print(target)
#             data[s]['x'].append(image)
#             data[s]['y'].append(target)
#             class_num.add(target.numpy()[0])
#             print(image.shape)

#         print('='*20)
#     print(len(class_num))
