import torch
import Dataset
import model
import resnet


class Worker():
	def __init__(self, id, dataset, initial_weights, config, honest = True):
		'''
			id: int
			dataset: string
			initial_weights: list of tensors
		'''

		self.id = id

		self.honest = honest

		if dataset.lower() == "femnist":
			#Datasets
			self.training_dataset = Dataset.workerFEMNIST(root='./data', train=True, download=False, transform=None, client_id=self.id)
			self.testing_dataset = Dataset.workerFEMNIST(root='./data', train=False, download=False, transform=None, client_id=self.id)

			#Dataloaders
			self.training_dataloader = torch.utils.data.DataLoader(self.training_dataset, batch_size=config['batch_size'], shuffle=True)
			self.testing_dataloader = torch.utils.data.DataLoader(self.testing_dataset, batch_size=config['batch_size'], shuffle=True)

			self.nb_labels = 62

		self.model = model.FemnistCNN()
		if initial_weights is not None:
			self.set_model_parameters(initial_weights)

		self.model = self.model.cuda()

		# self.model = torch.nn.DataParallel(self.model, device_ids = [0,1])

		self.optimizer = torch.optim.SGD(self.model.parameters(), lr=config['workers_lr'], weight_decay = 1e-4)
		# self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[200,400,520], gamma=0.2, last_epoch=-1)

		self.criterion = torch.nn.NLLLoss().cuda()

		self.label_flipping = False

		self.momentum = torch.zeros(self.flatten(initial_weights).shape).cuda()

		self.beta = 0.90

	def update_learning_rate(self, gamma):
		for g in self.optimizer.param_groups:
			g['lr'] = g['lr']*gamma

	def enable_label_flipping(self):
		self.label_flipping = True

	def get_id(self):
		return self.id

	def is_honest(self):
		return self.honest

	def is_byzantine(self):
		return not self.honest

	def flatten(self, list_of_tensor):
		return torch.cat(tuple(tensor.view(-1) for tensor in list_of_tensor))

	def unflatten(self, flat_tensor, list_of_tensor):
		c = 0
		returned_list = [torch.zeros(tensor.shape) for tensor in list_of_tensor]
		for i, tensor in enumerate(list_of_tensor):
			count = torch.numel(tensor.data)
			returned_list[i].data = flat_tensor[c:c + count].view(returned_list[i].data.shape)
			c = c + count
		return returned_list
	def set_model_parameters(self, initial_weights):
		'''
			initial_weights: list of tensors
		'''
		for j, param in enumerate(self.model.parameters()):
			param.data = initial_weights[j].data.clone().detach()

	def set_model_gradient(self, flat_gradient):
		'''
			flat_gradient: flat tensor
		'''
		gradients = self.unflatten(flat_gradient, [param.grad for param in self.model.parameters()])

		for j, param in enumerate(self.model.parameters()):
			param.grad.data = gradients[j].data.clone().detach()

	def get_model_parameters(self):
		return [param for param in self.model.parameters()]

	def get_flatten_model_parameters(self):
		return self.flatten([param for param in self.model.parameters()])

	def compute_gradient(self):
		self.model.train()
		self.optimizer.zero_grad()

		inputs, targets = next(iter(self.training_dataloader))
		inputs, targets = inputs.cuda(), targets.cuda()

		if self.label_flipping == True:
			targets = targets.sub(self.nb_labels - 1).mul(-1)

		outputs = self.model(inputs)

		loss = self.criterion(outputs, targets)

		loss.backward()

		grad = self.flatten([param.grad.data for param in self.model.parameters()])

		return grad

	def optimizer_step(self):
		self.optimizer.step()

	def do_local_steps(self, nb_local_steps):

		old_model = self.get_flatten_model_parameters().clone().detach()

		self.model.train()

		for i in range(nb_local_steps):
			self.optimizer.zero_grad()

			inputs, targets = next(iter(self.training_dataloader))
			inputs, targets = inputs.cuda(), targets.cuda()

			if self.label_flipping == True:
				targets = targets.sub(self.nb_labels - 1).mul(-1)

			outputs = self.model(inputs)

			loss = self.criterion(outputs, targets)

			# print(loss)
			loss.backward()

			# grad = self.flatten([param.grad.data for param in self.model.parameters()])

			# self.momentum = self.beta * self.momentum + (1-self.beta)* grad.clone().detach()

			# self.set_model_gradient(self.momentum)

			self.optimizer.step()

		return old_model - self.get_flatten_model_parameters().clone().detach()

	def evaluate(self):
		with torch.no_grad():
			total = 0
			correct = 0 
			for data in self.testing_dataloader:
				inputs, targets = data
				inputs, targets = inputs.cuda(), targets.cuda()
				outputs = self.model(inputs)
				_, predicted = torch.max(outputs.data, 1)
				total += targets.size(0)
				correct += (predicted == targets).sum().item()
		return 100*(correct/total)










