import copy
import numpy as np
import tensorflow as tf
from common.utils import *

import tensorflow.keras.layers as layers
from models.global_model import GlobalModel

class GlobalFedWeIT(GlobalModel):

	def __init__(self, opt):
		super(GlobalFedWeIT, self).__init__(opt)
		self.opt = opt
		self.current_task = 0
		self.input_shape = (32, 32, 3)
		if self.opt.base_architect == 0:
			self.shapes = [
				(5, 5, 3, 20),
				(5, 5, 20, 50),
				(3200, 800),
				(800, 500),
			]
		elif self.opt.base_architect == 1:
			self.shapes = [
				(4, 4, 3, 64),
				(3, 3, 64, 128),
				(2, 2, 128, 256),
				(1024, 1024),
				(1024, 1024),
			]
			"""
			self.shapes = [
				(4, 4, 3, 64),
				(3, 3, 64, 128),
				(2, 2, 128, 256),
				(4096, 512),
				(512, 512),
			]
			"""
		self.initialize_weights()

	def initialize_weights(self):
		self.weights = []
		self.client_adapts = []
		self.initializer=tf.keras.initializers.VarianceScaling(seed=self.opt.global_random_seed)
		for i in range(len(self.shapes)):
			self.weights.append(self.initializer(self.shapes[i]).numpy().tolist())

	def get_weights(self):
		print("#### Global: get_weights")
		return self.weights

	def get_adapts(self):
		print("#### Global: get_adapts")
		return self.client_adapts

	def set_weights(self, weights):
		print("#### Global: set_weights")
		self.weights = weights

	def set_adapts(self, client_adapts):
		print("#### Global: set_adapts")
		self.client_adapts = client_adapts


	def update_weights(self, responses):
		print("#### Global: update_weights")
		#client_adapts = [pickle_string_to_obj(resp['client_adapts']) for resp in responses]
		client_both = [pickle_string_to_obj(resp['client_both']) for resp in responses]
		#pdb.set_trace()
		#self.set_adapts(client_adapts)
		client_w = [cb[0] for cb in client_both]
		client_a = [cb[1] for cb in client_both]
		#client_weights = [resp['client_weights'] for resp in responses]
		client_sizes = [resp['train_size'] for resp in responses]
		client_masks = [resp['client_masks'] for resp in responses] if self.opt.sparse_comm else []
		if self.opt.fed_method == 0:
			self.apply_federated_average(client_w, client_sizes, client_masks)
		elif self.opt.fed_method == 1:
			self.apply_federated_prox(client_w, client_sizes, client_masks)

		if self.opt.server_sparse_comm:
			self.sparse_communication()
		self.calculate_comm_costs(self.get_weights())
		self.set_adapts(client_a)

	def get_info(self):
		return {
			'shapes': self.shapes,
			'input_shape': self.input_shape,
			'shared_params': self.weights
		}
