import argparse
import yaml
import numpy as np
import pandas as pd

import torch

from utils import set_seed, setup_files, plot_results, save_results_in_to_files

from src.clients.client_scheme2_time_adaptive_dpfl import Client
from src.server.server_scheme2_time_adaptive_dpfl import Server

from src.datasets.distribute_fmnist import distributed_fmnist
from src.datasets.distribute_mnist import distributed_mnist
from src.datasets.distribute_adult_income import distributed_adult_income
from src.datasets.dataset_utils import read_client_data



def create_clients(num_clients, local_epochs, dataset_name,
                   epsilons_distribution, delta, num_rounds, lr, bs, rate,
                   cprint_rounds, privacy_states, saving_sample_rates_distribution, device):
    clients = []
    for clientId in range(num_clients):
        epsilon = epsilons_distribution[clientId]
        privacy_state = privacy_states[clientId]
        saving_sample_rate = saving_sample_rates_distribution[clientId]
        train_data = read_client_data(dataset_name, clientId, is_train=True)
        test_data = read_client_data(dataset_name, clientId, is_train=False)
        clients.append(
            Client(clientId, train_data, test_data, local_epochs, epsilon, delta, num_rounds, num_clients, lr, bs, rate,
                   dataset_name, cprint_rounds, privacy_state, saving_sample_rate, device))
    return clients

def set_privacy_state_random_save(privacy_spending, rounds):
    spend_rounds = int(np.round(privacy_spending * rounds))
    save_rounds = rounds - spend_rounds
    all_rounds = np.arange(rounds)
    save_rounds_indices = np.random.choice(all_rounds, size=save_rounds, replace=False)
    privacy_state = np.array(['spend'] * rounds)
    privacy_state[save_rounds_indices] = 'save'
    privacy_state = privacy_state.tolist()
    print(f"save round indices: {save_rounds_indices}")
    return privacy_state

def set_privacy_state_first_save(privacy_spending, rounds):
    spend_rounds = int(np.round(privacy_spending * rounds))
    save_rounds = rounds - spend_rounds
    save_rounds_indices = np.arange(0, save_rounds)
    privacy_state = np.array(['spend'] * rounds)
    privacy_state[save_rounds_indices] = 'save'
    privacy_state = privacy_state.tolist()
    print(f"save round indices: {save_rounds_indices}")
    return privacy_state


def main(config_path):
    set_seed()
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    flag_is_niid = config['niid']
    num_clients = config['num_clients']
    num_rounds = config['rounds']
    num_epochs = config['epochs']
    sample_rate = config['sample_rate']
    clip_value = config['clip_value']
    sigma = config['sigma']
    epsilons = [config['epsilon1'], config['epsilon2'], config['epsilon3']]
    dataset_name = config['dataset']
    when_save = config['when_save']
    spending_rate1 = config['privacy_spending1']
    spending_rate2 = config['privacy_spending2']
    spending_rate3 = config['privacy_spending3']

    if dataset_name == 'fmnist':
        distributed_fmnist(num_clients,
                           flag_is_niid,
                           config['balance'],
                           config['partition'],
                           config['data_points_per_client'])
    elif dataset_name == 'mnist':
        distributed_mnist(num_clients,
                          flag_is_niid,
                          config['balance'],
                          config['partition'],
                          config['data_points_per_client'])
    elif dataset_name == 'adult_income':
        distributed_adult_income(num_clients,
                                 flag_is_niid,
                                 config['balance'],
                                 config['partition'],
                                 config['data_points_per_client'])

    num_clients_eps1 = int(np.round(0.34 * num_clients))
    num_clients_eps2 = int(np.round(0.43 * num_clients))
    num_clients_eps3 = num_clients - (num_clients_eps1 + num_clients_eps2)

    epsilons_distribution = ([epsilons[0]] * num_clients_eps1 +
                             [epsilons[1]] * num_clients_eps2 +
                             [epsilons[2]] * num_clients_eps3)

    sample_rates = [config['srate1'], config['srate2'], config['srate3']]
    title = '4' + "_" + when_save + "_" + "{:.2f}".format(spending_rate1) + "_" + "{:.1f}".format(
        sample_rates[0]) + "_" + "{:.1f}".format(sample_rates[1]) + "_" + "{:.1f}".format(sample_rates[2])
    sample_rates_distribution = ([sample_rates[0]] * num_clients_eps1 +
                                 [sample_rates[1]] * num_clients_eps2 +
                                 [sample_rates[2]] * num_clients_eps3)

    if when_save == 'first':
        privacy_states1 = set_privacy_state_first_save(spending_rate1, num_rounds)
        privacy_states2 = set_privacy_state_first_save(spending_rate2, num_rounds)
        privacy_states3 = set_privacy_state_first_save(spending_rate3, num_rounds)

    elif when_save == 'random':
        privacy_states1 = set_privacy_state_random_save(spending_rate1, num_rounds)
        privacy_states2 = set_privacy_state_random_save(spending_rate2, num_rounds)
        privacy_states3 = set_privacy_state_random_save(spending_rate3, num_rounds)

    privacy_states = ([privacy_states1] * num_clients_eps1 +
                      [privacy_states2] * num_clients_eps2 +
                      [privacy_states3] * num_clients_eps3)

    indices = np.arange(len(sample_rates_distribution))
    runs = config['runs']
    for run in range(runs):
        np.random.shuffle(indices)
        sample_rates_distribution = [sample_rates_distribution[idx] for idx in indices]
        epsilons_distribution = [epsilons_distribution[idx] for idx in indices]
        privacy_states = [privacy_states[idx] for idx in indices]

        clients = create_clients(num_clients,
                                 num_epochs,
                                 dataset_name,
                                 epsilons_distribution,
                                 config['delta'],
                                 num_rounds,
                                 config['lr'],
                                 config['bs'],
                                 config['rate'],
                                 config['cprint_rounds'],
                                 privacy_states,
                                 sample_rates_distribution,
                                 config['device'])

        server = Server(clients,
                        num_clients,
                        num_rounds,
                        sample_rate,
                        clip_value,
                        config['delta'],
                        config['lr'],
                        config['bs'],
                        config['rate'],
                        run,
                        dataset_name,
                        config['sprint_rounds'],
                        privacy_states,
                        sample_rates_distribution,
                        config['device'])

        (plot_clients_test_acc_local_model_local_data,
         plot_clients_test_acc_global_model_local_data,
         plot_clients_test_acc_local_model_global_data,
         plot_clients_norm,
         plot_clients_train_loss,
         plot_test_acc,
         plot_local_clip_values,
         plot_global_clip_values,
         plot_clients_privacy_spent,
         plot_clients_per_round_privacy_spent,
         plot_clients_per_round_noise_multiplier,
         plot_clients_per_round_rdp_privacy_spent,
         plot_clients_per_round_best_alpha) \
            = server.federated_learning()
            
        save_results_in_to_files(plot_clients_test_acc_local_model_local_data,
               plot_clients_test_acc_global_model_local_data,
               plot_clients_test_acc_local_model_global_data,
               plot_test_acc,
               plot_clients_norm,
               plot_global_clip_values,
               plot_local_clip_values,
               plot_clients_train_loss,
               plot_clients_privacy_spent,
               plot_clients_per_round_privacy_spent,
               plot_clients_per_round_noise_multiplier,
               plot_clients_per_round_rdp_privacy_spent,
               plot_clients_per_round_best_alpha,
               num_clients,
               clip_value,
               epsilons,
               sigma,
               sample_rate,
               title,
               dataset_name,
               run,
               num_rounds,
               num_epochs)

        if run == 0:
            avg_plot_clients_test_acc_local_model_local_data = \
                np.array(plot_clients_test_acc_local_model_local_data) / runs
            avg_plot_clients_test_acc_global_model_local_data = \
                np.array(plot_clients_test_acc_global_model_local_data) / runs
            avg_plot_clients_test_acc_local_model_global_data = \
                np.array(plot_clients_test_acc_local_model_global_data) / runs
            avg_plot_clients_norm = np.array(torch.tensor(plot_clients_norm).to(torch.device("cpu"))) / runs
            avg_plot_clients_train_loss = np.array(torch.tensor(plot_clients_train_loss).to(torch.device("cpu"))) / runs
            avg_plot_test_acc = np.array(plot_test_acc) / runs
            avg_plot_local_clip_values = np.array(plot_local_clip_values) / runs
            avg_plot_global_clip_values = np.array(plot_global_clip_values) / runs
            avg_plot_clients_privacy_spent = np.array(plot_clients_privacy_spent) / runs
            avg_plot_clients_per_round_privacy_spent = np.array(plot_clients_per_round_privacy_spent) / runs
            avg_plot_clients_per_round_noise_multiplier = np.array(plot_clients_per_round_noise_multiplier) / runs
            avg_plot_clients_per_round_rdp_privacy_spent = np.array(plot_clients_per_round_rdp_privacy_spent) / runs
            avg_plot_clients_per_round_best_alpha = np.array(plot_clients_per_round_best_alpha) / runs
        else:
            avg_plot_clients_test_acc_local_model_local_data += np.array(
                plot_clients_test_acc_local_model_local_data) / runs
            avg_plot_clients_test_acc_global_model_local_data += np.array(
                plot_clients_test_acc_global_model_local_data) / runs
            avg_plot_clients_test_acc_local_model_global_data += np.array(
                plot_clients_test_acc_local_model_global_data) / runs
            avg_plot_clients_norm += np.array(torch.tensor(plot_clients_norm).to(torch.device("cpu"))) / runs
            avg_plot_clients_train_loss += np.array(
                torch.tensor(plot_clients_train_loss).to(torch.device("cpu"))) / runs
            avg_plot_test_acc += np.array(plot_test_acc) / runs
            avg_plot_local_clip_values += np.array(plot_local_clip_values) / runs
            avg_plot_global_clip_values += np.array(plot_global_clip_values) / runs
            avg_plot_clients_privacy_spent += np.array(plot_clients_privacy_spent) / runs
            avg_plot_clients_per_round_privacy_spent += np.array(plot_clients_per_round_privacy_spent) / runs
            avg_plot_clients_per_round_noise_multiplier += np.array(plot_clients_per_round_noise_multiplier) / runs
            avg_plot_clients_per_round_rdp_privacy_spent += np.array(plot_clients_per_round_rdp_privacy_spent) / runs
            avg_plot_clients_per_round_best_alpha += np.array(plot_clients_per_round_best_alpha) / runs

    (file_clients_test_acc_local_model_local_data,
     file_clients_test_acc_global_model_local_data,
     file_clients_test_acc_local_model_global_data,
     file_clients_norm,
     file_clients_train_loss,
     file_test_acc,
     file_clients_per_round_privacy_spent,
     file_clients_per_round_noise_multiplier,
     file_clients_per_round_rdp_privacy_spent,
     file_clients_per_round_best_alpha,
     file_clients_privacy_spent) = setup_files(num_clients,
                                               runs,
                                               num_rounds,
                                               num_epochs,
                                               sigma,
                                               clip_value,
                                               epsilons,
                                               sample_rate,
                                               title,
                                               flag_is_niid,
                                               dataset_name)

    plot_results(avg_plot_clients_test_acc_local_model_local_data,
                 avg_plot_clients_test_acc_global_model_local_data,
                 avg_plot_clients_test_acc_local_model_global_data,
                 avg_plot_test_acc,
                 avg_plot_clients_norm,
                 avg_plot_global_clip_values,
                 avg_plot_local_clip_values,
                 avg_plot_clients_train_loss,
                 avg_plot_clients_privacy_spent,
                 avg_plot_clients_per_round_privacy_spent,
                 avg_plot_clients_per_round_noise_multiplier,
                 avg_plot_clients_per_round_rdp_privacy_spent,
                 avg_plot_clients_per_round_best_alpha,
                 file_clients_test_acc_local_model_local_data,
                 file_clients_test_acc_global_model_local_data,
                 file_clients_test_acc_local_model_global_data,
                 file_clients_norm,
                 file_clients_train_loss,
                 file_test_acc,
                 file_clients_privacy_spent,
                 file_clients_per_round_privacy_spent,
                 file_clients_per_round_noise_multiplier,
                 file_clients_per_round_rdp_privacy_spent,
                 file_clients_per_round_best_alpha,
                 num_clients,
                 clip_value,
                 epsilons,
                 sigma,
                 sample_rate,
                 title,
                 dataset_name)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Federated Learning Main')
    parser.add_argument('--config', type=str, default='config.yaml', help='Path to the configuration file')
    args = parser.parse_args()

    main(args.config)
