import argparse
import math
import numpy as np
import numpy.random as random
import sys
from utils.tensor_utils import flatten_named_tensor

import utils.data_utils as data_utils
import utils.utils as utils

from fl.attacker import attack
from fl.aggregate import aggregate
from fl.client import local_update_batch

def process_cmd():
    parser = argparse.ArgumentParser()

    # data
    parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'cifar100', 'imagenet12', 'femnist'], help='dataset', )
    parser.add_argument('--root', default='', type=str, help='root directory of dataset where directory the dataset exists or will be saved to', )
    parser.add_argument('--n_clients', default=50, type=int, help='the number of clients, set to 3400 if the dataset is femnist', )
    parser.add_argument('--partition', default='dirichlet', type=str, choices=['iid', 'dirichlet'], help='partition', )
    parser.add_argument('--beta', default=0.5, type=float, help='the concentration parameter for dirichlet partition', )
    parser.add_argument('--min_n_samples', default=10, type=int, help='minimum number of samples of each party for dirichelet partition', )
    # attack
    parser.add_argument('--at_ratio', default=0.0, type=float, )
    parser.add_argument('--at_type', default=None, type=str, choices=[None, 'bit_flip', 'ipm', 'label_flip', 'lie', 'min_max', 'min_sum'])
    parser.add_argument('--lie_z', default=1.5, type=float, help='lie attack strength', )
    parser.add_argument('--ipm_evals', default=2, type=int, )
    parser.add_argument('--dev_type', default='std', type=str, help='for min-max and min-sum attack', choices=['unit_vec', 'sign', 'std'])
    # server
    parser.add_argument('--arch', default='alexnet', type=str, choices=['alexnet', 'resnet18', 'squeezenet', 'femnistnet'])
    parser.add_argument('--n_rounds', default=200, type=int, help='number of communication rounds')
    parser.add_argument('--client_sample_ratio', default=0.1, type=float, )
    parser.add_argument('--scheduler', default=None, type=str, choices=[None, 'step', 'exponential', 'multi_step'])
    parser.add_argument('--step_size', default=0, type=int, help='eriod of learning rate decay')
    parser.add_argument('--milestones', default=[], type=int, nargs='*', help='ist of epoch indices')
    parser.add_argument('--gamma', default=1.0, type=float, help='multiplicative factor of learning rate decay')
    # defense
    parser.add_argument('--agg_type', default='mean', type=str, choices=['dnc', 'gmedian', 'mean', 'median', 'rbtm', 'krum', 'bulyan', ])
    parser.add_argument('--dnc_filter', default=4.0, type=float, help='filtering fraction of dnc', )
    parser.add_argument('--dnc_n_sample', default=10000, type=int, help='number of sampled coordinates of dnc', )
    parser.add_argument('--rfa_budget', default=3, type=int)
    # gain
    parser.add_argument('--n_subvectors', default=0, type=int, help='number of sub-vectors for GAIN, 0 when GAIN is disabled')
    # local update
    parser.add_argument('--n_epochs', default=5, type=int, help='number of local epochs')
    parser.add_argument('--batch_size', default=64, type=int, help='batch size for local update')
    parser.add_argument('--local_lr', default=0.1, type=float, help='learning rate for local update')
    parser.add_argument('--momentum', default=0.5, type=float, help='momentum factor')
    parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay (L2 penalty)')
    parser.add_argument('--grad_clip', default=True, action='store_true', help='enable gradient clipping')
    parser.add_argument('--no_grad_clip', action='store_false', dest='grad_clip', help='disable gradient clipping')
    parser.add_argument('--max_norm', default=2.0, type=float, help='max norm of the gradients')
    # seed
    parser.add_argument('--seed', default=0, type=int, )

    args = parser.parse_args()

    # properties
    if args.dataset == 'cifar10':
        args.n_classes = 10
    if args.dataset == 'cifar100':
        args.n_classes = 100
    elif args.dataset == 'imagenet12':
        args.n_classes = 12
    elif args.dataset == 'femnist':
        args.n_classes = 62
        args.n_clients = 3400

    # useful variables
    args.n_attackers = math.floor(args.n_clients * args.at_ratio)
    args.n_sample_clients = math.ceil(args.n_clients * args.client_sample_ratio)

    return args

if __name__ == '__main__':
    utils.output_current_time()

    args = process_cmd()
    print(f'Args: {args}', flush=True)

    # fix seed
    utils.setup_seed(args.seed)
    # load data
    tr_data, tr_label, te_data, te_label = data_utils.load_data(
        dataset=args.dataset, root=args.root,
        args=args,
    )
    # initialize net
    fed_model = utils.init_net(args.arch, n_classes=args.n_classes)
    fed_model.cuda()

    communication_round = 0
    best_te_acc = 0
    while communication_round <= args.n_rounds:
        print(f'========== communication round {communication_round:4} ==========')
        # sample client
        sampled_clients = random.choice(args.n_clients, args.n_sample_clients, replace=False)
        print(f'sampled clients: {sampled_clients.tolist()}')
        # local update
        sampled_users = sampled_clients[sampled_clients >= args.n_attackers]
        sampled_user_datas = [tr_data[i] for i in sampled_users]
        sampled_user_labels = [tr_label[i] for i in sampled_users]
        sampled_user_updates, sampled_n_user_samples = local_update_batch(
            fed_model, communication_round,
            sampled_user_datas, sampled_user_labels, 
            args, 
        )
        # attack
        sampled_attackers = sampled_clients[sampled_clients < args.n_attackers]
        sampled_attacker_data = [tr_data[i] for i in sampled_attackers]
        sampled_attacker_label = [tr_label[i] for i in sampled_attackers]
        sampled_at_updates, sampled_n_attacker_samples = attack(
            sampled_attacker_data, sampled_attacker_label,
            fed_model, communication_round,
            sampled_user_updates, sampled_n_user_samples,
            args,
        )
        # aggregate
        sampled_clients_reordered = np.concatenate((sampled_users, sampled_attackers))
        sampled_updates_reordered = sampled_user_updates + sampled_at_updates
        sampled_n_samples_reordered = sampled_n_user_samples + sampled_n_attacker_samples
        n_sampled_attackers = len(sampled_attackers)
        agg_update, candidate_idxs = aggregate(
            client_updates=sampled_updates_reordered, n_samples=sampled_n_samples_reordered, n_attackers=n_sampled_attackers,
            agg_type=args.agg_type, n_subvectors=args.n_subvectors,
            budget=args.rfa_budget,
            filtering_fraction=args.dnc_filter, n_sampled_coordinates=args.dnc_n_sample,
        )
        optimal_update, _ = aggregate(client_updates=sampled_user_updates)
        flat_optimal_update, _ = flatten_named_tensor(optimal_update)
        flat_agg_update, _ = flatten_named_tensor(agg_update)
        deviation = (flat_agg_update-flat_optimal_update).square().sum().sqrt().item()
        optimal_norm = flat_optimal_update.square().sum().sqrt().item()
        print(f'gradient deviation: {deviation}')
        print(f'optimal gradient norm: {optimal_norm}')
        # step
        model_state = fed_model.state_dict()
        for name in model_state:
            model_state[name] += agg_update[name]
        fed_model.load_state_dict(model_state)

        # test
        te_acc, te_loss = utils.inference(fed_model, te_data, te_label)
        best_te_acc = max(best_te_acc, te_acc)

        # output
        print(f'test loss: {te_loss:8.4f} test accuracy: {te_acc:6.2f} best test accuracy: {best_te_acc:6.2f}')
        sys.stdout.flush()
        communication_round+=1

    print('========== end of training ==========')