from xml.sax import default_parser_list

import numpy as np
import argparse
import random
import torch
import os.path
import importlib
import os
import utils.fmodule
import ujson
import time
import collections
import utils.system_simulator as ss
import logging
import copy
sample_list=['uniform', 'md', 'full', 'uniform_available', 'md_available', 'full_available']
agg_list=['uniform', 'weighted_scale', 'weighted_com']
optimizer_list=['SGD', 'Adam', 'RMSprop', 'Adagrad']
logger = None

def read_option():
    parser = argparse.ArgumentParser()
    # basic settings
    parser.add_argument('--algorithm', help='name of algorithm;', type=str, default='fedavg')
    parser.add_argument('--model', help='name of model;', type=str, default='cnn')
    # parser.add_argument('--pretrain', help='the path of the pretrained model parameter created by torch.save;', type=str, default='')
    parser.add_argument('--C', help='number of clients', type=int, default=100)
    parser.add_argument('--P', help='partial of clients per rounds', type=float, default=0.1)
    parser.add_argument('--N', help='number of cuts that one client has', type=int, default=2)
    parser.add_argument('--dataloader', help='dataloader names', type=str)
    parser.add_argument('--dir_a', help='alpha to guide dirichlet distribution', type=float, default=0.1)

    # methods of server side for sampling and aggregating
    parser.add_argument('--aggregate', help='methods for aggregating models', type=str, choices=agg_list, default='uniform')
    # hyper-parameters of training in server side
    parser.add_argument('--num_rounds', help='number of communication rounds', type=int, default=20)
    parser.add_argument('--lr_decay', help='learning rate decay for the training process;', type=float, default=0.98)
    parser.add_argument('--lr_scheduler', help='type of the global learning rate scheduler', type=int, default=0)
    parser.add_argument('--early_stop', help='stop training if there is no improvement for no smaller than the maximum rounds', type=int, default=-1)
    parser.add_argument('--global_momentum', help='decide to use global_momentum or not', action='store_true')
    # hyper-parameters of local training
    parser.add_argument('--num_epochs', help='number of epochs when clients trainset on data;', type=int, default=1)
    parser.add_argument('--num_steps', help='the number of local steps, which dominate num_epochs when setting num_steps>0', type=int, default=-1)
    parser.add_argument('--learning_rate', help='learning rate for inner solver;', type=float)
    parser.add_argument('--batch_size', help='batch size when clients trainset on data;', type=int, default='50')
    parser.add_argument('--optimizer', help='select the optimizer for gd', type=str, choices=optimizer_list, default='SGD')
    parser.add_argument('--momentum', help='momentum of local update', type=float, default=0.0)
    parser.add_argument('--weight_decay', help='weight decay for the training process', type=float, default=0)
    # realistic machine config
    parser.add_argument('--seed', help='seed for random initialization;', type=int, default=0)
    # 这个可以保存列表
    # parser.add_argument('--gpu', nargs='*', help='GPU IDs and empty input is equal to using CPU', type=int)
    parser.add_argument('--gpu', type=int, default=-1, help='GPU ID and empty input is equal to using CPU')

    parser.add_argument('--eval_interval', help='evaluate every __ rounds;', type=int, default=1)
    # algorithm-dependent hyper-parameters
    parser.add_argument('--algo_para', help='algorithm-dependent hyper-parameters', nargs='*', type=float)

    # FedEBA
    parser.add_argument('--feda', type=float, default=0.9)
    parser.add_argument('--T', type=float, default=0.1)
    parser.add_argument('--fake', help='fake aggregation', action='store_true')
    parser.add_argument('--pmode', help='fake aggregation', type=str, default='mean')

    # Unlearn
    parser.add_argument('--retrain', help='golden standard of FU', action='store_true')
    parser.add_argument('--unlearn', help='do FL training or unlearning', action='store_true')
    parser.add_argument('--u_rounds', help='number of unlearning rounds', type=int, default=20)
    parser.add_argument('--p_rounds', help='number of post-training rounds', type=int, default=20)
    parser.add_argument('--u_clients', help='number of unlearning clinets', type=int, default=1)
    parser.add_argument('--split_num', help='number of unlearning par', type=int, default=1)
    parser.add_argument('--class_num', help='number of unlearning classes', type=int, default=7)


    # fedproto
    parser.add_argument('--infoNCET', help='infoNCET of Rethink FL', type=float, default=0.02)

    # backdoor
    parser.add_argument('--bd', help='backdoor', action='store_true')

    parser.add_argument('--pretrain', help='load pretrained neural network weights', action='store_true') # 是否需要load 模型预训练参数（外部预训练）

    try: option = vars(parser.parse_args())
    except IOError as msg: parser.error(str(msg))
    return option

def setup_seed(seed):
    random.seed(1+seed)
    np.random.seed(21+seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(12+seed)
    torch.cuda.manual_seed_all(123+seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True

def initialize(option):
    # seed for reproduction
    setup_seed(option['seed'])

    # init model_path
    model_path = '.'.join(['models', option['model']])

    # init model
    # TODO check 内存应该不重叠
    server_model = getattr(importlib.import_module(model_path), 'Model')()

    # init devices
    gpu = option['gpu']
    gpu_now = torch.device('cpu') if gpu == -1 else torch.device('cuda:{}'.format(gpu))
    # utils.fmodule.dev_list = torch.device('cpu') if gpu == -1 else torch.device('cuda:{}'.format(gpu))
    # utils.fmodule.dev_manager = utils.fmodule.get_device()
    # model = utils.fmodule.Model().to(utils.fmodule.dev_list[0])
    # gpu_now = torch.device('cpu') if gpus is None else torch.device("cuda:{}".format(gpus[0]))

    # init data_loader
    data_name = 'DataLoader_' + option['dataloader']
    loader_path = '.'.join(['utils', 'dataloaders', data_name])
    Dataloader = getattr(importlib.import_module(loader_path), data_name)
    data_loader = Dataloader(params=option, input_require_shape=server_model.input_require_shape)

    # root dict
    current_file_path = os.path.realpath(__file__)
    current_directory = os.path.dirname(current_file_path)
    parent_directory = os.path.dirname(current_directory)
    # pretrain model
    if option['model'] == 'mobilenetv2':

        if option['pretrain']:
            server_model.load_state_dict(torch.load(os.path.join(parent_directory, 'models', 'mobilenetv2_96x96-ff0e83d8.pth'),
                           weights_only=True), strict=False)
        num_features = server_model.classifier.in_features
        server_model.classifier = torch.nn.Linear(num_features, data_loader.target_class_num)
    elif option['model'] == 'vgg16':
        if option['pretrain']:
            server_model.load_state_dict(
                torch.load(os.path.join(parent_directory, 'models', 'hub', 'checkpoints', 'vgg16-397923af.pth'),
                           weights_only=True), strict=False)
        server_model.classifier[-1] = torch.nn.Linear(server_model.classifier[-1].in_features, data_loader.target_class_num)

    # init client
    num_clients = option['C']
    if option['unlearn']:
        print('Unlearning Algorithm is Running.')
        obj_path = '%s.%s.%s' % ('algorithm', 'FU', option['algorithm'])
    else:
        print('FL Algorithm is Running.')
        obj_path = '%s.%s.%s' % ('algorithm', 'FL', option['algorithm'])
    Client = getattr(importlib.import_module(obj_path), 'Client')
    # 第一次初始
    clients = [Client(option, id=cid, model=None) for cid in range(num_clients)]
    # allocate data
    # TODO data statistics
    data_loader.allocate(clients)
    for cid, c in enumerate(clients): c.id = cid

    # init server

    server_module = importlib.import_module(obj_path)
    server = getattr(server_module, 'Server')(option, server_model, clients, data_loader, device=gpu_now)
    return server