
import os
import pdb
import json
import math
import time
import random
import datetime

import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K

from tensorflow.keras.datasets import mnist
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.datasets import cifar100

class Cifar100Generator:

    def __init__(self, client_id, opt):
             
        self.client_id = client_id
        self.num_tasks = 10
        self.max_num_examples = -1
        self.classes_per_task = []
        
        self.seprate_ratio = (0.7, 0.2, 0.1)
        self.MAX_NUM_CLASSES_PER_CLIENT = 5 # only when overlapped used

        (x_train, y_train), (x_test, y_test) = cifar100.load_data() # load dataset
        self.x = np.concatenate([x_train, x_test]).astype('float') / 255. # merge
        self.y = np.concatenate([y_train, y_test])
        self.classes = np.unique(self.y) 
        if K.image_data_format() == 'channels_first':
            self.x = np.expand_dims(self.x, axis=1) # (n, c, w, h)
        else:
            self.x = np.expand_dims(self.x, axis=3) # (n, w, h, c)

        self.num_classes_per_task = math.floor(len(self.classes)/self.num_tasks)

    def get_task(self, task_id):
        return self.get_evenly_split_task_multihead(task_id) # self.get_evenly_split_task(task_id)

    def get_evenly_split_task_multihead(self, task_id):
        self.task_id = task_id
        start_time = time.time()
        
        if len(self.classes_per_task) == 0:
            shuffled_classes = self.classes.tolist()
            np.random.seed(self.client_id) # multiprocess produce same random
            np.random.shuffle(shuffled_classes)
            self.classes_per_task = [shuffled_classes[x:x+self.num_classes_per_task] for x in range(0, len(self.classes), self.num_classes_per_task)]

        # print('classes_per_task', self.classes_per_task)
        classes_on_task = self.classes_per_task[task_id-1]
        idx_list_on_task = []
        for c in classes_on_task:
            idx = np.where(self.y[:]==c)
            # print(np.unique(self.y[idx]))
            # print(np.shape(idx))
            idx_list_on_task.append(idx[0])
        idx_on_task = np.concatenate(idx_list_on_task, axis=0)
        # np.random.seed(self.client_id) # multiprocess produce same random
        # np.random.shuffle(idx_on_task) # shuffle rows
        x = self.x[idx_on_task]
        y = self.y[idx_on_task]
        # print('y:', np.unique(y))
        self.separate_into_train_test_valid_multihead(x, y, num_classes=len(classes_on_task))

        train_set = [(np.reshape(self.x_train[i], (32, 32, 3)), self.y_train[i]) for i in range(len(self.x_train))]
        test_set  = [(np.reshape(self.x_test[i],  (32, 32, 3)), self.y_test[i]) for i in range(len(self.x_test))]
        valid_set = [(np.reshape(self.x_valid[i], (32, 32, 3)), self.y_valid[i]) for i in range(len(self.x_valid))]
        
        print('[%s][client:%d] data generated. (%d seconds.)' %(datetime.datetime.now().strftime("%Y%m%d-%H:%M:%S"), self.client_id, time.time()-start_time))
        return {
            'train': train_set, 
            'test': test_set, 
            'valid': valid_set,
            'name': '',
            'classes': classes_on_task,
            'train_size_per_class': []
        }
    
    def separate_into_train_test_valid_multihead(self, x, y, num_classes=None):

        if self.max_num_examples > -1:
            # random sampling
            num_examples = self.max_num_examples * num_classes
            idx = np.arange(len(x))
            idx = random.sample(idx.tolist(), num_examples)
            x = x[idx]
            y = y[idx]
        else: 
            # shuffle
            num_examples = x.shape[0] 
            idx = np.arange(num_examples)
            np.random.shuffle(idx) 
            x = x[idx]  
            y = y[idx]

        # categorical labels
        idx_list = []
        classes_on_task = np.unique(y)
        print('[%s][client:%d] task:%d contains classes %s' 
            %(datetime.datetime.now().strftime("%Y%m%d-%H:%M:%S"), self.client_id, self.task_id, ','.join(map(str, classes_on_task))))
        for c in classes_on_task:
            idx_list.append(np.argwhere(y[:]==c))
        for i, idx in enumerate(idx_list):
            # print(i)
            y[idx] = i
        y = tf.keras.utils.to_categorical(y, num_classes)

        num_train = int(num_examples*self.seprate_ratio[0]) # split according to ratio 
        num_test = int(num_examples*self.seprate_ratio[1])  # split according to ratio 
        self.x_train = x[0:num_train]
        self.y_train = y[0:num_train]
        self.x_test  = x[num_train:num_train+num_test]
        self.y_test  = y[num_train:num_train+num_test]
        self.x_valid = x[num_train+num_test:]
        self.y_valid = y[num_train+num_test:]

    
