import torch
import torchvision
import torchvision.transforms as transforms
import Dataset
import copy
import torch.multiprocessing as mp
import model
import numpy as np
import time
from worker import Worker
from server import Server
import random
import attacks
import sys
import aggregators
import visualization
from numpy import savetxt
import os
import json

sys.path.append('../ByzLibrary')

from robust_aggregators import RobustAggregator
from byz_attacks import ByzantineAttack


random_seed = 1

torch.manual_seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = {
	'T' : [500],
	'n' : [150],
	'f': [0, 30],
	'n_sampled': [i*5+1 for i in range(12)],
	'server_lr' : [1],
	'workers_lr' : [0.1],
	'milestones' : [300,400,460],
	'gamma' : [0.2],
	'attack' : ['SF'],
	'agg' : ['trmean'],
	'batch_size' : [8],
	'nb_local_steps' : [10],
	'device' : device, 
	'nb_run' : 10, 
	'seed' : random_seed
}

save_folder = './save_fig2'

if not os.path.isdir(save_folder):
	os.mkdir(save_folder)

for T in config['T']:
	for n in config['n']:
		for f in config['f']:
			for n_sampled in config['n_sampled']:
				for server_lr in config['server_lr']:
					for workers_lr in config['workers_lr']:
						for gamma in config['gamma']:
							for attack in config['attack']:
								for agg in config['agg']:
									for batch_size in config['batch_size']:
										for nb_local_steps in config['nb_local_steps']:

											config_tmp = {
												'T' : T,
												'n' : n,
												'f': f,
												'n_sampled': n_sampled,
												'server_lr' : server_lr,
												'workers_lr' : workers_lr,
												'milestones' : config['milestones'],
												'gamma' : gamma,
												'attack' : attack,
												'agg' : agg,
												'batch_size' : batch_size,
												'nb_local_steps' : nb_local_steps,
												'device' : device, 
												'nb_run' : config['nb_run'], 
												'seed' : random_seed
											}

											title = '_'.join([key+'_'+str(value) for key, value in config_tmp.items()])

											if not os.path.isdir(save_folder+'/'+title):
												os.mkdir(save_folder+'/'+title)

											save_folder_tmp = save_folder+'/'+title+'/'

											if not os.path.isdir(save_folder_tmp+'/data'):
												os.mkdir(save_folder_tmp+'/data')
											if not os.path.isdir(save_folder_tmp+'/plots'):
												os.mkdir(save_folder_tmp+'/plots')

											with open(save_folder_tmp+'config.json', 'w') as fp:
												tmp = device
												config_tmp['device'] = str(tmp)
												json.dump(config_tmp, fp, indent=4)
												config_tmp['device'] = tmp

											tab_accuracy = [[] for i in range(config['nb_run'])]
											for run in range(config['nb_run']):
												nb_honest = n - f
												f_sampled = np.max([0, n_sampled//2-1])
												print(f_sampled, attack)
												index_milestone = 0

												server = Server(config_tmp)
												server_weights = server.get_model_parameters()

												workers  = [Worker(id = i, dataset = 'femnist', initial_weights = server_weights, honest = i < nb_honest, config = config_tmp) for i in range(n)]

												aggregator = RobustAggregator('nnm', agg, 1, f_sampled, server.model_size, config['device'])
												byz_attack = ByzantineAttack(attack, 0, server.model_size, config['device'], 200, agg)

												server_lr = config_tmp['server_lr']
												

												for t in range(T):
													if t == config['milestones'][index_milestone]:
														for worker in workers:
															worker.update_learning_rate(gamma)
															server_lr = server_lr * gamma
														if index_milestone < len(config['milestones'])-1:
															index_milestone = index_milestone + 1

													params = []
													batch_of_workers = random.sample(workers, n_sampled)
													c = 0
													
													server_weights = server.get_model_parameters()

													for worker in batch_of_workers:
														if worker.is_byzantine():
															c = c+1
															continue
														
														worker.set_model_parameters(server_weights)

														grad = worker.do_local_steps(nb_local_steps)
														params.append(grad)
													if c > f_sampled:
														print('More byzantine than allowed in a subbatch of workers:', c)
														server.set_model_parameters_with_flat_tensor(torch.zeros(server.model_size).to(config_tmp['device']))
													else:
														byz_attack.nb_real_byz = c
														byz_params = byz_attack.generate_byzantine_vectors(params, None, t)

														params = params + byz_params

														fed_avr_param = aggregator.aggregate(params)

														server.set_model_gradient_with_flat_tensor(fed_avr_param)
														server.step()

													acc = []
													server_weights = server.get_model_parameters()
													for worker in workers:
														if worker.is_byzantine():
															continue
														worker.set_model_parameters(server_weights)
														acc.append(worker.evaluate())
													tab_accuracy[run].append(np.mean(acc))
													print('Step:', t, '- Accuracy:', np.mean(acc), '- Number of Byzantine in the batch:', c)

												savetxt(save_folder_tmp+'data/Accuracies_'+str(run)+'.csv', tab_accuracy[run], delimiter=',')









































