#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
# Pytorch version: 1.1.0 or above


## all experiments on CIFAR10, optimizer=SGD, momentum=0.9, learning rate=0.05(for small nets, lr=0.01 for large nets), weight-decay=0.95perepoch
## number of iterations = [58650/150=391(cifar10), 70350/150=469(MNIST)],
# sample: num_sample/epoch = 10, [ceil(391/40)=10(cifar10), ceil(469/47)=10(MNIST)]
import os
import logging
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import math

import cmd_args_cc
from utils_cc import get_data_loaders, fl_get_train_data_loaders, get_model, setup_logging, seed_everything, accuracy, accuracy_encoder_decoder, average_weights

################################################################## BASELINE: GENERAL SGD & USGD & FedAvg ########################################################################
def SGD_train_model(args, UA_train_loader, test_loader, model, exp_dir, exp_model_dir, criterion, start_epoch=None, epochs=None):
    cudnn.benchmark = True
    start_epoch = start_epoch or 0
    epochs = epochs or args.epochs

    init_model_train_loss, init_model_train_acc = eval_loss(args, UA_train_loader, model, criterion)
    init_model_test_loss, init_model_test_acc = eval_loss(args, test_loader, model, criterion)
    final_model_train_loss_lst, final_model_test_loss_lst = np.array([init_model_train_loss]), np.array([init_model_test_loss])
    final_model_train_acc_lst, final_model_test_acc_lst = np.array([init_model_train_acc]), np.array([init_model_test_acc])
    logging.info('Round: %d, Train-loss-final: %6.4f, Test-loss-final: %6.4f, Train-acc-final: %6.4f, Test-acc-final: %6.4f', 0, init_model_train_loss, init_model_test_loss, init_model_train_acc, init_model_test_acc)
    
    for epoch in range(start_epoch, epochs):
        model = train_local_update_general(args, UA_train_loader, model, criterion)
        final_model_train_loss, final_model_train_acc = eval_loss(args, UA_train_loader, model, criterion)
        final_model_test_loss, final_model_test_acc = eval_loss(args, test_loader, model, criterion)
        final_model_train_loss_lst = np.append(final_model_train_loss_lst, final_model_train_loss)
        final_model_test_loss_lst = np.append(final_model_test_loss_lst, final_model_test_loss)
        final_model_train_acc_lst = np.append(final_model_train_acc_lst, final_model_train_acc)
        final_model_test_acc_lst = np.append(final_model_test_acc_lst, final_model_test_acc)
        logging.info('Round: %d, Train-loss-final: %6.4f, Test-loss-final: %6.4f, Train-acc-final: %6.4f, Test-acc-final: %6.4f', epoch+1, final_model_train_loss, final_model_test_loss, final_model_train_acc, final_model_test_acc)

    log_fn_1 = os.path.join(exp_dir, "final_train_loss.txt")
    log_fn_2 = os.path.join(exp_dir, "final_test_loss.txt")
    log_fn_3 = os.path.join(exp_dir, "final_train_acc.txt")
    log_fn_4 = os.path.join(exp_dir, "final_test_acc.txt")
    np.savetxt(log_fn_1, final_model_train_loss_lst, fmt='%6.4f')
    np.savetxt(log_fn_2, final_model_test_loss_lst, fmt='%6.4f')
    np.savetxt(log_fn_3, final_model_train_acc_lst, fmt='%6.4f')
    np.savetxt(log_fn_4, final_model_test_acc_lst, fmt='%6.4f')
    saved_model_fn = os.path.join(exp_model_dir, "SGD_model.pth")
    torch.save(model.state_dict(), saved_model_fn)


def USGD_train_model(args, U_train_loader, test_loader, model, exp_dir, exp_model_dir, criterion, start_epoch=None, epochs=None):
    cudnn.benchmark = True
    start_epoch = start_epoch or 0
    epochs = epochs or args.epochs

    init_model_train_loss, init_model_train_acc = eval_loss(args, U_train_loader, model, criterion)
    init_model_test_loss, init_model_test_acc = eval_loss(args, test_loader, model, criterion)
    final_model_train_loss_lst, final_model_test_loss_lst = np.array([init_model_train_loss]), np.array([init_model_test_loss])
    final_model_train_acc_lst, final_model_test_acc_lst = np.array([init_model_train_acc]), np.array([init_model_test_acc])
    logging.info('Round: %d, Train-loss-final: %6.4f, Test-loss-final: %6.4f, Train-acc-final: %6.4f, Test-acc-final: %6.4f', 0, init_model_train_loss, init_model_test_loss, init_model_train_acc, init_model_test_acc)
    
    for epoch in range(start_epoch, epochs):
        model = train_local_update_general(args, U_train_loader, model, criterion)
        final_model_train_loss, final_model_train_acc = eval_loss(args, U_train_loader, model, criterion)
        final_model_test_loss, final_model_test_acc = eval_loss(args, test_loader, model, criterion)
        final_model_train_loss_lst = np.append(final_model_train_loss_lst, final_model_train_loss)
        final_model_test_loss_lst = np.append(final_model_test_loss_lst, final_model_test_loss)
        final_model_train_acc_lst = np.append(final_model_train_acc_lst, final_model_train_acc)
        final_model_test_acc_lst = np.append(final_model_test_acc_lst, final_model_test_acc)
        logging.info('Round: %d, Train-loss-final: %6.4f, Test-loss-final: %6.4f, Train-acc-final: %6.4f, Test-acc-final: %6.4f', epoch+1, final_model_train_loss, final_model_test_loss, final_model_train_acc, final_model_test_acc)

    log_fn_1 = os.path.join(exp_dir, "final_train_loss.txt")
    log_fn_2 = os.path.join(exp_dir, "final_test_loss.txt")
    log_fn_3 = os.path.join(exp_dir, "final_train_acc.txt")
    log_fn_4 = os.path.join(exp_dir, "final_test_acc.txt")
    np.savetxt(log_fn_1, final_model_train_loss_lst, fmt='%6.4f')
    np.savetxt(log_fn_2, final_model_test_loss_lst, fmt='%6.4f')
    np.savetxt(log_fn_3, final_model_train_acc_lst, fmt='%6.4f')
    np.savetxt(log_fn_4, final_model_test_acc_lst, fmt='%6.4f')
    saved_model_fn = os.path.join(exp_model_dir, "USGD_model.pth")
    torch.save(model.state_dict(), saved_model_fn)


def fedavg_train_model(args, U_train_loader, A_train_loader, UA_train_loader, test_loader, model, exp_dir, exp_model_dir, criterion, valid_user_data_size, start_epoch=None, epochs=None):
    cudnn.benchmark = True
    start_epoch = start_epoch or 0
    epochs = epochs or args.epochs
    U_local_steps = round(args.local_steps_init*args.rou)
    A_local_steps = args.local_steps_init - U_local_steps

    U_init_model_train_loss, U_init_model_train_acc = eval_loss(args, UA_train_loader, model, criterion)
    U_init_model_test_loss, U_init_model_test_acc = eval_loss(args, test_loader, model, criterion)
    final_model_train_loss_lst, final_model_test_loss_lst = np.array([U_init_model_train_loss]), np.array([U_init_model_test_loss])
    final_model_train_acc_lst, final_model_test_acc_lst = np.array([U_init_model_train_acc]), np.array([U_init_model_test_acc])
    logging.info('Round: %d, Train-loss-final: %6.4f, Test-loss-final: %6.4f, Train-acc-final: %6.4f, Test-acc-final: %6.4f', 0, U_init_model_train_loss, U_init_model_test_loss, U_init_model_train_acc, U_init_model_test_acc)

    for epoch in range(start_epoch, epochs):
        U_model = train_local_update_general(args, U_train_loader, copy.deepcopy(model), criterion)
        A_model = train_local_update_general(args, A_train_loader, copy.deepcopy(model), criterion)
        train_weights_user = [copy.deepcopy(U_model.state_dict()), copy.deepcopy(A_model.state_dict())]
        global_weights = average_weights(train_weights_user, valid_user_data_size)
        model.load_state_dict(global_weights)
        final_model_train_loss, final_model_train_acc = eval_loss(args, UA_train_loader, model, criterion)
        final_model_test_loss, final_model_test_acc = eval_loss(args, test_loader, model, criterion)
        final_model_train_loss_lst = np.append(final_model_train_loss_lst, final_model_train_loss)
        final_model_test_loss_lst = np.append(final_model_test_loss_lst, final_model_test_loss)
        final_model_train_acc_lst = np.append(final_model_train_acc_lst, final_model_train_acc)
        final_model_test_acc_lst = np.append(final_model_test_acc_lst, final_model_test_acc)
        logging.info('Round: %d, Train-loss-final: %6.4f, Test-loss-final: %6.4f, Train-acc-final: %6.4f, Test-acc-final: %6.4f', epoch+1, final_model_train_loss, final_model_test_loss, final_model_train_acc, final_model_test_acc)

    log_fn_1 = os.path.join(exp_dir, "final_train_loss.txt")
    log_fn_2 = os.path.join(exp_dir, "final_test_loss.txt")
    log_fn_3 = os.path.join(exp_dir, "final_train_acc.txt")
    log_fn_4 = os.path.join(exp_dir, "final_test_acc.txt")
    np.savetxt(log_fn_1, final_model_train_loss_lst, fmt='%6.4f')
    np.savetxt(log_fn_2, final_model_test_loss_lst, fmt='%6.4f')
    np.savetxt(log_fn_3, final_model_train_acc_lst, fmt='%6.4f')
    np.savetxt(log_fn_4, final_model_test_acc_lst, fmt='%6.4f')
    saved_model_fn = os.path.join(exp_model_dir, "fedavg_model.pth")
    torch.save(model.state_dict(), saved_model_fn)


def train_local_update_general(args, train_loader, model, criterion):
    lr = args.learning_rate
    model.train() # switch to train mode
    if args.algorithm == 'sgd_m':
        optimizer = torch.optim.SGD(model.parameters(), lr, momentum=args.mo)
    elif args.algorithm == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr)
    for local_step, (input_init, target) in enumerate(train_loader):
        if torch.cuda.is_available():
            target = target.cuda(non_blocking=True)
            input_cuda = input_init.cuda()
        else:
            input_cuda = input_init
        input_var = Variable(input_cuda, requires_grad=True)
        target_var = Variable(target) # default requires_grad=False
        # compute output
        output = model(input_var)
        if (args.arch == 'conv_ae') or (args.arch == 'ae') or (args.arch == 'unet'):
            loss = criterion(output, input_var)
        else:
            loss = criterion(output, target_var)
        # compute gradient and do SGD local_step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return model

################################################################## Single-Agent Assisted-SGD ####################################################################
def SAASGD_train_model(args, U_train_loader, A_train_loader, UA_train_loader, test_loader, model, exp_dir, exp_model_dir, criterion, start_epoch=None, epochs=None):
    cudnn.benchmark = True
    start_epoch = start_epoch or 0
    epochs = epochs or args.epochs
    U_local_steps = round(args.local_steps_init*args.rou)
    A_local_steps = args.local_steps_init - U_local_steps
    U_model = copy.deepcopy(model)
    U_argmin_index_lst, A_argmin_index_lst = np.array(list()), np.array(list())
    U_model_lst, A_model_lst = list(), list()
    for i in range(U_local_steps):
        if i%(args.local_interval) == 0:
            U_model_lst.append(copy.deepcopy(model))
    for i in range(A_local_steps):
        if i%(args.local_interval) == 0:
            A_model_lst.append(copy.deepcopy(model))

    U_init_model_train_loss, U_init_model_train_acc = eval_loss(args, UA_train_loader, model, criterion)
    U_init_model_test_loss, U_init_model_test_acc = eval_loss(args, test_loader, model, criterion)
    final_model_train_loss_lst, final_model_test_loss_lst = np.array([U_init_model_train_loss]), np.array([U_init_model_test_loss])
    final_model_train_acc_lst, final_model_test_acc_lst = np.array([U_init_model_train_acc]), np.array([U_init_model_test_acc])
    logging.info('Round: %d, Train-loss-final: %6.4f, Test-loss-final: %6.4f, Train-acc-final: %6.4f, Test-acc-final: %6.4f', 0, U_init_model_train_loss, U_init_model_test_loss, U_init_model_train_acc, U_init_model_test_acc)
    
    for epoch in range(start_epoch, epochs):
        U_local_model_train_loss_lst, A_local_model_train_loss_lst, A_local_model_train_acc_lst = np.array(list()), np.array(list()), np.array(list())
        #np.random.seed(epoch)

        U_model_lst = train_local_update(args, U_train_loader, U_model, U_model_lst, criterion)
        for i in range(len(U_model_lst)):
            U_local_model_train_loss, _ = eval_loss(args, UA_train_loader, U_model_lst[i], criterion)
            U_local_model_train_loss_lst = np.append(U_local_model_train_loss_lst, U_local_model_train_loss)
        model_index = np.argmin(U_local_model_train_loss_lst)
        model_index_original = model_index*args.local_interval + 1
        U_argmin_index_lst = np.append(U_argmin_index_lst, model_index_original)
        A_model = copy.deepcopy(U_model_lst[model_index])

        if (args.plot_hist==0) and ((epoch==0) or (epoch==epochs-1)):
            saved_U_model_0 = os.path.join(exp_model_dir, "U_model_{}_1.pth".format(epoch+1))
            torch.save(U_model_lst[0].state_dict(), saved_U_model_0)
            saved_U_model_1 = os.path.join(exp_model_dir, "U_model_{}_51.pth".format(epoch+1))
            torch.save(U_model_lst[1].state_dict(), saved_U_model_1)
            saved_U_model_2 = os.path.join(exp_model_dir, "U_model_{}_151.pth".format(epoch+1))
            torch.save(U_model_lst[3].state_dict(), saved_U_model_2)

        A_model_lst = train_local_update(args, A_train_loader, A_model, A_model_lst, criterion)
        for i in range(len(A_model_lst)):
            A_local_model_train_loss, A_local_model_train_acc = eval_loss(args, UA_train_loader, A_model_lst[i], criterion)
            A_local_model_train_loss_lst = np.append(A_local_model_train_loss_lst, A_local_model_train_loss)
            A_local_model_train_acc_lst = np.append(A_local_model_train_acc_lst, A_local_model_train_acc)
        model_index = np.argmin(A_local_model_train_loss_lst)
        model_index_original = model_index*args.local_interval + 1
        A_argmin_index_lst = np.append(A_argmin_index_lst, model_index_original)
        U_model = copy.deepcopy(A_model_lst[model_index])

        if (args.plot_hist==0) and ((epoch==0) or (epoch==epochs-1)):
            saved_A_model_0 = os.path.join(exp_model_dir, "A_model_{}_1.pth".format(epoch+1))
            torch.save(A_model_lst[0].state_dict(), saved_A_model_0)
            saved_A_model_1 = os.path.join(exp_model_dir, "A_model_{}_501.pth".format(epoch+1))
            torch.save(A_model_lst[10].state_dict(), saved_A_model_1)
            saved_A_model_2 = os.path.join(exp_model_dir, "A_model_{}_1501.pth".format(epoch+1))
            torch.save(A_model_lst[30].state_dict(), saved_A_model_2)

        final_model_train_loss, final_model_train_acc = A_local_model_train_loss_lst[model_index], A_local_model_train_acc_lst[model_index]
        final_model_test_loss, final_model_test_acc = eval_loss(args, test_loader, U_model, criterion)
        final_model_train_loss_lst = np.append(final_model_train_loss_lst, final_model_train_loss)
        final_model_test_loss_lst = np.append(final_model_test_loss_lst, final_model_test_loss)
        final_model_train_acc_lst = np.append(final_model_train_acc_lst, final_model_train_acc)
        final_model_test_acc_lst = np.append(final_model_test_acc_lst, final_model_test_acc)
        logging.info('Round: %d, Train-loss-final: %6.4f, Test-loss-final: %6.4f, Train-acc-final: %6.4f, Test-acc-final: %6.4f', epoch+1, final_model_train_loss, final_model_test_loss, final_model_train_acc, final_model_test_acc)

    log_fn_1 = os.path.join(exp_dir, "final_train_loss.txt")
    log_fn_2 = os.path.join(exp_dir, "final_test_loss.txt")
    log_fn_3 = os.path.join(exp_dir, "final_train_acc.txt")
    log_fn_4 = os.path.join(exp_dir, "final_test_acc.txt")
    log_fn_5 = os.path.join(exp_dir, "U_argmin_indices.txt")
    log_fn_6 = os.path.join(exp_dir, "A_argmin_indices.txt")
    np.savetxt(log_fn_1, final_model_train_loss_lst, fmt='%6.4f')
    np.savetxt(log_fn_2, final_model_test_loss_lst, fmt='%6.4f')
    np.savetxt(log_fn_3, final_model_train_acc_lst, fmt='%6.4f')
    np.savetxt(log_fn_4, final_model_test_acc_lst, fmt='%6.4f')
    np.savetxt(log_fn_5, U_argmin_index_lst, fmt='%d')
    np.savetxt(log_fn_6, A_argmin_index_lst, fmt='%d')
    saved_model_fn = os.path.join(exp_model_dir, "SAASGD_model.pth")
    torch.save(U_model.state_dict(), saved_model_fn)


def train_local_update(args, train_loader, model, model_lst, criterion):
    grad_lst = []
    lr = args.learning_rate
    if args.eps != None:
        delta = 1e-5
        sigma = math.sqrt(2*math.log(1/delta)) / args.eps # Strong Composition Theorem
    model.train() # switch to train mode
    if args.algorithm == 'sgd_m':
        optimizer = torch.optim.SGD(model.parameters(), lr, momentum=args.mo)
    elif args.algorithm == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr)
    for local_step, (input_init, target) in enumerate(train_loader):
        if torch.cuda.is_available():
            target = target.cuda(non_blocking=True)
            input_cuda = input_init.cuda()
        else:
            input_cuda = input_init
        input_var = Variable(input_cuda, requires_grad=True)
        target_var = Variable(target) # default requires_grad=False
        # compute output
        output = model(input_var)
        if (args.arch == 'conv_ae') or (args.arch == 'ae') or (args.arch == 'unet'):
            loss = criterion(output, input_var)
        else:
            loss = criterion(output, target_var)
        # compute gradient and do SGD local_step
        optimizer.zero_grad() # Initialize the gradients to 0
        loss.backward() # Backward to compute the gradients using losses
        if args.eps != None:
            batch_grad_norm = torch.tensor(0.0)
            for p in model.parameters():
                batch_grad_norm = batch_grad_norm + (torch.norm(p.grad))**2
            batch_grad_norm = torch.sqrt(batch_grad_norm)
            grad_lst.append(batch_grad_norm)
            C = np.median(grad_lst)
            normal_dist = torch.distributions.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([sigma*C/args.U_batch_size]))
            for p in model.parameters():
                p.grad = p.grad / max(1, batch_grad_norm/C) # Clip gradient
                p.grad = p.grad + normal_dist.sample((p.grad.view(-1).size())).reshape(p.grad.size()).cuda() # Add noise
        optimizer.step() # Update the parameters using gradients
        if local_step%(args.local_interval) == 0:
            model_lst[int(local_step/(args.local_interval))] = copy.deepcopy(model)
    return model_lst


def eval_loss(args, loader, model, criterion):
    loss_np, acc_np = np.array(list()), np.array(list())
    # switch to evaluation mode
    model.eval()
    for i, (input_init, target) in enumerate(loader):
        if torch.cuda.is_available():
            target = target.cuda(non_blocking=True)
            input_cuda = input_init.cuda()
        else:
            input_cuda = input_init
        with torch.no_grad():
            input_var = Variable(input_cuda) # default requires_grad=False
            target_var = Variable(target) # default requires_grad=False
        # compute output
        output = model(input_var)
        # measure accuracy and record loss
        if (args.arch == 'conv_ae') or (args.arch == 'ae') or (args.arch == 'unet'):
            loss = criterion(output, input_var)
            prec1 = accuracy_encoder_decoder(input_var.data, output.data)
        else:
            loss = criterion(output, target_var)
            prec1 = accuracy(output.data, target, topk=(1,))[0]
        loss_np = np.append(loss_np, loss.data.item())
        acc_np = np.append(acc_np, prec1.item())
    return (np.mean(loss_np)).item(), (np.mean(acc_np)).item()

################################################################## MAIN ############################################################################
def main():
    seed_everything()
    args = cmd_args_cc.parse_args()
    print(args)
    exp_dir = os.path.join(args.script_path, 'files')
    exp_dir = os.path.join(exp_dir, args.exp_name)
    if not os.path.isdir(exp_dir):
        os.makedirs(exp_dir)
    exp_model_dir = os.path.join(args.script_path, 'models')
    exp_model_dir = os.path.join(exp_model_dir, args.exp_name)
    if not os.path.isdir(exp_model_dir):
        os.makedirs(exp_model_dir)
    setup_logging(args, exp_dir)
    model = get_model(args)
    logging.info('Number of parameters: %d', sum([p.data.nelement() for p in model.parameters()]))

    if (args.plot_hist != 1):
        if args.loss == 'MSE':
            criterion = nn.MSELoss()
        elif args.loss == 'KL_divergence':
            criterion = nn.KLDivLoss(reduction='none')
        else:
            criterion = nn.CrossEntropyLoss()
        if torch.cuda.is_available():
            criterion = criterion.cuda()

        train_dataset, test_loader, U_user_groups, A_user_groups, U_user_data_size, A_user_data_size = get_data_loaders(args)

        if args.command == 'train':
            U_idxs = U_user_groups[0]
            A_idxs = A_user_groups[0]
            UA_idxs = np.concatenate((U_idxs, A_idxs), axis=None)
            UA_idxs = np.random.choice(UA_idxs, UA_idxs.size, replace=False)
            U_train_loader = fl_get_train_data_loaders(args, train_dataset, list(U_idxs), 'U_training')
            A_train_loader = fl_get_train_data_loaders(args, train_dataset, list(A_idxs), 'A_training')
            UA_train_loader = fl_get_train_data_loaders(args, train_dataset, list(UA_idxs), 'UA_training')

            if args.learn_type == 'SAASGD':
                SAASGD_train_model(args, U_train_loader, A_train_loader, UA_train_loader, test_loader, model, exp_dir, exp_model_dir, criterion)
            elif args.learn_type == 'fedavg':
                valid_user_data_size = [U_user_data_size[0], A_user_data_size[0]]
                fedavg_train_model(args, U_train_loader, A_train_loader, UA_train_loader, test_loader, model, exp_dir, exp_model_dir, criterion, valid_user_data_size)
            elif args.learn_type == 'USGD':
                USGD_train_model(args, U_train_loader, test_loader, model, exp_dir, exp_model_dir, criterion)
            else:
                SGD_train_model(args, UA_train_loader, test_loader, model, exp_dir, exp_model_dir, criterion)
        else:
            device = torch.device('cpu')
            U_model_1 = copy.deepcopy(model)
            A_model_0 = copy.deepcopy(model)
            A_model_1 = copy.deepcopy(model)
            saved_U_model = os.path.join(exp_model_dir, "U_model_1_1.pth")
            model.load_state_dict(torch.load(saved_U_model, map_location=device))
            saved_U_model_1 = os.path.join(exp_model_dir, "U_model_1_51.pth")
            U_model_1.load_state_dict(torch.load(saved_U_model_1, map_location=device))
            saved_A_model = os.path.join(exp_model_dir, "A_model_1_1.pth")
            A_model_0.load_state_dict(torch.load(saved_A_model, map_location=device))
            saved_A_model_1 = os.path.join(exp_model_dir, "A_model_1_501.pth")
            A_model_1.load_state_dict(torch.load(saved_A_model_1, map_location=device))

            normal_dist = torch.distributions.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([args.s]))
            for p in model.parameters():
                t = normal_dist.sample((p.view(-1).size())).reshape(p.size())
                with torch.no_grad():
                    p.add_(t)
            for p in U_model_1.parameters():
                t = normal_dist.sample((p.view(-1).size())).reshape(p.size())
                with torch.no_grad():
                    p.add_(t)
            for p in A_model_0.parameters():
                t = normal_dist.sample((p.view(-1).size())).reshape(p.size())
                with torch.no_grad():
                    p.add_(t)
            for p in A_model_1.parameters():
                t = normal_dist.sample((p.view(-1).size())).reshape(p.size())
                with torch.no_grad():
                    p.add_(t)

            U_model_0_eval_loss, U_model_0_eval_acc = eval_loss(args, test_loader, model, criterion)
            U_model_1_eval_loss, U_model_1_eval_acc = eval_loss(args, test_loader, U_model_1, criterion)
            A_model_0_eval_loss, A_model_0_eval_acc = eval_loss(args, test_loader, A_model_0, criterion)
            A_model_1_eval_loss, A_model_1_eval_acc = eval_loss(args, test_loader, A_model_1, criterion)
            logging.info('Test-loss-U-model-1-1: %6.4f, Test-acc-U-model-1-1: %6.4f', U_model_0_eval_loss, U_model_0_eval_acc)
            logging.info('Test-loss-U-model-1-51: %6.4f, Test-acc-U-model-1-51: %6.4f', U_model_1_eval_loss, U_model_1_eval_acc)
            logging.info('Test-loss-A-model-1-1: %6.4f, Test-acc-A-model-1-1: %6.4f', A_model_0_eval_loss, A_model_0_eval_acc)
            logging.info('Test-loss-A-model-1-501: %6.4f, Test-acc-A-model-1-501: %6.4f', A_model_1_eval_loss, A_model_1_eval_acc)
    else:
        device = torch.device('cpu')
        """
        U_model_1 = copy.deepcopy(model)
        U_model_2 = copy.deepcopy(model)
        U_model_3 = copy.deepcopy(model)
        U_model_4 = copy.deepcopy(model)
        U_model_5 = copy.deepcopy(model)
        A_model_0 = copy.deepcopy(model)
        A_model_1 = copy.deepcopy(model)
        A_model_2 = copy.deepcopy(model)
        A_model_3 = copy.deepcopy(model)
        A_model_4 = copy.deepcopy(model)
        A_model_5 = copy.deepcopy(model)
        """

        saved_U_model = os.path.join(exp_model_dir, "U_model_1_1.pth")
        model.load_state_dict(torch.load(saved_U_model, map_location=device))
        """
        saved_U_model_1 = os.path.join(exp_model_dir, "U_model_1_51.pth")
        U_model_1.load_state_dict(torch.load(saved_U_model_1, map_location=device))
        saved_U_model_2 = os.path.join(exp_model_dir, "U_model_1_151.pth")
        U_model_2.load_state_dict(torch.load(saved_U_model_2, map_location=device))
        saved_U_model_3 = os.path.join(exp_model_dir, "U_model_10_1.pth")
        U_model_3.load_state_dict(torch.load(saved_U_model_3, map_location=device))
        saved_U_model_4 = os.path.join(exp_model_dir, "U_model_10_51.pth")
        U_model_4.load_state_dict(torch.load(saved_U_model_4, map_location=device))
        saved_U_model_5 = os.path.join(exp_model_dir, "U_model_10_151.pth")
        U_model_5.load_state_dict(torch.load(saved_U_model_5, map_location=device))
        U_model_np, U_model_1_np, U_model_2_np, U_model_3_np, U_model_4_np, U_model_5_np = [], [], [], [], [], []
        """
        U_model_np = []
        for p in model.parameters():
            U_model_np.append(p.data.detach().cpu().numpy())
        U_model_np = np.asarray(U_model_np)
        """
        for p in U_model_1.parameters():
            U_model_1_np.append(p.data.detach().cpu().numpy())
        U_model_1_np = np.asarray(U_model_1_np)
        for p in U_model_2.parameters():
            U_model_2_np.append(p.data.detach().cpu().numpy())
        U_model_2_np = np.asarray(U_model_2_np)
        for p in U_model_3.parameters():
            U_model_3_np.append(p.data.detach().cpu().numpy())
        U_model_3_np = np.asarray(U_model_3_np)
        for p in U_model_4.parameters():
            U_model_4_np.append(p.data.detach().cpu().numpy())
        U_model_4_np = np.asarray(U_model_4_np)
        for p in U_model_5.parameters():
            U_model_5_np.append(p.data.detach().cpu().numpy())
        U_model_5_np = np.asarray(U_model_5_np)
        """
        # the histogram of the data
        plt.rcParams.update({'font.size': 18})
        plt.xlabel('Weight Values', fontsize=18)
        plt.ylabel('Counts', fontsize=18)
        plt.grid(True)
        U_n, U_bins, _ = plt.hist(U_model_np, 100)
        #plt.title('Histogram of Learner Model of Iteration 1 in Round 1')
        plt.xlim(-0.1, 0.1)
        plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
        plt.tight_layout()
        plt.draw() #plt.show()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_1_1_new.pdf'))
        """
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        U_n_1, U_bins_1, _ = plt.hist(U_model_1_np, 100)
        plt.title('Histogram of Learner Model of Iteration 51 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_1_51.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        U_n_2, U_bins_2, _ = plt.hist(U_model_2_np, 100)
        plt.title('Histogram of Learner Model of Iteration 151 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_1_151.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        U_n_3, U_bins_3, _ = plt.hist(U_model_3_np, 100)
        plt.title('Histogram of Learner Model of Iteration 1 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_10_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        U_n_4, U_bins_4, _ = plt.hist(U_model_4_np, 100)
        plt.title('Histogram of Learner Model of Iteration 51 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_10_51.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        U_n_5, U_bins_5, _ = plt.hist(U_model_5_np, 100)
        plt.title('Histogram of Learner Model of Iteration 151 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_10_151.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(U_model_1_np - U_model_np, 100)
        plt.title('Histogram of Learner Difference between Iteration 51 & 1 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_1_51_d_1_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(U_model_2_np - U_model_1_np, 100)
        plt.title('Histogram of Learner Difference between Iteration 151 & 51 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_1_151_d_1_51.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(U_model_2_np - U_model_np, 100)
        plt.title('Histogram of Learner Difference between Iteration 151 & 1 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_1_151_d_1_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(U_model_4_np - U_model_3_np, 100)
        plt.title('Histogram of Learner Difference between Iteration 51 & 1 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_10_51_d_10_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(U_model_5_np - U_model_4_np, 100)
        plt.title('Histogram of Learner Difference between Iteration 151 & 51 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_10_151_d_10_51.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(U_model_5_np - U_model_3_np, 100)
        plt.title('Histogram of Learner Difference between Iteration 151 & 1 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/U_model_10_151_d_10_1.pdf'))

        saved_A_model = os.path.join(exp_model_dir, "A_model_1_1.pth")
        A_model_0.load_state_dict(torch.load(saved_A_model, map_location=device))
        saved_A_model_1 = os.path.join(exp_model_dir, "A_model_1_501.pth")
        A_model_1.load_state_dict(torch.load(saved_A_model_1, map_location=device))
        saved_A_model_2 = os.path.join(exp_model_dir, "A_model_1_1501.pth")
        A_model_2.load_state_dict(torch.load(saved_A_model_2, map_location=device))
        saved_A_model_3 = os.path.join(exp_model_dir, "A_model_10_1.pth")
        A_model_3.load_state_dict(torch.load(saved_A_model_3, map_location=device))
        saved_A_model_4 = os.path.join(exp_model_dir, "A_model_10_501.pth")
        A_model_4.load_state_dict(torch.load(saved_A_model_4, map_location=device))
        saved_A_model_5 = os.path.join(exp_model_dir, "A_model_10_1501.pth")
        A_model_5.load_state_dict(torch.load(saved_A_model_5, map_location=device))
        A_model_0_np, A_model_1_np, A_model_2_np, A_model_3_np, A_model_4_np, A_model_5_np = [], [], [], [], [], []
        for p in A_model_0.parameters():
            A_model_0_np.append(p.data.detach().cpu().numpy())
        A_model_0_np = np.asarray(A_model_0_np)
        for p in A_model_1.parameters():
            A_model_1_np.append(p.data.detach().cpu().numpy())
        A_model_1_np = np.asarray(A_model_1_np)
        for p in A_model_2.parameters():
            A_model_2_np.append(p.data.detach().cpu().numpy())
        A_model_2_np = np.asarray(A_model_2_np)
        for p in A_model_3.parameters():
            A_model_3_np.append(p.data.detach().cpu().numpy())
        A_model_3_np = np.asarray(A_model_3_np)
        for p in A_model_4.parameters():
            A_model_4_np.append(p.data.detach().cpu().numpy())
        A_model_4_np = np.asarray(A_model_4_np)
        for p in A_model_5.parameters():
            A_model_5_np.append(p.data.detach().cpu().numpy())
        A_model_5_np = np.asarray(A_model_5_np)
        # the histogram of the data
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        A_n_0, A_bins_0, _ = plt.hist(A_model_0_np, 100)
        plt.title('Histogram of Provider Model of Iteration 1 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_1_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        A_n_1, A_bins_1, _ = plt.hist(A_model_1_np, 100)
        plt.title('Histogram of Provider Model of Iteration 501 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_1_501.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        A_n_2, A_bins_2, _ = plt.hist(A_model_2_np, 100)
        plt.title('Histogram of Provider Model of Iteration 1501 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_1_1501.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        A_n_3, A_bins_3, _ = plt.hist(A_model_3_np, 100)
        plt.title('Histogram of Provider Model of Iteration 1 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_10_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        A_n_4, A_bins_4, _ = plt.hist(A_model_4_np, 100)
        plt.title('Histogram of Provider Model of Iteration 501 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_10_501.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        A_n_5, A_bins_5, _ = plt.hist(A_model_5_np, 100)
        plt.title('Histogram of Provider Model of Iteration 1501 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_10_1501.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(A_model_1_np - A_model_0_np, 100)
        plt.title('Histogram of Provider Difference between Iteration 501 & 1 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_1_501_d_1_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(A_model_2_np - A_model_1_np, 100)
        plt.title('Histogram of Provider Difference between Iteration 1501 & 501 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_1_1501_d_501_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(A_model_2_np - A_model_0_np, 100)
        plt.title('Histogram of Provider Difference between Iteration 1501 & 1 in Round 1')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_1_1501_d_1_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(A_model_4_np - A_model_3_np, 100)
        plt.title('Histogram of Provider Difference between Iteration 501 & 1 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_10_501_d_10_1.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(A_model_5_np - A_model_4_np, 100)
        plt.title('Histogram of Provider Difference between Iteration 1501 & 501 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_10_1501_d_10_501.pdf'))
        plt.xlabel('Weight Values')
        plt.ylabel('Counts')
        plt.grid(True)
        plt.hist(A_model_5_np - A_model_3_np, 100)
        plt.title('Histogram of Provider Difference between Iteration 1501 & 1 in Round 10')
        plt.draw()
        plt.savefig(os.path.join(exp_model_dir, 'hist_imgs/A_model_10_1501_d_10_1.pdf'))
        """


if __name__ == '__main__':
    main()
