import os
import json
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K

import datetime
from common.utils import *

TASK1 = 1
CURRENT_TASK = 2
CUMULATIVE = 4

class LocalModel(object):

	def __init__(self, client_id, data_info, opt):

		gpus = tf.config.experimental.list_physical_devices('GPU')
		if len(gpus)>0:
			tf.config.experimental.set_virtual_device_configuration(gpus[0],
				  [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024*2.5)])

		self.opt = opt
		self.client_id = client_id
		self.data_info = data_info

		self.log_dir = ''
		self.tasks = []
		self.mem_ratio = []
		self.comm_ratio = []
		self.num_train_list = []
		self.num_test_list = []
		self.num_valid_list = []
		self.x_test_list = []
		self.y_test_list = []
		self.performance_epoch = {}
		self.performance_watch = {}
		self.performance_final = {}
		self.num_classes = []

		self.options = {}
		self.early_stop = False
		self.test_batch_step = 10
		self.filename = 'client-'+str(self.client_id)+get_setting(opt)+'.txt'
		self.metrics = {
			'valid_lss': tf.keras.metrics.Mean(name='valid_lss'),
			'train_lss': tf.keras.metrics.Mean(name='train_lss'),
			'test_lss' : tf.keras.metrics.Mean(name='test_lss'),
			'valid_acc': tf.keras.metrics.CategoricalAccuracy(name='valid_acc'),
			'train_acc': tf.keras.metrics.CategoricalAccuracy(name='train_acc'),
			'test_acc' : tf.keras.metrics.CategoricalAccuracy(name='test_acc')
		}

	def loss(self):
		# must be implemented in the child class
		raise NotImplementedError()

	def get_optimizer(self, lr):
		# must be implemented in the child class
		raise NotImplementedError()

	def get_weights(self):
		# must be implemented in the child class
		raise NotImplementedError()

	def set_weights(self):
		# must be implemented in the child class
		raise NotImplementedError()

	def initialize(self, model_info):
		# must be implemented in the child class
		raise NotImplementedError()

	def build_model(self, model_info):
		# must be implemented in the child class
		raise NotImplementedError()

	def init_on_new_task(self):
		# must be implemented in the child class
		raise NotImplementedError()

	def train_step(self, model, x, y):
		tf.keras.backend.set_learning_phase(1)
		with tf.GradientTape() as tape:
			y_pred = model(x)
			loss = self.loss(model, y, y_pred)
		gradients = tape.gradient(loss, self.trainable_variables)
		if self.opt.model == 1:
			gradients *= 10 # from the original source
		self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
		return y_pred, loss

	def train_one_round(self, current_round, count_rounds, is_last=False):
		self.current_round = current_round
		self.count_rounds = count_rounds
		model = self.models[self.current_task]
		# target_task = 0
		# prev_0 = self.get_variable('adaptive', 0, target_task).numpy()
		for epoch in range(self.opt.num_epochs):
			self.current_epoch = epoch
			self.current_batch = 0
			for i in range(0, len(self.x_train), self.opt.batch_size): # train
				self.current_batch += 1
				if self.opt.num_examples>0 and i+self.opt.batch_size>=self.opt.num_examples:
					x_batch = self.x_train[i:self.opt.num_examples]
					y_batch = self.y_train[i:self.opt.num_examples]
					y_pred, loss = self.train_step(model, x_batch, y_batch)
					break
				else:
					x_batch = self.x_train[i:i+self.opt.batch_size]
					y_batch = self.y_train[i:i+self.opt.batch_size]
					y_pred, loss = self.train_step(model, x_batch, y_batch)
			# compare_weights(prev_0, self.get_variable('adaptive', 0, target_task).numpy())

			# validation
			vlss, vacc = self.validate(self.current_task)
			# Adapt lr
			if vlss<self.lowest_lss:
				self.lowest_lss = vlss
				self.current_lr_patience = self.opt.lr_patience
			else:
				self.current_lr_patience-=1
				if self.current_lr_patience<=0:
					self.current_lr/=self.opt.lr_factor
					syslog(self.client_id, 'task:%d, round:%d (cnt:%d), drop lr => %.10f'
							%(self.current_task, self.current_round, self.count_rounds, self.current_lr))
					if self.current_lr<self.opt.lr_min:
						syslog(self.client_id, 'task:%d, round:%d (cnt:%d), early stop, reached minium lr (%.10f)'
							 %(self.current_task, self.current_round, self.count_rounds, self.current_lr))
						self.early_stop = True
						break
					self.current_lr_patience = self.opt.lr_patience
					self.optimizer = self.get_optimizer(self.current_lr)
			# Epoch-level evalation
			is_last_epoch = (epoch==self.opt.num_epochs-1)
			self.evaluate(is_last=(is_last and is_last_epoch))

	def test_step(self, model, x, y):
		tf.keras.backend.set_learning_phase(0)
		y_pred = model(x)
		loss = self.loss(model, y, y_pred)
		return y_pred, loss

	def validate(self, head_idx):
		for i in range(0, len(self.x_valid), self.opt.batch_size):
			x_batch = self.x_valid[i:i+self.opt.batch_size]
			y_batch = self.y_valid[i:i+self.opt.batch_size]
			y_pred, loss = self.test_step(self.get_model(head_idx), x_batch, y_batch)
			self.add_performance('valid_lss', 'valid_acc', loss, y_batch, y_pred)
		return self.measure_performance('valid_lss', 'valid_acc')

	def evaluate(self, is_last=False):
		for tid in range(len(self.x_test_list)):
			# if tid == self.current_task or tid in self.opt.watch_tasks or is_last:
			lss, acc = self._evaluate(self.x_test_list[tid], self.y_test_list[tid], head_idx=tid)
			if tid not in self.performance_watch:
				self.performance_watch[tid] = [] # initialize
			self.performance_watch[tid].append(acc)
			if tid == self.current_task:
				if tid not in self.performance_epoch:
					self.performance_epoch[tid] = [] # initialize
				self.performance_epoch[tid].append(acc)
			if is_last:
				if tid not in self.performance_final:
					self.performance_final[tid] = [] # initialize
				self.performance_final[tid].append(acc)
			syslog(self.client_id, 'task:%d (%s), round:%d (cnt:%d), epoch:%d, test_lss:%f, test_acc:%f (task_%d)'
				 %(self.current_task, self.tasks[tid], self.current_round, self.count_rounds, self.current_epoch, lss, acc, tid))

	def _evaluate(self, x_test, y_test, head_idx):
		model = self.get_model(head_idx, test=True)
		for i in range(0, len(x_test), self.opt.batch_size):
			x_batch = x_test[i:i+self.opt.batch_size]
			y_batch = y_test[i:i+self.opt.batch_size]
			y_pred, loss = self.test_step(model, x_batch, y_batch)
			self.add_performance('test_lss', 'test_acc', loss, y_batch, y_pred)
		return self.measure_performance('test_lss', 'test_acc')

	def add_performance(self, lss_name, acc_name, loss, y_true, y_pred,):
		self.metrics[lss_name](loss)
		self.metrics[acc_name](y_true, y_pred)

	def measure_performance(self, lss_name, acc_name):
		lss = float(self.metrics[lss_name].result()) # tensor to float
		acc = float(self.metrics[acc_name].result()) # tensor to float
		self.metrics[lss_name].reset_states()
		self.metrics[acc_name].reset_states()
		return lss, acc

	def set_task(self, task_id, data):

		train = data['train']
		test  = data['test']
		valid = data['valid']

		self.x_train = np.array([tup[0] for tup in train])
		self.y_train = np.array([tup[1] for tup in train])
		self.x_test  = np.array([tup[0] for tup in test])
		self.y_test  = np.array([tup[1] for tup in test])
		self.x_valid = np.array([tup[0] for tup in valid])
		self.y_valid = np.array([tup[1] for tup in valid])

		self.x_test_list.append(self.x_test)
		self.y_test_list.append(self.y_test)

		self.early_stop = False
		self.lowest_lss = np.inf
		self.current_lr = self.opt.lr
		self.current_lr_patience = self.opt.lr_patience

		self.current_task = task_id
		self.tasks.append(data['name'])
		self.num_classes.append(len(data['classes']))

		self.train_size_per_class = data['train_size_per_class']
		self.num_train_list.append(len(self.x_train))
		self.num_test_list.append(len(self.x_test))
		self.num_valid_list.append(len(self.x_valid))

		if self.current_task > 0:
			self.init_on_new_task()

	def get_model(self, head_idx, test=False):
		if False:
			types = ['mask', 'bias', 'adaptive', 'adaptive_kb', 'atten']
			for lid in range(4):
				for _v in types:
					vv = self.get_variable(_v, lid, head_idx)
					print('tid %d, lid %d, vartype: %s, meanabs:%.6f'%(head_idx, lid, _v, np.mean(np.abs(vv.numpy()))))
		return self.models[head_idx]

	def get_test_size(self, task_idx):
		if task_idx == -1: # if total
			total = [len(x) for x in self.x_test_list]
			return np.sum(total)
		else:
			return len(self.x_test_list[task_idx])

	def  get_options(self):
		if len(self.options.keys())>0:
			return self.options
		else:
			self.options = {k:v for k, v in vars(self.opt).items()}
			return self.options

	def write_current_performances(self):
		write_file(self.log_dir, self.filename, {
			'client_id'         : self.client_id,
			'performance'       : self.performance_epoch,
			'performance_watch' : self.performance_watch,
			'performance_final' : self.performance_final,
			'mem_ratio'         : self.mem_ratio,
			'comm_ratio'        : self.comm_ratio,
			'options'           : self.get_options(),
			'data_info'         : self.data_info,
			'task_info'         : self.tasks,
			'num_examples'      : {
				'train': self.num_train_list,
				'test': self.num_test_list,
				'valid': self.num_valid_list,
			}
		})
