import torch
import torchvision
import numpy as np
from worker import Worker
from server import Server
import random
from numpy import savetxt
import os

sys.path.append('../ByzLibrary')

from robust_aggregators import RobustAggregator
from byz_attacks import ByzantineAttack

random_seed = 1

torch.manual_seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = {
	'T' : [500],
	'n' : [150],
	'f': [15],
	'server_lr' : [1],
	'workers_lr' : [0.1],
	'milestones' : [300,400,460],
	'gamma' : [0.2],
	'attack' : ['SF', 'FOE','ALIE', 'mimic'],
	'agg' : ['trmean'],
	'batch_size' : [8],
	'nb_local_steps' : [10],
	'device' : device, 
	'nb_run' : 10, 
	'seed' : random_seed,
	'p': 0.99
}

beta = config['f'][0]/config['n'][0]
C_beta = (0.5)*np.log(0.5/beta) + (1 - 0.5)*np.log((1-0.5)/(1-beta))
n_ = np.min([config['n'][0], int(np.ceil(np.log(4*config['T'][0]/(1-config['p']))/C_beta)+2)])
config['n_sampled'] = [n_]

save_folder = './save_fig3'

if not os.path.isdir(save_folder):
	os.mkdir(save_folder)

for T in config['T']:
	for n in config['n']:
		for f in config['f']:
			for n_sampled in config['n_sampled']:
				for server_lr in config['server_lr']:
					for workers_lr in config['workers_lr']:
						for gamma in config['gamma']:
							for attack in config['attack']:
								for agg in config['agg']:
									for batch_size in config['batch_size']:
										for nb_local_steps in config['nb_local_steps']:


											config_tmp = {
												'T' : T,
												'n' : n,
												'f': f,
												'n_sampled': n_sampled,
												'server_lr' : server_lr,
												'workers_lr' : workers_lr,
												'milestones' : config['milestones'],
												'gamma' : gamma,
												'attack' : attack,
												'agg' : agg,
												'batch_size' : batch_size,
												'nb_local_steps' : nb_local_steps,
												'device' : device, 
												'nb_run' : config['nb_run'], 
												'seed' : random_seed
											}

											title = '_'.join([key+'_'+str(value) for key, value in config_tmp.items()])

											if not os.path.isdir(save_folder+'/'+title):
												os.mkdir(save_folder+'/'+title)

											save_folder_tmp = save_folder+'/'+title+'/'

											if not os.path.isdir(save_folder_tmp+'/data'):
												os.mkdir(save_folder_tmp+'/data')
											if not os.path.isdir(save_folder_tmp+'/plots'):
												os.mkdir(save_folder_tmp+'/plots')

											tab_accuracy = [[] for i in range(config['nb_run'])]
											for run in range(config['nb_run']):
												nb_honest = n - f
												beta = f/n
												betas = np.linspace(beta, 0.5, 100)[:-1]
												best_beta = 0.5
												p = config['p']
												for beta_ in betas:
													C_beta = (beta_)*np.log(beta_/beta) + (1 - beta_)*np.log((1-beta_)/(1-beta))
													RHS = np.log(T/(1-p))/n_sampled
													# print('beta_=', beta_, 'best_beta', best_beta, ' - ', C_beta, 'vs', RHS, C_beta >= RHS, best_beta > beta_)
													if C_beta >= RHS and best_beta > beta_:
														best_beta = beta_

												f_sampled = int(best_beta*n_sampled)

												index_milestone = 0

												server = Server(config_tmp)
												server_weights = server.get_model_parameters()

												workers  = [Worker(id = i, dataset = 'femnist', initial_weights = server_weights, honest = i < nb_honest, config = config_tmp) for i in range(n)]

												aggregator = RobustAggregator('nnm', agg, 1, f_sampled, server.model_size, config['device'])
												byz_attack = ByzantineAttack(attack, 0, server.model_size, config['device'], 200, aggregator)

												server_lr = config_tmp['server_lr']
												

												for t in range(T):
													if t == config['milestones'][index_milestone]:
														for worker in workers:
															worker.update_learning_rate(gamma)
														if index_milestone < len(config['milestones'])-1:
															index_milestone = index_milestone + 1
													
													params = []
													batch_of_workers = random.sample(workers, n_sampled)
													c = 0
													
													server_weights = server.get_model_parameters()

													for worker in batch_of_workers:
														if worker.is_byzantine():
															c = c+1
															continue
														
														worker.set_model_parameters(server_weights)

														grad = worker.do_local_steps(nb_local_steps)
														params.append(grad)
													if c > f_sampled:
														print('More byzantine than allowed in a subbatch of workers:', c)
														server.set_model_parameters_with_flat_tensor(torch.zeros(server.model_size).to(config_tmp['device']))
													else:
														byz_attack.nb_real_byz = c
														byz_params = byz_attack.generate_byzantine_vectors(params, None, t)

														params = params + byz_params

														fed_avr_param = aggregator.aggregate(params)

														server.set_model_gradient_with_flat_tensor(fed_avr_param)
														server.step()

													acc = []
													server_weights = server.get_model_parameters()
													for worker in workers:
														if worker.is_byzantine():
															continue
														worker.set_model_parameters(server_weights)
														acc.append(worker.evaluate())
													tab_accuracy[run].append(np.mean(acc))
													print('Step:', t, '- Accuracy:', np.mean(acc), '- Number of Byzantine in the batch:', c)

												savetxt(save_folder_tmp+'data/Accuracies_'+str(run)+'.csv', tab_accuracy[run], delimiter=',')








































