import os
import argparse
import yaml
import numpy as np
import matplotlib.pyplot as plt

import torch

from utils import set_seed, setup_files, plot_results, save_results_in_to_files

from src.clients.client_scheme1_baseline import Client
from src.server.server_scheme1_baseline import Server

from src.datasets.distribute_fmnist import distributed_fmnist
from src.datasets.distribute_mnist import distributed_mnist
from src.datasets.distribute_adult_income import distributed_adult_income
from src.datasets.dataset_utils import read_client_data



def create_clients(num_clients, local_epochs, dataset_name, epsilons_distribution, delta, num_rounds, lr, bs, rate, cprint_rounds, device):
    ''' generate train and test data for every client and initialize client objects '''
    clients = []
    for clientId, epsilon in zip(range(num_clients), epsilons_distribution):
        train_data = read_client_data(dataset_name, clientId, is_train=True)
        test_data = read_client_data(dataset_name, clientId, is_train=False)
        clients.append(Client(clientId, train_data, test_data, local_epochs, epsilon, delta, num_rounds, num_clients, lr, bs, rate, dataset_name, cprint_rounds, device))
    return clients



def main(config_path):
    ''' Facilitate the federated learning process by creating clients, initializing server,
    performing federated learning and evaluating global test accuracy of the server model '''
    set_seed()
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    flag_is_niid = config['niid']
    num_clients = config['num_clients']
    num_rounds = config['rounds']
    num_epochs = config['epochs']
    sample_rate = config['sample_rate']
    clip_value = config['clip_value']
    sigma = config['sigma']
    epsilons = [config['epsilon1'], config['epsilon2'], config['epsilon3']]
    dataset_name = config['dataset']

    num_clients_eps1 = int(np.round(0.34 * num_clients))
    num_clients_eps2 = int(np.round(0.43 * num_clients))
    num_clients_eps3 = num_clients - (num_clients_eps1 + num_clients_eps2)

    epsilons_distribution = ([epsilons[0]] * num_clients_eps1 +
                             [epsilons[1]] * num_clients_eps2 +
                             [epsilons[2]] * num_clients_eps3)


    if dataset_name == 'fmnist':
        distributed_fmnist(num_clients,
                           flag_is_niid,
                           config['balance'],
                           config['partition'],
                           config['data_points_per_client'])
    elif dataset_name == 'mnist':
        distributed_mnist(num_clients,
                          flag_is_niid,
                          config['balance'],
                          config['partition'],
                          config['data_points_per_client'])
    elif dataset_name == 'adult_income':
        distributed_adult_income(num_clients,
                          flag_is_niid,
                          config['balance'],
                          config['partition'],
                          config['data_points_per_client'])

    runs = config['runs']
    for run in range(runs):
        np.random.shuffle(epsilons_distribution)
        clients = create_clients(num_clients,
                                 num_epochs,
                                 dataset_name,
                                 epsilons_distribution,
                                 config['delta'],
                                 num_rounds,
                                 config['lr'],
                                 config['bs'],
                                 config['rate'],
                                 config['cprint_rounds'],
                                 config['device'])

        server = Server(clients,
                        num_clients,
                        num_rounds,
                        sample_rate,
                        clip_value,
                        config['delta'],
                        config['lr'],
                        config['bs'],
                        config['rate'],
                        run,
                        dataset_name,
                        config['sprint_rounds'],
                        config['device'])

        (plot_clients_test_acc_local_model_local_data,
         plot_clients_test_acc_global_model_local_data,
         plot_clients_test_acc_local_model_global_data,
         plot_clients_norm,
         plot_clients_train_loss,
         plot_test_acc,
         plot_local_clip_values,
         plot_global_clip_values,
         plot_clients_privacy_spent,
         plot_clients_per_round_privacy_spent,
         plot_clients_per_round_noise_multiplier,
         plot_clients_per_round_rdp_privacy_spent,
         plot_clients_per_round_best_alpha) \
            = server.federated_learning()
            
        save_results_in_to_files(plot_clients_test_acc_local_model_local_data,
                         plot_clients_test_acc_global_model_local_data,
                         plot_clients_test_acc_local_model_global_data,
                         plot_test_acc,
                         plot_clients_norm,
                         plot_global_clip_values,
                         plot_local_clip_values,
                         plot_clients_train_loss,
                         plot_clients_privacy_spent,
                         plot_clients_per_round_privacy_spent,
                         plot_clients_per_round_noise_multiplier,
                         plot_clients_per_round_rdp_privacy_spent,
                         plot_clients_per_round_best_alpha,
                         num_clients,
                         clip_value,
                         epsilons,
                         sigma,
                         sample_rate,
                         '1',
                         dataset_name,
                         run,
                         num_rounds,
                         num_epochs)

        if run == 0:
            avg_plot_clients_test_acc_local_model_local_data = \
                np.array(plot_clients_test_acc_local_model_local_data) / runs
            avg_plot_clients_test_acc_global_model_local_data = \
                np.array(plot_clients_test_acc_global_model_local_data) / runs
            avg_plot_clients_test_acc_local_model_global_data = \
                np.array(plot_clients_test_acc_local_model_global_data) / runs
            avg_plot_clients_norm = np.array(torch.tensor(plot_clients_norm).to(torch.device("cpu"))) / runs
            avg_plot_clients_train_loss = np.array(torch.tensor(plot_clients_train_loss).to(torch.device("cpu"))) / runs
            avg_plot_test_acc = np.array(plot_test_acc) / runs
            avg_plot_local_clip_values = np.array(plot_local_clip_values) / runs
            avg_plot_global_clip_values = np.array(plot_global_clip_values) / runs
            avg_plot_clients_privacy_spent = np.array(plot_clients_privacy_spent) / runs
            avg_plot_clients_per_round_privacy_spent = np.array(plot_clients_per_round_privacy_spent) / runs
            avg_plot_clients_per_round_noise_multiplier = np.array(plot_clients_per_round_noise_multiplier) / runs
            avg_plot_clients_per_round_rdp_privacy_spent = np.array(plot_clients_per_round_rdp_privacy_spent) / runs
            avg_plot_clients_per_round_best_alpha = np.array(plot_clients_per_round_best_alpha) / runs
        else:
            avg_plot_clients_test_acc_local_model_local_data += np.array(
                plot_clients_test_acc_local_model_local_data) / runs
            avg_plot_clients_test_acc_global_model_local_data += np.array(
                plot_clients_test_acc_global_model_local_data) / runs
            avg_plot_clients_test_acc_local_model_global_data += np.array(
                plot_clients_test_acc_local_model_global_data) / runs
            avg_plot_clients_norm += np.array(torch.tensor(plot_clients_norm).to(torch.device("cpu"))) / runs
            avg_plot_clients_train_loss += np.array(
                torch.tensor(plot_clients_train_loss).to(torch.device("cpu"))) / runs
            avg_plot_test_acc += np.array(plot_test_acc) / runs
            avg_plot_local_clip_values += np.array(plot_local_clip_values) / runs
            avg_plot_global_clip_values += np.array(plot_global_clip_values) / runs
            avg_plot_clients_privacy_spent += np.array(plot_clients_privacy_spent) / runs
            avg_plot_clients_per_round_privacy_spent += np.array(plot_clients_per_round_privacy_spent) / runs
            avg_plot_clients_per_round_noise_multiplier += np.array(plot_clients_per_round_noise_multiplier) / runs
            avg_plot_clients_per_round_rdp_privacy_spent += np.array(plot_clients_per_round_rdp_privacy_spent) / runs
            avg_plot_clients_per_round_best_alpha += np.array(plot_clients_per_round_best_alpha) / runs


    (file_clients_test_acc_local_model_local_data,
     file_clients_test_acc_global_model_local_data,
     file_clients_test_acc_local_model_global_data,
     file_clients_norm,
     file_clients_train_loss,
     file_test_acc,
     file_clients_per_round_privacy_spent,
     file_clients_per_round_noise_multiplier,
     file_clients_per_round_rdp_privacy_spent,
     file_clients_per_round_best_alpha,
     file_clients_privacy_spent) = setup_files(num_clients,
                                               runs,
                                               num_rounds,
                                               num_epochs,
                                               sigma,
                                               clip_value,
                                               epsilons,
                                               sample_rate,
                                               '1',
                                               flag_is_niid,
                                               dataset_name)

    plot_results(avg_plot_clients_test_acc_local_model_local_data,
                 avg_plot_clients_test_acc_global_model_local_data,
                 avg_plot_clients_test_acc_local_model_global_data,
                 avg_plot_test_acc,
                 avg_plot_clients_norm,
                 avg_plot_global_clip_values,
                 avg_plot_local_clip_values,
                 avg_plot_clients_train_loss,
                 avg_plot_clients_privacy_spent,
                 avg_plot_clients_per_round_privacy_spent,
                 avg_plot_clients_per_round_noise_multiplier,
                 avg_plot_clients_per_round_rdp_privacy_spent,
                 avg_plot_clients_per_round_best_alpha,
                 file_clients_test_acc_local_model_local_data,
                 file_clients_test_acc_global_model_local_data,
                 file_clients_test_acc_local_model_global_data,
                 file_clients_norm,
                 file_clients_train_loss,
                 file_test_acc,
                 file_clients_privacy_spent,
                 file_clients_per_round_privacy_spent,
                 file_clients_per_round_noise_multiplier,
                 file_clients_per_round_rdp_privacy_spent,
                 file_clients_per_round_best_alpha,
                 num_clients,
                 clip_value,
                 epsilons,
                 sigma,
                 sample_rate,
                 '1',
                 dataset_name)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Federated Learning Main')
    parser.add_argument('--config', type=str, default='config.yaml', help='Path to the configuration file')
    args = parser.parse_args()

    main(args.config)

