import os
import pdb
import glob
import json
import math
import time
import shutil
import random
import datetime

import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K

from tensorflow.keras.datasets import mnist
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.datasets import cifar100
from data._generator import Cifar100Generator
from common.utils import *

class DataGenerator:

	def __init__(self, client_id, opt):
		self.opt = opt
		self.client_id = client_id
		self.seprate_ratio = (0.7, 0.2, 0.1) # train, test, valid
		if self.opt.split_option=='non_iid':
			self.base_dir = os.path.join(self.opt.data_dir, self.opt.split_option, 'num_classes_'+str(self.opt.num_classes))
		elif self.opt.split_option=='overlapped':
			self.base_dir = os.path.join(self.opt.data_dir, self.opt.split_option, 'num_classes_'+str(self.opt.num_classes))
		elif self.opt.split_option=='iid':
			self.base_dir = os.path.join(self.opt.data_dir, self.opt.split_option, 'num_clients_'+str(self.opt.num_clients))
		# else:
		#     overlapped ignored
		self.info = {
			'num_tasks': 0,
			'datasets': ', '.join(map(get_dataset_name, self.opt.dataset))
		}
		if self.client_id > -1:
			self.tasks=[]
			if len(self.opt.manual_tasks)>0:
				self.tasks = self.opt.manual_tasks #[task+'.npy' for task in self.opt.manual_tasks]
				syslog(self.client_id, 'tasks are manually set: {}'.format(', '.join(self.opt.manual_tasks)))
			else:
				if self.opt.split_option == 'non_iid':
					for d in self.opt.dataset:
						path = os.path.join(self.base_dir, get_dataset_name(d)+'_*')
						self.tasks += [os.path.basename(p) for p in glob.glob(path)]
					# shuffle in the globally same order to ensure instance-level-non-overlapped in client-wise
					random_shuffle(self.opt.global_random_seed, self.tasks)
					offset = self.client_id*self.opt.num_pick_tasks
					self.tasks = self.tasks[offset:offset+self.opt.num_pick_tasks]
				elif self.opt.split_option == 'overlapped':
					for d in self.opt.dataset:
						path = os.path.join(self.base_dir, get_dataset_name(d)+'_*')
						self.tasks += [os.path.basename(p) for p in glob.glob(path)]
					# shuffle in the globally same order to ensure instance-level-non-overlapped in client-wise
					random_shuffle(self.opt.global_random_seed, self.tasks)
					offset = self.client_id*self.opt.num_pick_tasks
					self.tasks = self.tasks[offset:offset+self.opt.num_pick_tasks]
				elif self.opt.split_option == 'iid':
					for d in self.opt.dataset:
						path = os.path.join(self.base_dir, get_dataset_name(d)+'_'+str(self.client_id)+'*')
						self.tasks += [os.path.basename(p) for p in glob.glob(path)]
					random_shuffle(self.client_id, self.tasks)
					# ['fashion_mnist_0.npy', 'cifar100_0.npy', 'traffic_sign_0.npy', 'mnist_0.npy', 'cifar10_0.npy', 'svhn_0.npy', 'not_mnist_0.npy', 'face_scrub_0.npy']
					# ['svhn_1.npy', 'face_scrub_1.npy', 'cifar100_1.npy', 'traffic_sign_1.npy', 'not_mnist_1.npy', 'cifar10_1.npy', 'fashion_mnist_1.npy', 'mnist_1.npy']
					# ['traffic_sign_2.npy', 'svhn_2.npy', 'fashion_mnist_2.npy', 'cifar100_2.npy', 'mnist_2.npy', 'face_scrub_2.npy', 'not_mnist_2.npy', 'cifar10_2.npy']
					# ['cifar10_3.npy', 'traffic_sign_3.npy', 'not_mnist_3.npy', 'mnist_3.npy', 'cifar100_3.npy', 'face_scrub_3.npy', 'fashion_mnist_3.npy', 'svhn_3.npy']
					# ['cifar100_4.npy', 'face_scrub_4.npy', 'traffic_sign_4.npy', 'fashion_mnist_4.npy', 'not_mnist_4.npy', 'cifar10_4.npy', 'mnist_4.npy', 'svhn_4.npy']
				print(self.tasks)

			self.info['num_tasks'] = len(self.tasks)

	def get_info(self):
		return self.info

	def generate_data(self):
		saved_mixture_filepath = os.path.join(self.opt.mixture_dir, 'saved', self.opt.mixture_filename)
		if os.path.exists(saved_mixture_filepath):
			syslog(self.client_id, 'loading mixture data: {}'.format(saved_mixture_filepath))
			mixture = np.load(saved_mixture_filepath, allow_pickle=True)
		else:
			syslog(self.client_id, 'downloading & processing mixture data')
			mixture = get(base_dir=self.opt.mixture_dir, fixed_order=True)
			np_save(os.path.join(self.opt.mixture_dir, 'saved'), self.opt.mixture_filename, mixture)
		self.generate_tasks(mixture)

	def generate_tasks(self, mixture):
		syslog(self.client_id, 'generating tasks with the given options')
		self.task_cnt = -1
		#########
		#self.opt.c100_superclass   = True
		########
		if not self.opt.c100_superclass:
			for did in self.opt.dataset:
				self._generate_tasks(did, mixture[0][did])
		else:
			for did in self.opt.dataset:
				self._generate_c100_tasks(did, mixture[0][did])

	def _generate_c100_tasks(self, did, data):
		x_train = data['train']['x']
		y_train = data['train']['y']
		x_test = data['test']['x']
		y_test = data['test']['y']

		x = np.concatenate([x_train, x_test])
		y = np.concatenate([y_train, y_test])
		labels_pair = self.cifar100_superclass_label_pair()
		labels = np.unique(y)

		if self.opt.split_option == 'non_iid': # NonIID
			#labels_per_task = [labels[i:i+self.opt.num_classes] for i in range(0, len(labels), self.opt.num_classes)]
			labels_per_task = labels_pair
		elif self.opt.split_option == 'overlapped':
			#labels_per_task = [np.array(random.sample(labels.tolist(), self.opt.num_classes)) for _ in range(self.opt.gen_num_tasks)]
			labels_per_task = []
			for _ in range(self.opt.num_clients):
				random_shuffle(1004, labels_pair)
				labels_per_task.append(labels_pair[:10])
			temp = np.concatenate(labels_per_task)
			labels_per_task = temp.tolist()

			label_map = {}
			for labels in labels_per_task:
				for l in labels:
					if l not in label_map:
						label_map[l] = 0
					label_map[l] += 1
			print('label_map%s'%label_map)
		if self.opt.split_option == 'iid':
			pass
		else:
			for t, task in enumerate(labels_per_task):
				self.task_cnt += 1
				if self.opt.split_option == 'non_iid':
					idx = np.concatenate([np.where(y[:]==c)[0] for c in task], axis=0)
					random_shuffle(self.opt.global_random_seed, idx)
				else: # overlapped
					def split_instance(c, idx):
						idx = idx[:round(len(idx)/label_map[c])]
						label_map[c] -= 1
						return idx
					idx = np.concatenate([split_instance(c, np.where(y[:]==c)[0]) for c in task], axis=0)
					random_shuffle(self.opt.global_random_seed, idx)
				x_task = x[idx]
				y_task = y[idx]
				syslog(self.client_id, 'task: %d, dataset: %s, classes: %s'
					%(self.task_cnt, get_dataset_name(did),','.join(map(str, task))))
				self._save_task(x_task, y_task, task, did, t)


	def _generate_tasks(self, did, data):
		x_train = data['train']['x']
		y_train = data['train']['y']
		x_test = data['test']['x']
		y_test = data['test']['y']

		x = np.concatenate([x_train, x_test])
		y = np.concatenate([y_train, y_test])

		labels = np.unique(y)
		random_shuffle(self.opt.global_random_seed, labels)

		if self.opt.split_option == 'non_iid': # NonIID
			labels_per_task = [labels[i:i+self.opt.num_classes] for i in range(0, len(labels), self.opt.num_classes)]
		elif self.opt.split_option == 'overlapped':
			labels_per_task = [np.array(random.sample(labels.tolist(), self.opt.num_classes)) for _ in range(self.opt.gen_num_tasks)]
			label_map = {}
			for labels in labels_per_task:
				for l in labels:
					if l not in label_map:
						label_map[l] = 0
					label_map[l] += 1
		if self.opt.split_option == 'iid':
			for cid in range(self.opt.num_clients):
				self.task_cnt += 1
				indices = np.arange(len(x))
				random_shuffle(self.opt.global_random_seed, indices) # globally same order
				offset = round(len(indices)/self.opt.num_clients)
				idx = indices[cid*offset:(cid+1)*offset]
				print(cid, cid*offset, (cid+1)*offset)
				x_task = x[idx]
				y_task = y[idx]
				syslog(self.client_id, 'task: %d, dataset: %s, classes: %s, instances: %d'
					%(self.task_cnt, get_dataset_name(did),','.join(map(str, labels)), len(x_task)))
				self._save_task(x_task, y_task, labels, did, cid)
		else:
			for t, task in enumerate(labels_per_task):
				self.task_cnt += 1
				if self.opt.split_option == 'non_iid':
					idx = np.concatenate([np.where(y[:]==c)[0] for c in task], axis=0)
					random_shuffle(self.opt.global_random_seed, idx)
				else: # overlapped
					def split_instance(c, idx):
						idx = idx[:round(len(idx)/label_map[c])]
						label_map[c] -= 1
						return idx
					idx = np.concatenate([split_instance(c, np.where(y[:]==c)[0]) for c in task], axis=0)
					random_shuffle(self.opt.global_random_seed, idx)
				x_task = x[idx]
				y_task = y[idx]
				syslog(self.client_id, 'task: %d, dataset: %s, classes: %s'
					%(self.task_cnt, get_dataset_name(did),','.join(map(str, task))))
				self._save_task(x_task, y_task, task, did, t)

	def _save_task(self, x_task, y_task, labels, did, tid):
		# if self.opt.multihead:
		train_size_per_class = []
		idx_list = [np.where(y_task[:]==c)[0] for c in labels]
		for i, idx in enumerate(idx_list):
			y_task[idx] = i # reset classes id
			train_size_per_class.append(len(y_task[idx]))
		y_task = tf.keras.utils.to_categorical(y_task, len(labels))
		# else:
		#     y_task = tf.keras.utils.to_categorical(y_task, len(labels))
		pairs = list(zip(x_task, y_task))
		num_examples = len(pairs)
		num_train = int(num_examples*self.seprate_ratio[0]) # split according to ratio
		num_test = int(num_examples*self.seprate_ratio[1])  # split according to ratio

		filename = '{}_{}'.format(get_dataset_name(did), tid)
		_data = {
			'train': pairs[0:num_train],
			'test' : pairs[num_train:num_train+num_test],
			'valid': pairs[num_train+num_test:],
			'classes': labels,
			'name': filename,
			'train_size_per_class': train_size_per_class
		}
		save_task(base_dir=self.base_dir, filename=filename, data=_data)

	def get_task(self, task_id):
		# if self.opt.task_pool == 4:
		#     return self.cgen.get_task(task_id)
		# else:
		task = load_task(self.base_dir, self.tasks[task_id])
		return task.item()

	def cifar100_superclass_label_pair(self):
		CIFAR100_LABELS_LIST = [
			'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
			'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
			'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
			'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
			'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
			'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
			'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
			'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
			'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
			'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
			'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
			'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
			'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
			'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
			'worm'
		]
		sclass = []
		sclass.append(' beaver, dolphin, otter, seal, whale,') 						#aquatic mammals
		sclass.append(' aquarium_fish, flatfish, ray, shark, trout,')				#fish
		sclass.append(' orchid, poppy, rose, sunflower, tulip,')					#flowers
		sclass.append(' bottle, bowl, can, cup, plate,')							#food
		sclass.append(' apple, mushroom, orange, pear, sweet_pepper,')				#fruit and vegetables
		sclass.append(' clock, computer keyboard, lamp, telephone, television,')	#household electrical devices
		sclass.append(' bed, chair, couch, table, wardrobe,')						#household furniture
		sclass.append(' bee, beetle, butterfly, caterpillar, cockroach,')			#insects
		sclass.append(' bear, leopard, lion, tiger, wolf,')							#large carnivores
		sclass.append(' bridge, castle, house, road, skyscraper,')					#large man-made outdoor things
		sclass.append(' cloud, forest, mountain, plain, sea,')						#large natural outdoor scenes
		sclass.append(' camel, cattle, chimpanzee, elephant, kangaroo,')			#large omnivores and herbivores
		sclass.append(' fox, porcupine, possum, raccoon, skunk,')					#medium-sized mammals
		sclass.append(' crab, lobster, snail, spider, worm,')						#non-insect invertebrates
		sclass.append(' baby, boy, girl, man, woman,')								#people
		sclass.append(' crocodile, dinosaur, lizard, snake, turtle,')				#reptiles
		sclass.append(' hamster, mouse, rabbit, shrew, squirrel,')					#small mammals
		sclass.append(' maple_tree, oak_tree, palm_tree, pine_tree, willow_tree,')	#trees
		sclass.append(' bicycle, bus, motorcycle, pickup_truck, train,')			#vehicles 1
		sclass.append(' lawn_mower, rocket, streetcar, tank, tractor,')				#vehicles 2
		labels_pair = [[jj for jj in range(100) if ' %s,'%CIFAR100_LABELS_LIST[jj] in sclass[kk]] for kk in range(20)]
		return labels_pair
