from Clients import *
# from Server import *
from data.datasets import *
import os

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device("cuda")


def seed_all(seed=1029):  # 设置随机数
    random.seed(seed)
    # os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def log_():
    for handler in logging.root.handlers[:]:  # what
        logging.root.removeHandler(handler)

    logging.basicConfig(
        # filename=os.path.join(args.logdir, log_path),
        filename="/data/cyn/SNN_CIFAR10_Q2/test1.log",
        format='%(asctime)s %(levelname)-8s %(message)s',
        datefmt='%m-%d %H:%M',
        level=10,
        filemode='a',  # 覆盖之前的内容
    )

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)


def whole_net(n_parties, k_clients, i):
    server_save = args.partition+'_'+str(i) + '_' + str(n_parties) + '_' + str(k_clients) + '_'+ args.dataset+ '-server.pth'
    clients_save = args.partition+'_'+str(i) + '_' + str(n_parties) + '_' + str(k_clients) + '_' + args.dataset

    write_excel_xls_append(args.result, 'global', n_parties, k_clients,i)
    logger.info('n_parties= %s   k_clients=%s ' % (str(n_parties), str(k_clients)))

    # Prepare data partition
    logger.info("Partitioning data")
    train_data, test_data, net_dataidx_map = partition_data(n_parties,i)
    eval_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=args.batch_size,
        shuffle=True
    )

    # Initial server
    global_model = resnet20(use_bn=True, num_classes=10 if args.dataset == 'cifar10' else 100).to(device)
    diff1={}
    diff2={}

    # server.diff1
    for n, m in global_model.named_modules():
        if is_bn(m):
            name = str(n) + '.weight'
            data1 = m.weight.data.cuda()
            diff1[name] = data1

            name = str(n) + '.bias'
            data2 = m.bias.data.cuda()
            diff1[name] = data2

            name = str(n) + '.running_mean'
            data3 = m.running_mean.data.cuda()
            diff1[name] = data3

            name = str(n) + '.running_var'
            data4 = m.running_var.data.cuda()
            diff1[name] = data4

    # server.diff2
    for n, m in global_model.named_modules():
        if (isinstance(m, nn.Conv2d)) or isinstance(m, nn.Linear):
            name = str(n) + '.weight'
            data0 = m.weight.data.cuda()
            diff2[name] = data0

    search_fold_and_remove_bn(global_model)
    global_model = SpikeModel(model=global_model, sim_length=args.T, specials=res_specials)
    global_model.set_spike_state(use_spike=True)
    get_maximum_activation(eval_loader, model=global_model, momentum=0.9, iters=5, mse=True,
                           percentile=None,
                           sim_length=args.T, channel_wise=False)

    # Prepare Client
    clients = []
    for c in range(n_parties):
        clients.append(
            Clients(train_data, net_dataidx_map[c], device, c, clients_save, eval_loader))
        write_sheet_xls(args.result, str(c))
        #write_sheet_xls('result.xls', str(c))
        write_excel_xls_append(args.result, str(c), n_parties, k_clients,i)
        #write_excel_xls_append('result.xls', str(c), n_parties, k_clients)

    candidates = random.sample(clients, k_clients)

    # Global train
    for e in range(args.global_epochs):
        print("------------------------------" * 2)
        print("Global Epoch %d" % e)
        logger.info("Global Epoch %d" + "-" * 50)

        print("  select client is:")
        logger.info("  select client is:")
        for c in candidates:
            print("\t", c.client_id)
            logger.info("  %s", str(c.client_id))

        weight_accumulator = {}
        for name, params in global_model.state_dict().items():
            weight_accumulator[name] = torch.zeros_like(params)

        weight_accumulator1 = {}
        for name1, params1 in diff1.items():
            weight_accumulator1[name1] = torch.zeros_like(params1)

        weight_accumulator2 = {}
        for name2, params2 in diff2.items():
            weight_accumulator2[name2] = torch.zeros_like(params2)

        # Clients Train
        for c in candidates:
            diff, diff1, diff2 = c.local_train(server_save, e, diff1, diff2)

            for name, params in global_model.state_dict().items():  # SNN
                weight_accumulator[name] = weight_accumulator[name] + diff[name]

            for name1, params1 in diff1.items():  # ANN weight
                weight_accumulator1[name1].add_(diff1[name1])

            for name2, params2 in diff2.items():  # BN
                weight_accumulator2[name2].add_(diff2[name2])

        for name1, params1 in diff1.items():
            diff1[name1] = torch.div(weight_accumulator1[name1], k_clients).clone()

        for name2, params2 in diff2.items():
            diff2[name2] = torch.div(weight_accumulator2[name2], k_clients).clone()

        # Parameters aggregate and evaluate
        for name, data in global_model.state_dict().items():
            update_per_layer = torch.div(weight_accumulator[name], k_clients).clone()
            if data.type() != update_per_layer.type():
                data.mul_(0)
                data.add_(update_per_layer.to(torch.int64))  # add_自加相当于+=
            else:
                data.mul_(0)
                data.add_(update_per_layer)

        torch.save(global_model.state_dict(), server_save)

        acc = eval(global_model, eval_loader, device)

        print("\nGlobal Epoch = ", e, "   test acc =", acc)
        logger.info("Global epoch= %d  test acc=%s train acc=%s", e, str(acc), str(0))
        write_excel_xls_append(args.result, 'global', e, str(acc), 0)
    return acc


if __name__ == "__main__":
    log_()
    logger.info(device)

    # 随机数
    seed_all(1029)

    write_sheet_xls(args.result, 'global')
    acc1 = whole_net(10, 10, 7)
    acc1 = whole_net(10, 10, 1)
    acc1 = whole_net(10, 10, 3)
    acc1 = whole_net(10, 10, 5)
    #acc1 = whole_net(10, 10, 7)


