Module run

Expand source code
#!/usr/bin/env python3

import os
import sys
import json
import socket
import math
import signal

# Import PyTorch root package import torch
import torch
from torch.utils.collect_env import get_pretty_env_info

import numpy as np
import time
import copy
import threading
import pickle

from utils import comm_socket
from utils import gpu_utils
from utils import execution_context
from utils import thread_pool
from utils import wandb_wrapper

from opts import parse_args

from utils.logger import Logger
from data_preprocess.data_loader import load_data, get_test_batch_size
from utils.model_funcs import get_training_elements, evaluate_model, get_lr_scheduler, get_optimiser, \
    run_one_communication_round, local_training, initialise_model

from utils.checkpointing import save_checkpoint
from utils.checkpointing import defered_eval_and_save_checkpoint

from utils.utils import create_model_dir, create_metrics_dict
from utils.fl_funcs import get_sampled_clients

from models import RNN_MODELS
from models import mutils

from utils import algorithms

from utils.buffer import Buffer


def main(args, raw_cmdline, extra_):
    # Init project and save in args url for the plots
    projectWB = None
    if len(args.wandb_key) > 0:
        projectWB = wandb_wrapper.initWandbProject(
            args.wandb_key,
            args.wandb_project_name,
            f"{args.algorithm}-rounds-{args.rounds}-{args.model}@{args.dataset}-runid-{args.run_id}",
            args
        )
        if projectWB is not None:
            args.wandb_url = projectWB.url

    # In case of DataParallel for .to() to
    # This a device used by master for FL algorithms
    args.device = gpu_utils.get_target_device_str(args.gpu[0]) if type(args.gpu) == list else args.gpu

    # Instantiate execution context
    exec_ctx = execution_context.initExecutionContext()
    exec_ctx.extra_ = extra_

    # Set all seeds
    if args.deterministic:
        exec_ctx.random.seed(args.manual_init_seed)
        exec_ctx.np_random.seed(args.manual_init_seed)

    # Setup extra options for experiments
    for option in args.algorithm_options.split(","):
        kv = option.split(":")
        if len(kv) == 1:
            exec_ctx.experimental_options[kv[0]] = True
        else:
            exec_ctx.experimental_options[kv[0]] = kv[1]

    # Load validation set
    trainset, testset = load_data(exec_ctx, args.data_path, args.dataset, args, load_trainset=True, download=True)

    test_batch_size = get_test_batch_size(args.dataset, args.batch_size)

    # Test set dataloader
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=test_batch_size,
                                             num_workers=args.num_workers_test,
                                             shuffle=False,
                                             pin_memory=False,
                                             drop_last=False
                                             )

    # Reset all seeds
    if args.deterministic:
        exec_ctx.random.seed(args.manual_init_seed)
        exec_ctx.np_random.seed(args.manual_init_seed)

    if not args.evaluate:  # Training mode
        trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            num_workers=args.num_workers_train,
            shuffle=False,
            pin_memory=False,
            drop_last=False)

        # Is target device for master is GPU
        is_target_dev_gpu = gpu_utils.is_target_dev_gpu(args.device)

        # Initialize local training threads in case of using multithreading implementation
        exec_ctx.local_training_threads.adjust_num_workers(w=args.threadpool_for_local_opt, own_cuda_stream=True,
                                                           device_list=args.gpu)
        for th in exec_ctx.local_training_threads.threads:
            th.trainset_copy = copy.deepcopy(trainset)
            th.train_loader_copy = torch.utils.data.DataLoader(th.trainset_copy,
                                                               batch_size=args.batch_size,
                                                               num_workers=args.num_workers_train,
                                                               shuffle=False,
                                                               pin_memory=False,
                                                               drop_last=False)
            if hasattr(th.trainset_copy, "load_data"):
                th.trainset_copy.load_data()

        # Initialize remote clients - one remote connection per one remote computation device
        external_devices = [d.strip() for d in args.external_devices.split(",") if len(d.strip()) > 0]
        exec_ctx.remote_training_threads.adjust_num_workers(w=len(external_devices), own_cuda_stream=False,
                                                            device_list=["cpu"])
        for th_i in range(len(exec_ctx.remote_training_threads.threads)):
            th = exec_ctx.remote_training_threads.threads[th_i]
            dev = external_devices[th_i]
            th.external_dev_long = dev
            th.external_ip, th.external_port, th.external_dev, th.external_dev_dumber = dev.split(":")

            try:
                th.external_socket = comm_socket.CommSocket()
                th.external_socket.sock.connect((th.external_ip, int(th.external_port)))  # connect with remote side
                th.external_socket.rawSendString("execute_work")  # initiate work execution COMMAND
                th.external_socket.rawSendString(th.external_dev_long)  # provide device specification
                th.external_socket.rawSend(pickle.dumps(
                    raw_cmdline))  # provide original command line (it will be changed by the client) to be consistent with used devices
                th.external_socket_online = True
            except socket.error:
                th.external_socket_online = False

        # Initialize saver treads for information serialization (0 is acceptable)
        exec_ctx.saver_thread.adjust_num_workers(w=args.save_async_threads, own_cuda_stream=True, device_list=args.gpu)

        # Initialize local training threads in case of using multithreading implementation (0 is acceptable)
        exec_ctx.eval_thread_pool.adjust_num_workers(w=args.eval_async_threads, own_cuda_stream=True,
                                                     device_list=args.gpu)

        for th in exec_ctx.eval_thread_pool.threads:
            th.testset_copy = copy.deepcopy(testset)
            th.testloader_copy = torch.utils.data.DataLoader(th.testset_copy,
                                                             batch_size=test_batch_size,
                                                             num_workers=args.num_workers_test,
                                                             shuffle=False,
                                                             pin_memory=False,
                                                             drop_last=False
                                                             )
            if hasattr(th.testset_copy, "load_data"):
                th.testset_copy.load_data()

        if args.worker_listen_mode <= 0:
            # Path of execution for local simulation
            init_and_train_model(args, raw_cmdline, trainloader, testloader, exec_ctx)
        else:
            # Path of execution for assist simulation with a remote means
            init_and_help_with_compute(args, raw_cmdline, trainloader, testloader, exec_ctx)

    else:  # Evaluation mode
        # TODO: figure out how to exploit DataParallel. Currently - parallelism across workers
        model, criterion, round = get_training_elements(args.model, args.dataset, args, args.resume_from,
                                                        args.load_best, args.device, args.loss,
                                                        args.turn_off_batch_normalization_and_dropout)

        metrics = evaluate_model(model, testloader, criterion, args.device, round,
                                 print_freq=10, metric_to_optim=args.metric,
                                 is_rnn=args.model in RNN_MODELS)

        metrics_dict = create_metrics_dict(metrics)
        logger.info(f'Validation metrics: {metrics_dict}')
    wandb_wrapper.finishProject(projectWB)


def init_and_help_with_compute(args, raw_cmdline, trainloader, testloader, exec_ctx):
    logger = Logger.get(args.run_id)

    model_dir = create_model_dir(args)
    # don't train if setup already exists
    if os.path.isdir(model_dir):
        logger.info(f"{model_dir} already exists.")
        logger.info("Skipping this setup.")
        return
    # create model directory
    os.makedirs(model_dir, exist_ok=True)
    # save used args as json to experiment directory
    with open(os.path.join(create_model_dir(args), 'args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    is_rnn = args.model in RNN_MODELS
    # TODO: figure out how to exploit DataParallel. Currently - parallelism across workers
    model, criterion, current_round = get_training_elements(args.model, args.dataset, trainloader.dataset, args,
                                                            args.resume_from, args.load_best, args.device, args.loss,
                                                            args.turn_off_batch_normalization_and_dropout)
    # ==================================================================================================================
    # Reset execution seeds for tunable runtime behaviour
    if args.deterministic:
        exec_ctx.random.seed(args.manual_runtime_seed)
        exec_ctx.np_random.seed(args.manual_runtime_seed)
    # ==================================================================================================================
    mutils.print_models_info(model, args)
    logger.info(f"Number of parameters in the model: {mutils.number_of_params(model):,d}\n")
    gpu_utils.print_info_about_used_gpu(args.device, args.run_id)
    # ==================================================================================================================
    # Initialize server state
    # ==================================================================================================================
    trainloader.dataset.set_client(None)
    algorithms.evaluateGradient(None, model, trainloader, criterion, is_rnn, update_statistics=False,
                                evaluate_function=False, device=args.device, args=args)
    gradient_at_start = mutils.get_gradient(model)
    for p in model.parameters():
        p.grad = None
    H = algorithms.initializeServerState(args, model, trainloader.dataset.num_clients, gradient_at_start, exec_ctx)
    logger.info(f'D: {H["D"]} / D_include_frozen : {H["D_include_frozen"]}')

    # Append all launch arguments (setup and default)
    H["args"] = args
    H["raw_cmdline"] = raw_cmdline
    H["execution_context"] = exec_ctx
    H["total_clients"] = trainloader.dataset.num_clients
    H["comment"] = args.comment
    H["group-name"] = args.group_name
    # ==================================================================================================================
    local_optimiser = get_optimiser(model.parameters(), args.local_optimiser, args.local_lr, args.local_momentum,
                                    args.local_weight_decay)
    socket = exec_ctx.extra_
    state_dicts_thread_safe = Buffer()

    while True:
        # Main loop for obtaining commands from master
        cmd = socket.rawRecvString()
        if cmd != "finish_work" and cmd != "non_local_training":
            print("Unknown command: ", cmd)
            break

        if cmd == "finish_work":
            # Work termination
            break

        if cmd == "non_local_training":
            # Remote training. Param# 1- all params with local training
            msg = socket.rawRecv()
            serialized_args = pickle.loads(msg)

            (client_state, client_id, msg, model_dict_original, optimiser_dict_original,
             model_, train_loader_, criterion, local_optimiser_, device_, round, run_local_iters,
             number_of_local_steps, is_rnn, print_freq) = serialized_args

            # 1. Update to client specific parameters - compute device, execution context
            # 2. Move all tensors into device
            device = args.device
            client_state['device'] = args.device

            client_state['H']["execution_context"] = H["execution_context"]

            for k, v in client_state['H'].items():
                if torch.is_tensor(v):
                    client_state['H'][k] = v.to(device=device)

            model_dict_original = model_dict_original.to(device=device)

            res = local_training(None, client_state, client_id, msg, model_dict_original, optimiser_dict_original,
                                 model, trainloader, criterion, local_optimiser, device, round, run_local_iters,
                                 number_of_local_steps, is_rnn, print_freq, state_dicts_thread_safe,
                                 client_state['H']["run_id"])
            # ==========================================================================================================
            # Training is finished - response
            socket.rawSendString("result_of_local_training")
            msg = state_dicts_thread_safe.popFront()
            # Temporary remove reference to server state from client
            h_ctx = msg['client_state']['H']
            # Remove server state from sending
            del msg['client_state']['H']
            for k, v in msg.items():
                if torch.is_tensor(v):
                    msg[k] = v.cpu()
            socket.rawSend(pickle.dumps(msg))
            # Revert reference to server state from client
            msg['client_state']['H'] = h_ctx
            # ==========================================================================================================


def makeBackupOfServerState(H, round):
    args = H["args"]
    job_id = args.run_id
    fname = args.out

    if fname == "":
        fname = os.path.join(args.checkpoint_dir, f"{args.run_id}_{args.algorithm}_{args.global_lr}_backup.bin")
    else:
        fname = fname.replace(".bin", f"{args.run_id}_{args.algorithm}_{args.global_lr}_backup.bin")

    execution_context = H["execution_context"]
    H["execution_context"] = None
    exec_ctx_in_arg = ('exec_ctx' in H["args"])
    exec_ctx = None
    if exec_ctx_in_arg:
        exec_ctx = H["args"]['exec_ctx']
        H["execution_context"] = None

    with open(fname, "wb") as f:
        pickle.dump([(job_id, H)], f)

    H["execution_context"] = execution_context
    if exec_ctx_in_arg:
        H["args"]['exec_ctx'] = exec_ctx


def init_and_train_model(args, raw_cmdline, trainloader, testloader, exec_ctx):
    logger = Logger.get(args.run_id)
    model_dir = create_model_dir(args)
    # don't train if setup already exists
    if os.path.isdir(model_dir):
        logger.info(f"{model_dir} already exists.")
        logger.info("Skipping this setup.")
        return
    # create model directory
    os.makedirs(model_dir, exist_ok=True)
    # save used args as json to experiment directory
    with open(os.path.join(create_model_dir(args), 'args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    is_rnn = args.model in RNN_MODELS
    # TODO: figure out how to exploit DataParallel. Currently - parallelism across workers
    model, criterion, current_round = get_training_elements(args.model, args.dataset, trainloader.dataset, args,
                                                            args.resume_from, args.load_best, args.device, args.loss,
                                                            args.turn_off_batch_normalization_and_dropout)
    # ==================================================================================================================
    # Reset execution seeds for tunable runtime behaviour
    if args.deterministic:
        exec_ctx.random.seed(args.manual_runtime_seed)
        exec_ctx.np_random.seed(args.manual_runtime_seed)
    # ==================================================================================================================

    local_optimiser = get_optimiser(model.parameters(), args.local_optimiser, args.local_lr,
                                    args.local_momentum, args.local_weight_decay)

    local_scheduler = get_lr_scheduler(local_optimiser, args.rounds, args.local_lr_type)

    global_optimiser = get_optimiser(model.parameters(), args.global_optimiser, args.global_lr,
                                     args.global_momentum, args.global_weight_decay)

    global_scheduler = get_lr_scheduler(global_optimiser, args.rounds, args.global_lr_type)

    metric_to_optim = args.metric
    train_time_meter = 0

    best_metric = -np.inf
    eval_metrics = {}

    mutils.print_models_info(model, args)
    logger.info(f"Number of parameters in the model: {mutils.number_of_params(model):,d}\n")

    gpu_utils.print_info_about_used_gpu(args.device, args.run_id)

    sampled_clients = get_sampled_clients(trainloader.dataset.num_clients, args, exec_ctx)

    # ============= Initialize worker threads===========================================================================
    for th in exec_ctx.local_training_threads.threads:
        th.model_copy = copy.deepcopy(model)
        # th.model_copy, _ = initialise_model(args.model, args.dataset, trainloader.dataset, args, args.resume_from,
        #                                     args.load_best)
        # th.model_copy.load_state_dict(copy.deepcopy(model.state_dict()))
        th.model_copy = th.model_copy.to(th.device)

        th.local_optimiser_copy = get_optimiser(th.model_copy.parameters(), args.local_optimiser, args.local_lr,
                                                args.local_momentum, args.local_weight_decay)
        th.local_scheduler = get_lr_scheduler(th.local_optimiser_copy, args.rounds, args.local_lr_type)

    for th in exec_ctx.eval_thread_pool.threads:
        th.model_copy = copy.deepcopy(model)
        # th.model_copy, _ = initialise_model(args.model, args.dataset, trainloader.dataset, args, args.resume_from,
        #                                     args.load_best)
        # th.model_copy.load_state_dict(copy.deepcopy(model.state_dict()))
        th.model_copy = th.model_copy.to(th.device)

    for th in exec_ctx.saver_thread.threads:
        th.model_copy = copy.deepcopy(model)
        # th.model_copy, _ = initialise_model(args.model, args.dataset, trainloader.dataset, args, args.resume_from,
        #                                     args.load_best)
        # th.model_copy.load_state_dict(copy.deepcopy(model.state_dict()))
        th.model_copy = th.model_copy.to(th.device)

    exec_ctx.saver_thread.best_metric = -np.inf
    exec_ctx.saver_thread.eval_metrics = {}
    exec_ctx.saver_thread.best_metric_lock = threading.Lock()
    # ==================================================================================================================
    optimiser_dict_original = copy.deepcopy(local_optimiser.state_dict())
    # ==================================================================================================================

    logger.info(f"Use separate {exec_ctx.local_training_threads.workers()} worker threads for make local optimization")

    # Initialize server state

    if args.initialize_shifts_policy == "full_gradient_at_start":
        trainloader.dataset.set_client(None)
        algorithms.evaluateGradient(None, model, trainloader, criterion, is_rnn, update_statistics=False,
                                    evaluate_function=False, device=args.device, args=args)
        gradient_at_start = mutils.get_gradient(model)
        for p in model.parameters():
            p.grad = None
    else:
        D = mutils.number_of_params(model)
        gradient_at_start = torch.zeros(D).to(args.device)

    H = algorithms.initializeServerState(args, model, trainloader.dataset.num_clients, gradient_at_start, exec_ctx)
    logger.info(f'D: {H["D"]} / D_include_frozen : {H["D_include_frozen"]}')

    # ==================================================================================================================
    # Append all launch arguments (setup and default)
    H["args"] = args
    H["raw_cmdline"] = raw_cmdline
    H["execution_context"] = exec_ctx
    H["total_clients"] = trainloader.dataset.num_clients
    H["comment"] = args.comment
    H["group-name"] = args.group_name
    # H["sampled_clients"] = sampled_clients

    # Probable pre-scale x0
    if algorithms.has_experiment_option(H, "x0_norm"):
        H["x0"] = (H["x0"] / H["x0"].norm()) * algorithms.get_experiment_option_f(H, "x0_norm")

    # For task where we can obtain good estimation of L
    if "L" in dir(trainloader.dataset):
        H["L_compute"] = trainloader.dataset.L
    if "Li_all_clients" in dir(trainloader.dataset):
        H["Li_all_clients"] = trainloader.dataset.Li_all_clients
        H["Li_max"] = max(H["Li_all_clients"])

    # Initialize metrics
    H['best_metric'] = best_metric
    H['eval_metrics'] = copy.deepcopy(eval_metrics)

    # Initialize starting point of optimization
    mutils.set_params(model, H['x0'])

    is_target_gpu = gpu_utils.is_target_dev_gpu(args.device)

    if execution_context.simulation_start_fn:
        execution_context.simulation_start_fn(H)

    extra_opts = [a.strip() for a in args.extra_track.split(",")]

    fed_dataset_all_clients = np.arange(H['total_clients'])

    # Main Algorithm Optimization loop
    for i in range(args.rounds):
        # Check requirement to force stop simulation
        if execution_context.is_simulation_need_earlystop_fn is not None:
            if execution_context.is_simulation_need_earlystop_fn(H):
                break

        # Check that all is ok with current run
        if i > 0:
            prev_history_stats = H['history'][i - 1]
            forceTermination = False

            for elem in ["x_before_round", "grad_sgd_server_l2"]:
                if elem in prev_history_stats:
                    if math.isnan(prev_history_stats[elem]) or math.isinf(prev_history_stats[elem]):
                        logger.error(f"Force early stop due to numerical problems with {elem} at round {i}.")
                        forceTermination = True
                        break

            if forceTermination:
                break

        # Update information about current round
        H['current_round'] = i

        start = time.time()
        # Generate client state
        # TODO: metrics meter to be used for Tensorboard/WandB
        metrics_meter, fed_dataset_clients = run_one_communication_round(H, model, trainloader, criterion,
                                                                         local_optimiser,
                                                                         optimiser_dict_original, global_optimiser,
                                                                         args.device,
                                                                         current_round, args.run_local_steps,
                                                                         args.number_of_local_iters,
                                                                         is_rnn, sampled_clients)
        # train_time = time.time() - start
        if H['algorithm'] == 'gradskip':
            train_time = H['time']
        train_time_meter += train_time  # Track timings for across epochs average
        logger.debug(f'Epoch train time: {train_time}')
        # H['history'][current_round]["xi_after"] = mutils.get_params(model)
        H['history'][current_round]["train_time"] = train_time
        H['history'][current_round]["time"] = time.time() - exec_ctx.context_init_time
        if H['algorithm'] == 'gradskip':
            if current_round == 0:
                H['history'][current_round]["time"] = train_time
            else:
                H['history'][current_round]["time"] = H['history'][current_round - 1]["time"] + train_time
        H['last_round_elapsed_sec'] = train_time

        if (i % args.eval_every == 0 or i == (args.rounds - 1)) and metric_to_optim != "none":
            # Save results obtained so far into backup file
            # ==========================================================================================================
            # Serialize server state into filesystem"
            makeBackupOfServerState(H, i)
            # ==========================================================================================================
            # Evaluate model
            if args.eval_async_threads > 0:
                defered_eval_and_save_checkpoint(model, criterion, args, current_round, is_rnn=is_rnn,
                                                 metric_to_optim=metric_to_optim, exec_ctx=exec_ctx)

                # Update recent information about eval metrics
                exec_ctx.saver_thread.best_metric_lock.acquire()
                H['best_metric'] = max(H['best_metric'], exec_ctx.saver_thread.best_metric)
                H['eval_metrics'] = copy.deepcopy(exec_ctx.saver_thread.eval_metrics)
                exec_ctx.saver_thread.best_metric_lock.release()

            else:
                metrics = evaluate_model(model, testloader, criterion, args.device, current_round, print_freq=10,
                                         is_rnn=is_rnn, metric_to_optim=metric_to_optim)

                avg_metric = metrics[metric_to_optim].get_avg()

                cur_metrics = {"loss": metrics["loss"].get_avg(),
                               "top_1_acc": metrics["top_1_acc"].get_avg(),
                               "top_5_acc": metrics["top_5_acc"].get_avg(),
                               "neq_perplexity": metrics["neq_perplexity"].get_avg()
                               }

                eval_metrics.update({metrics['round']: cur_metrics})

                # Save model checkpoint
                model_filename = '{model}_{run_id}_checkpoint_{round:0>2d}.pth.tar'.format(model=args.model,
                                                                                           run_id=args.run_id,
                                                                                           round=current_round)

                is_best = avg_metric > best_metric
                save_checkpoint(model, model_filename, is_best=is_best, args=args, metrics=metrics,
                                metric_to_optim=metric_to_optim)

                if is_best:
                    best_metric = avg_metric

                # Update recent information about eval metrics
                H['best_metric'] = max(H['best_metric'], best_metric)
                H['eval_metrics'] = copy.deepcopy(eval_metrics)

                if np.isnan(metrics['loss'].get_avg()):
                    logger.critical('NaN loss detected, aborting training procedure.')
                #   return

                logger.info(f'Current lrs global:{global_scheduler.get_last_lr()}')
        # ==============================================================================================================
        # Add number of clients in that round
        H['history'][current_round]["number_of_client_in_round"] = len(fed_dataset_clients)

        if args.log_gpu_usage:
            H['history'][current_round]["memory_gpu_used"] = 0
            for dev in args.gpu:
                if gpu_utils.is_target_dev_gpu(dev):
                    memory_gpu_used = torch.cuda.memory_stats(args.device)['reserved_bytes.all.current']
                    logger.info(
                        f"GPU: Before round {i} we have used {memory_gpu_used / (1024.0 ** 2):.2f} MB from device {dev}")
                    H['history'][current_round]["memory_gpu_used"] += memory_gpu_used / (1024.0 ** 2)

        if current_round % args.eval_every == 0 or i == (args.rounds - 1):
            if "full_gradient_norm_train" in extra_opts and "full_objective_value_train" in extra_opts:
                fed_dataset = trainloader.dataset

                gradient_avg = mutils.get_zero_gradient_compatible_with_model(model)
                fvalue_avg = 0.0

                for c in fed_dataset_all_clients:
                    fed_dataset.set_client(c)
                    fvalue = algorithms.evaluateGradient(None, model, trainloader, criterion, is_rnn,
                                                         update_statistics=False, evaluate_function=True,
                                                         device=args.device, args=args)
                    g = mutils.get_gradient(model)

                    gradient_avg = (gradient_avg * c + g) / (c + 1)
                    fvalue_avg = (fvalue_avg * c + fvalue) / (c + 1)

                H['history'][current_round]["full_gradient_norm_train"] = mutils.l2_norm_of_vec(gradient_avg)
                H['history'][current_round]["full_objective_value_train"] = fvalue_avg

            if "full_gradient_norm_train" in extra_opts and "full_objective_value_train" not in extra_opts:
                fed_dataset = trainloader.dataset
                gradient_avg = mutils.get_zero_gradient_compatible_with_model(model)

                for c in fed_dataset_all_clients:
                    fed_dataset.set_client(c)
                    algorithms.evaluateGradient(None, model, trainloader, criterion, is_rnn, update_statistics=False,
                                                evaluate_function=False, device=args.device, args=args)
                    g = mutils.get_gradient(model)
                    gradient_avg = (gradient_avg * c + g) / (c + 1)

                H['history'][current_round]["full_gradient_norm_train"] = mutils.l2_norm_of_vec(gradient_avg)

            if "full_objective_value_train" in extra_opts and "full_gradient_norm_train" not in extra_opts:
                fed_dataset = trainloader.dataset
                fvalue_avg = 0.0
                for c in fed_dataset_all_clients:
                    fed_dataset.set_client(c)
                    fvalue = algorithms.evaluateFunction(None, model, trainloader, criterion, is_rnn,
                                                         update_statistics=False, device=args.device, args=args)
                    fvalue_avg = (fvalue_avg * c + fvalue) / (c + 1)
                H['history'][current_round]["full_objective_value_train"] = fvalue_avg
            # ==========================================================================================================
            if "full_gradient_norm_val" in extra_opts and "full_objective_value_val" in extra_opts:
                fvalue = algorithms.evaluateGradient(None, model, testloader, criterion, is_rnn,
                                                     update_statistics=False, evaluate_function=True,
                                                     device=args.device, args=args)
                H['history'][current_round]["full_gradient_norm_val"] = mutils.l2_norm_of_gradient_m(model)
                H['history'][current_round]["full_objective_value_val"] = fvalue

            if "full_gradient_norm_val" in extra_opts and "full_objective_value_val" not in extra_opts:
                algorithms.evaluateGradient(None, model, testloader, criterion, is_rnn, update_statistics=False,
                                            evaluate_function=False, device=args.device, args=args)
                H['history'][current_round]["full_gradient_norm_val"] = mutils.l2_norm_of_gradient_m(model)

            if "full_objective_value_val" in extra_opts and "full_gradient_norm_val" not in extra_opts:
                fvalue = algorithms.evaluateFunction(None, model, testloader, criterion, is_rnn,
                                                     update_statistics=False, device=args.device, args=args)
                H['history'][current_round]["full_objective_value_val"] = fvalue
            # ==========================================================================================================
            # Interpolate
            # ==========================================================================================================
            if current_round > 0:
                look_behind = args.eval_every

                # ======================================================================================================
                if i % args.eval_every == 0:
                    look_behind = args.eval_every  # Exactly look_behind previous evaluation had place to be
                elif i == args.rounds - 1:
                    look_behind = i % args.eval_every  # We are in case that we in last round "R-1".
                    pass  # Exactly i%args.eval_every round before previous eval had place to be
                else:
                    assert (not "Impossible case")
                # ======================================================================================================

                for s in range(1, look_behind):
                    alpha = s / float(look_behind)
                    prev_eval_round = current_round - look_behind
                    inter_round = prev_eval_round + s

                    # Linearly interpolate between a (alpha=0.0) and b(alpha=1.0)
                    def lerp(a, b, alpha):
                        return (1.0 - alpha) * a + alpha * b

                    if "full_gradient_norm_train" in extra_opts:
                        H['history'][inter_round]["full_gradient_norm_train"] = lerp(
                            H['history'][prev_eval_round]["full_gradient_norm_train"],
                            H['history'][current_round]["full_gradient_norm_train"], alpha)
                    if "full_objective_value_train" in extra_opts:
                        H['history'][inter_round]["full_objective_value_train"] = lerp(
                            H['history'][prev_eval_round]["full_objective_value_train"],
                            H['history'][current_round]["full_objective_value_train"], alpha)
                    if "full_gradient_norm_val" in extra_opts:
                        H['history'][inter_round]["full_gradient_norm_val"] = lerp(
                            H['history'][prev_eval_round]["full_gradient_norm_val"],
                            H['history'][current_round]["full_gradient_norm_val"], alpha)
                    if "full_objective_value_val" in extra_opts:
                        H['history'][inter_round]["full_objective_value_val"] = lerp(
                            H['history'][prev_eval_round]["full_objective_value_val"],
                            H['history'][current_round]["full_objective_value_val"], alpha)

        if current_round % args.eval_every == 0 or i == (args.rounds - 1):
            if execution_context.simulation_progress_steps_fn is not None:
                execution_context.simulation_progress_steps_fn(i / float(args.rounds) * 0.75, H)
            wandb_wrapper.logStatistics(H, current_round)

        # ==============================================================================================================
        # Save parameters of the last model
        # if i == args.rounds - 1:
        #    xfinal = mutils.get_params(model)
        #    H['xfinal'] = xfinal
        # ==============================================================================================================
        # Update schedulers
        if exec_ctx.local_training_threads.workers() > 0:
            # In this mode local optimizers for worker threads are used
            for th in exec_ctx.local_training_threads.threads:
                th.local_scheduler.step()
        else:
            # In this mode local optimizer is used
            local_scheduler.step()

        global_scheduler.step()

        # Increment rounder counter
        current_round += 1

        # Empty caches. Force cleanup the cache after round (good to fix fragmentation issues)
        if args.per_round_clean_torch_cache:
            execution_context.torch_global_lock.acquire()
            torch.cuda.empty_cache()
            execution_context.torch_global_lock.release()

    xfinal = mutils.get_params(model)
    H['xfinal'] = xfinal

    logger.debug(f'Average epoch train time: {train_time_meter / args.rounds}')
    # ==================================================================================================================
    logger.info("Wait for threads from a training threadpool")
    exec_ctx.local_training_threads.stop()
    if execution_context.simulation_progress_steps_fn is not None:
        execution_context.simulation_progress_steps_fn(0.8, H)

    logger.info("Wait for threads from a evaluate threadpool")
    exec_ctx.eval_thread_pool.stop()
    if execution_context.simulation_progress_steps_fn is not None:
        execution_context.simulation_progress_steps_fn(0.95, H)

    logger.info("Wait for threads from a serialization threadpool")
    exec_ctx.saver_thread.stop()

    # Not need for use in fact, but we do it for purpose of excluding any bugs inside PyTorch

    logger.info("Synchronize all GPU streams")
    if is_target_gpu:
        torch.cuda.synchronize(args.device)

    # Update metrics based on eval and save results
    H['best_metric'] = max(H['best_metric'], exec_ctx.saver_thread.best_metric)
    H['best_metric'] = max(H['best_metric'], best_metric)

    if len(eval_metrics) > 0:
        H['eval_metrics'] = copy.deepcopy(eval_metrics)
    else:
        H['eval_metrics'] = copy.deepcopy(exec_ctx.saver_thread.eval_metrics)

    if execution_context.simulation_progress_steps_fn is not None:
        execution_context.simulation_progress_steps_fn(1.0, H)
        # ==================================================================================================================
    # Final prune
    HKeys = list(H.keys())

    # Remove any tensor field huger then 1 MBytes from final solution
    tensor_prune_mb_threshold = 0.0

    for field in HKeys:
        if torch.is_tensor(H[field]):
            size_in_mbytes = H[field].numel() * H[field].element_size() / (1024.0 ** 2)
            if size_in_mbytes > tensor_prune_mb_threshold:
                del H[field]
                continue
            else:
                # Convert any not removed tensor to CPU
                H[field] = H[field].cpu()

        if type(H[field]) == list:
            sz = len(H[field])
            i = 0

            while i < sz:
                if torch.is_tensor(H[field][i]):
                    size_in_mbytes = H[field][i].numel() * H[field][i].element_size() / (1024.0 ** 2)
                    if size_in_mbytes >= tensor_prune_mb_threshold:
                        del H[field][i]
                        sz = len(H[field])
                    else:
                        # Convert any not removed tensor to CPU
                        H[field][i] = H[field][i].cpu()
                        i += 1
                else:
                    i += 1

    # Remove local optimiser state, client_compressors, reference to server state
    not_to_remove_from_client_state = ["error"]

    for round, history_state in H['history'].items():
        clients_history = history_state['client_states']
        for client_id, client_state in clients_history.items():
            client_state = client_state['client_state']

            if 'H' in client_state:
                del client_state['H']
            if 'optimiser' in client_state:
                del client_state['optimiser']
            if 'buffers' in client_state:
                del client_state['buffers']
            if 'client_compressor' in client_state:
                del client_state['client_compressor']

            client_state_keys = list(client_state.keys())
            for field in client_state_keys:

                if field in not_to_remove_from_client_state:
                    continue

                if torch.is_tensor(client_state[field]):
                    size_in_mbytes = client_state[field].numel() * client_state[field].element_size() / (1024.0 ** 2)
                    if size_in_mbytes > tensor_prune_mb_threshold:
                        del client_state[field]
                        continue
                    else:
                        # Convert any not removed tensor to CPU
                        client_state[field] = client_state[field].cpu()
                        continue

                if type(client_state[field]) == list:
                    sz = len(client_state[field])
                    i = 0
                    while i < sz:
                        if torch.is_tensor(client_state[field][i]):
                            size_in_mbytes = client_state[field][i].numel() * client_state[field][i].element_size() / (
                                    1024.0 ** 2)
                            if size_in_mbytes >= tensor_prune_mb_threshold:
                                del client_state[field][i]
                                sz = len(client_state[field])
                            else:
                                client_state[field][i] = client_state[field][i].cpu()
                                i += 1
                        else:
                            i += 1

    # Remove reference for execution context
    del H["execution_context"]
    if 'exec_ctx' in H["args"]:
        del H["args"]

    # ==================================================================================================================
    # Cleanup execution context resources
    # ==================================================================================================================
    execution_context.resetExecutionContext(exec_ctx)
    # ==================================================================================================================
    if execution_context.simulation_finish_fn is not None:
        execution_context.simulation_finish_fn(H)


def runSimulation(cmdline, extra_=None):
    # ==================================================================================================================
    # ENTRY POINT FOR LAUNCH SIMULATION FROM GUI
    # ==================================================================================================================
    global CUDA_SUPPORT

    args = parse_args(cmdline)
    if args.hostname == "":
        args.hostname = socket.gethostname()

    Logger.setup_logging(args.loglevel, args.logfilter, logfile=args.logfile)
    logger = Logger.get(args.run_id)

    logger.info(f"CLI args original: {cmdline}")

    if torch.cuda.device_count():
        CUDA_SUPPORT = True
        # torch.backends.cudnn.benchmark = True
    else:
        logger.warning('CUDA unsupported!!')
        CUDA_SUPPORT = False

    if not CUDA_SUPPORT:
        args.gpu = "cpu"

    if args.deterministic:
        import torch.backends.cudnn as cudnn
        import os
        import random

        # TODO: This settings are not thread safe in case of using cuda backend from several threads
        # Project use execution context random generators for random and numpy in thread safe way
        if CUDA_SUPPORT:
            cudnn.deterministic = args.deterministic
            cudnn.benchmark = not args.deterministic
            torch.cuda.manual_seed(args.manual_init_seed)
            torch.cuda.manual_seed_all(args.manual_init_seed)

            # Turn off Tensor Cores if have been requested
            torch.backends.cudnn.allow_tf32 = args.allow_use_nv_tensorcores
            torch.backends.cuda.matmul.allow_tf32 = args.allow_use_nv_tensorcores

        torch.manual_seed(args.manual_init_seed)
        random.seed(args.manual_init_seed)
        np.random.seed(args.manual_init_seed)

        os.environ['PYTHONHASHSEED'] = str(args.manual_init_seed)
        torch.backends.cudnn.deterministic = True

    main(args, cmdline, extra_)


# ======================================================================================================================
g_Terminate = False


# ======================================================================================================================
def terminationWithSignal(*args):
    global g_Terminate
    g_Terminate = True
    print("Request for terminate process has been obtained. Save results and terminate.")


def isSimulationNeedEarlyStopCmdLine(H):
    global g_Terminate
    return g_Terminate


# ======================================================================================================================

def saveResult(H):
    """Serialize server state into filesystem"""
    args = H["args"]

    job_id = args.run_id
    fname = args.out
    if fname == "":
        fname = os.path.join(args.checkpoint_dir, f"{args.algorithm}_simulation_history.bin")

    # Save formally list of tuples (job_id, server_state)
    with open(fname, "wb") as f:
        pickle.dump([(job_id, H)], f)

    print(f"Final server state for {job_id} is serialized into: '{fname}'")
    print(f"Program which uses '{args.algorithm}' algorithm for {args.rounds} rounds was finished")


class ClientThread(threading.Thread):
    def __init__(self, clientSocket, listenPort):
        threading.Thread.__init__(self)
        self.clientSocket = clientSocket  # OS-like socket to communicate with information
        self.comSocket = comm_socket.CommSocket(clientSocket)  # Light construction on top of socket to communicate
        self.listenPort = listenPort  # Listening port

    def run(self):
        cmd = self.comSocket.rawRecvString()  # Obtain command

        t = time.localtime()
        cur_time = time.strftime("%H:%M:%S", t)
        print(f"-- {cur_time}: The received command: {cmd}")

        if cmd == "list_of_gpus":  # List of resources in the system
            gpus = len(gpu_utils.get_available_gpus())
            self.comSocket.rawSendString(f"{gpus}")
        elif cmd == "execute_work":  # Execute work
            devConfig = self.comSocket.rawRecvString()  # 1-st param: device configuration
            cmdLine = self.comSocket.rawRecv()  # 2-nd param: command line
            cmdLine = pickle.loads(cmdLine)  # unpack cmdline

            ip, port, dev_name, dev_number = devConfig.split(":")

            # Remove unnecessary flags for worker
            i = len(cmdLine) - 1
            while i >= 0:
                if cmdLine[i] == "--metric":
                    cmdLine[i + 1] = "none"

                if cmdLine[i] == "--gpu":
                    cmdLine[i + 1] = dev_number  # update device number

                if cmdLine[i] in ['--wandb-key', '--eval-async-threads', '--save-async-threads',
                                  '--threadpool-for-local-opt', '--external-devices', '--evaluate']:
                    del cmdLine[i]  # remove flag
                    del cmdLine[i]  # remove value for flag (originally at index i+1)
                i -= 1

            # Specify listen mode in which worker is
            cmdLine.append("--worker-listen-mode")
            cmdLine.append(str(self.listenPort))

            runSimulation(cmdLine, extra_=self.comSocket)
        else:
            print(f"The received command {cmd} is not valid")


def executeRemoteWorkerSupportMode(listenPort):
    """
    Listen for a specified port and wait for incoming connection. Maximum number of pending connection is 5.
    """
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    serversocket.bind(('0.0.0.0', listenPort))
    MAX_CONNECTIONS = 5
    serversocket.listen(MAX_CONNECTIONS)

    while True:
        t = time.localtime()
        cur_time = time.strftime("%H:%M:%S", t)
        print(f"-- {cur_time}: Waiting for commands from the master by worker {socket.gethostname()}:{listenPort}")
        # accept connections from outside and process it in separate threads
        (clientsocket, address) = serversocket.accept()
        ct = ClientThread(clientsocket, listenPort)
        ct.run()


if __name__ == "__main__":
    # ==================================================================================================================
    # ENTRY POINT FOR CUI
    # ==================================================================================================================
    # Worker listener mode
    for i in range(len(sys.argv)):
        if sys.argv[i] == "--worker-listen-mode":
            print(get_pretty_env_info())
            print("")
            executeRemoteWorkerSupportMode(int(sys.argv[i + 1]))
            sys.exit(0)
    # ==================================================================================================================
    # Usual mode

    # Setup signal for CTRL+C(SIGINT) or SIGTERM
    signal.signal(signal.SIGINT, terminationWithSignal)
    signal.signal(signal.SIGTERM, terminationWithSignal)

    execution_context.is_simulation_need_earlystop_fn = isSimulationNeedEarlyStopCmdLine
    execution_context.simulation_finish_fn = saveResult
    runSimulation(sys.argv[1:])
    sys.exit(0)
    # ==================================================================================================================

Functions

def executeRemoteWorkerSupportMode(listenPort)

Listen for a specified port and wait for incoming connection. Maximum number of pending connection is 5.

Expand source code
def executeRemoteWorkerSupportMode(listenPort):
    """
    Listen for a specified port and wait for incoming connection. Maximum number of pending connection is 5.
    """
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    serversocket.bind(('0.0.0.0', listenPort))
    MAX_CONNECTIONS = 5
    serversocket.listen(MAX_CONNECTIONS)

    while True:
        t = time.localtime()
        cur_time = time.strftime("%H:%M:%S", t)
        print(f"-- {cur_time}: Waiting for commands from the master by worker {socket.gethostname()}:{listenPort}")
        # accept connections from outside and process it in separate threads
        (clientsocket, address) = serversocket.accept()
        ct = ClientThread(clientsocket, listenPort)
        ct.run()
def init_and_help_with_compute(args, raw_cmdline, trainloader, testloader, exec_ctx)
Expand source code
def init_and_help_with_compute(args, raw_cmdline, trainloader, testloader, exec_ctx):
    logger = Logger.get(args.run_id)

    model_dir = create_model_dir(args)
    # don't train if setup already exists
    if os.path.isdir(model_dir):
        logger.info(f"{model_dir} already exists.")
        logger.info("Skipping this setup.")
        return
    # create model directory
    os.makedirs(model_dir, exist_ok=True)
    # save used args as json to experiment directory
    with open(os.path.join(create_model_dir(args), 'args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    is_rnn = args.model in RNN_MODELS
    # TODO: figure out how to exploit DataParallel. Currently - parallelism across workers
    model, criterion, current_round = get_training_elements(args.model, args.dataset, trainloader.dataset, args,
                                                            args.resume_from, args.load_best, args.device, args.loss,
                                                            args.turn_off_batch_normalization_and_dropout)
    # ==================================================================================================================
    # Reset execution seeds for tunable runtime behaviour
    if args.deterministic:
        exec_ctx.random.seed(args.manual_runtime_seed)
        exec_ctx.np_random.seed(args.manual_runtime_seed)
    # ==================================================================================================================
    mutils.print_models_info(model, args)
    logger.info(f"Number of parameters in the model: {mutils.number_of_params(model):,d}\n")
    gpu_utils.print_info_about_used_gpu(args.device, args.run_id)
    # ==================================================================================================================
    # Initialize server state
    # ==================================================================================================================
    trainloader.dataset.set_client(None)
    algorithms.evaluateGradient(None, model, trainloader, criterion, is_rnn, update_statistics=False,
                                evaluate_function=False, device=args.device, args=args)
    gradient_at_start = mutils.get_gradient(model)
    for p in model.parameters():
        p.grad = None
    H = algorithms.initializeServerState(args, model, trainloader.dataset.num_clients, gradient_at_start, exec_ctx)
    logger.info(f'D: {H["D"]} / D_include_frozen : {H["D_include_frozen"]}')

    # Append all launch arguments (setup and default)
    H["args"] = args
    H["raw_cmdline"] = raw_cmdline
    H["execution_context"] = exec_ctx
    H["total_clients"] = trainloader.dataset.num_clients
    H["comment"] = args.comment
    H["group-name"] = args.group_name
    # ==================================================================================================================
    local_optimiser = get_optimiser(model.parameters(), args.local_optimiser, args.local_lr, args.local_momentum,
                                    args.local_weight_decay)
    socket = exec_ctx.extra_
    state_dicts_thread_safe = Buffer()

    while True:
        # Main loop for obtaining commands from master
        cmd = socket.rawRecvString()
        if cmd != "finish_work" and cmd != "non_local_training":
            print("Unknown command: ", cmd)
            break

        if cmd == "finish_work":
            # Work termination
            break

        if cmd == "non_local_training":
            # Remote training. Param# 1- all params with local training
            msg = socket.rawRecv()
            serialized_args = pickle.loads(msg)

            (client_state, client_id, msg, model_dict_original, optimiser_dict_original,
             model_, train_loader_, criterion, local_optimiser_, device_, round, run_local_iters,
             number_of_local_steps, is_rnn, print_freq) = serialized_args

            # 1. Update to client specific parameters - compute device, execution context
            # 2. Move all tensors into device
            device = args.device
            client_state['device'] = args.device

            client_state['H']["execution_context"] = H["execution_context"]

            for k, v in client_state['H'].items():
                if torch.is_tensor(v):
                    client_state['H'][k] = v.to(device=device)

            model_dict_original = model_dict_original.to(device=device)

            res = local_training(None, client_state, client_id, msg, model_dict_original, optimiser_dict_original,
                                 model, trainloader, criterion, local_optimiser, device, round, run_local_iters,
                                 number_of_local_steps, is_rnn, print_freq, state_dicts_thread_safe,
                                 client_state['H']["run_id"])
            # ==========================================================================================================
            # Training is finished - response
            socket.rawSendString("result_of_local_training")
            msg = state_dicts_thread_safe.popFront()
            # Temporary remove reference to server state from client
            h_ctx = msg['client_state']['H']
            # Remove server state from sending
            del msg['client_state']['H']
            for k, v in msg.items():
                if torch.is_tensor(v):
                    msg[k] = v.cpu()
            socket.rawSend(pickle.dumps(msg))
            # Revert reference to server state from client
            msg['client_state']['H'] = h_ctx
            # ==========================================================================================================
def init_and_train_model(args, raw_cmdline, trainloader, testloader, exec_ctx)
Expand source code
def init_and_train_model(args, raw_cmdline, trainloader, testloader, exec_ctx):
    logger = Logger.get(args.run_id)
    model_dir = create_model_dir(args)
    # don't train if setup already exists
    if os.path.isdir(model_dir):
        logger.info(f"{model_dir} already exists.")
        logger.info("Skipping this setup.")
        return
    # create model directory
    os.makedirs(model_dir, exist_ok=True)
    # save used args as json to experiment directory
    with open(os.path.join(create_model_dir(args), 'args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    is_rnn = args.model in RNN_MODELS
    # TODO: figure out how to exploit DataParallel. Currently - parallelism across workers
    model, criterion, current_round = get_training_elements(args.model, args.dataset, trainloader.dataset, args,
                                                            args.resume_from, args.load_best, args.device, args.loss,
                                                            args.turn_off_batch_normalization_and_dropout)
    # ==================================================================================================================
    # Reset execution seeds for tunable runtime behaviour
    if args.deterministic:
        exec_ctx.random.seed(args.manual_runtime_seed)
        exec_ctx.np_random.seed(args.manual_runtime_seed)
    # ==================================================================================================================

    local_optimiser = get_optimiser(model.parameters(), args.local_optimiser, args.local_lr,
                                    args.local_momentum, args.local_weight_decay)

    local_scheduler = get_lr_scheduler(local_optimiser, args.rounds, args.local_lr_type)

    global_optimiser = get_optimiser(model.parameters(), args.global_optimiser, args.global_lr,
                                     args.global_momentum, args.global_weight_decay)

    global_scheduler = get_lr_scheduler(global_optimiser, args.rounds, args.global_lr_type)

    metric_to_optim = args.metric
    train_time_meter = 0

    best_metric = -np.inf
    eval_metrics = {}

    mutils.print_models_info(model, args)
    logger.info(f"Number of parameters in the model: {mutils.number_of_params(model):,d}\n")

    gpu_utils.print_info_about_used_gpu(args.device, args.run_id)

    sampled_clients = get_sampled_clients(trainloader.dataset.num_clients, args, exec_ctx)

    # ============= Initialize worker threads===========================================================================
    for th in exec_ctx.local_training_threads.threads:
        th.model_copy = copy.deepcopy(model)
        # th.model_copy, _ = initialise_model(args.model, args.dataset, trainloader.dataset, args, args.resume_from,
        #                                     args.load_best)
        # th.model_copy.load_state_dict(copy.deepcopy(model.state_dict()))
        th.model_copy = th.model_copy.to(th.device)

        th.local_optimiser_copy = get_optimiser(th.model_copy.parameters(), args.local_optimiser, args.local_lr,
                                                args.local_momentum, args.local_weight_decay)
        th.local_scheduler = get_lr_scheduler(th.local_optimiser_copy, args.rounds, args.local_lr_type)

    for th in exec_ctx.eval_thread_pool.threads:
        th.model_copy = copy.deepcopy(model)
        # th.model_copy, _ = initialise_model(args.model, args.dataset, trainloader.dataset, args, args.resume_from,
        #                                     args.load_best)
        # th.model_copy.load_state_dict(copy.deepcopy(model.state_dict()))
        th.model_copy = th.model_copy.to(th.device)

    for th in exec_ctx.saver_thread.threads:
        th.model_copy = copy.deepcopy(model)
        # th.model_copy, _ = initialise_model(args.model, args.dataset, trainloader.dataset, args, args.resume_from,
        #                                     args.load_best)
        # th.model_copy.load_state_dict(copy.deepcopy(model.state_dict()))
        th.model_copy = th.model_copy.to(th.device)

    exec_ctx.saver_thread.best_metric = -np.inf
    exec_ctx.saver_thread.eval_metrics = {}
    exec_ctx.saver_thread.best_metric_lock = threading.Lock()
    # ==================================================================================================================
    optimiser_dict_original = copy.deepcopy(local_optimiser.state_dict())
    # ==================================================================================================================

    logger.info(f"Use separate {exec_ctx.local_training_threads.workers()} worker threads for make local optimization")

    # Initialize server state

    if args.initialize_shifts_policy == "full_gradient_at_start":
        trainloader.dataset.set_client(None)
        algorithms.evaluateGradient(None, model, trainloader, criterion, is_rnn, update_statistics=False,
                                    evaluate_function=False, device=args.device, args=args)
        gradient_at_start = mutils.get_gradient(model)
        for p in model.parameters():
            p.grad = None
    else:
        D = mutils.number_of_params(model)
        gradient_at_start = torch.zeros(D).to(args.device)

    H = algorithms.initializeServerState(args, model, trainloader.dataset.num_clients, gradient_at_start, exec_ctx)
    logger.info(f'D: {H["D"]} / D_include_frozen : {H["D_include_frozen"]}')

    # ==================================================================================================================
    # Append all launch arguments (setup and default)
    H["args"] = args
    H["raw_cmdline"] = raw_cmdline
    H["execution_context"] = exec_ctx
    H["total_clients"] = trainloader.dataset.num_clients
    H["comment"] = args.comment
    H["group-name"] = args.group_name
    # H["sampled_clients"] = sampled_clients

    # Probable pre-scale x0
    if algorithms.has_experiment_option(H, "x0_norm"):
        H["x0"] = (H["x0"] / H["x0"].norm()) * algorithms.get_experiment_option_f(H, "x0_norm")

    # For task where we can obtain good estimation of L
    if "L" in dir(trainloader.dataset):
        H["L_compute"] = trainloader.dataset.L
    if "Li_all_clients" in dir(trainloader.dataset):
        H["Li_all_clients"] = trainloader.dataset.Li_all_clients
        H["Li_max"] = max(H["Li_all_clients"])

    # Initialize metrics
    H['best_metric'] = best_metric
    H['eval_metrics'] = copy.deepcopy(eval_metrics)

    # Initialize starting point of optimization
    mutils.set_params(model, H['x0'])

    is_target_gpu = gpu_utils.is_target_dev_gpu(args.device)

    if execution_context.simulation_start_fn:
        execution_context.simulation_start_fn(H)

    extra_opts = [a.strip() for a in args.extra_track.split(",")]

    fed_dataset_all_clients = np.arange(H['total_clients'])

    # Main Algorithm Optimization loop
    for i in range(args.rounds):
        # Check requirement to force stop simulation
        if execution_context.is_simulation_need_earlystop_fn is not None:
            if execution_context.is_simulation_need_earlystop_fn(H):
                break

        # Check that all is ok with current run
        if i > 0:
            prev_history_stats = H['history'][i - 1]
            forceTermination = False

            for elem in ["x_before_round", "grad_sgd_server_l2"]:
                if elem in prev_history_stats:
                    if math.isnan(prev_history_stats[elem]) or math.isinf(prev_history_stats[elem]):
                        logger.error(f"Force early stop due to numerical problems with {elem} at round {i}.")
                        forceTermination = True
                        break

            if forceTermination:
                break

        # Update information about current round
        H['current_round'] = i

        start = time.time()
        # Generate client state
        # TODO: metrics meter to be used for Tensorboard/WandB
        metrics_meter, fed_dataset_clients = run_one_communication_round(H, model, trainloader, criterion,
                                                                         local_optimiser,
                                                                         optimiser_dict_original, global_optimiser,
                                                                         args.device,
                                                                         current_round, args.run_local_steps,
                                                                         args.number_of_local_iters,
                                                                         is_rnn, sampled_clients)
        # train_time = time.time() - start
        if H['algorithm'] == 'gradskip':
            train_time = H['time']
        train_time_meter += train_time  # Track timings for across epochs average
        logger.debug(f'Epoch train time: {train_time}')
        # H['history'][current_round]["xi_after"] = mutils.get_params(model)
        H['history'][current_round]["train_time"] = train_time
        H['history'][current_round]["time"] = time.time() - exec_ctx.context_init_time
        if H['algorithm'] == 'gradskip':
            if current_round == 0:
                H['history'][current_round]["time"] = train_time
            else:
                H['history'][current_round]["time"] = H['history'][current_round - 1]["time"] + train_time
        H['last_round_elapsed_sec'] = train_time

        if (i % args.eval_every == 0 or i == (args.rounds - 1)) and metric_to_optim != "none":
            # Save results obtained so far into backup file
            # ==========================================================================================================
            # Serialize server state into filesystem"
            makeBackupOfServerState(H, i)
            # ==========================================================================================================
            # Evaluate model
            if args.eval_async_threads > 0:
                defered_eval_and_save_checkpoint(model, criterion, args, current_round, is_rnn=is_rnn,
                                                 metric_to_optim=metric_to_optim, exec_ctx=exec_ctx)

                # Update recent information about eval metrics
                exec_ctx.saver_thread.best_metric_lock.acquire()
                H['best_metric'] = max(H['best_metric'], exec_ctx.saver_thread.best_metric)
                H['eval_metrics'] = copy.deepcopy(exec_ctx.saver_thread.eval_metrics)
                exec_ctx.saver_thread.best_metric_lock.release()

            else:
                metrics = evaluate_model(model, testloader, criterion, args.device, current_round, print_freq=10,
                                         is_rnn=is_rnn, metric_to_optim=metric_to_optim)

                avg_metric = metrics[metric_to_optim].get_avg()

                cur_metrics = {"loss": metrics["loss"].get_avg(),
                               "top_1_acc": metrics["top_1_acc"].get_avg(),
                               "top_5_acc": metrics["top_5_acc"].get_avg(),
                               "neq_perplexity": metrics["neq_perplexity"].get_avg()
                               }

                eval_metrics.update({metrics['round']: cur_metrics})

                # Save model checkpoint
                model_filename = '{model}_{run_id}_checkpoint_{round:0>2d}.pth.tar'.format(model=args.model,
                                                                                           run_id=args.run_id,
                                                                                           round=current_round)

                is_best = avg_metric > best_metric
                save_checkpoint(model, model_filename, is_best=is_best, args=args, metrics=metrics,
                                metric_to_optim=metric_to_optim)

                if is_best:
                    best_metric = avg_metric

                # Update recent information about eval metrics
                H['best_metric'] = max(H['best_metric'], best_metric)
                H['eval_metrics'] = copy.deepcopy(eval_metrics)

                if np.isnan(metrics['loss'].get_avg()):
                    logger.critical('NaN loss detected, aborting training procedure.')
                #   return

                logger.info(f'Current lrs global:{global_scheduler.get_last_lr()}')
        # ==============================================================================================================
        # Add number of clients in that round
        H['history'][current_round]["number_of_client_in_round"] = len(fed_dataset_clients)

        if args.log_gpu_usage:
            H['history'][current_round]["memory_gpu_used"] = 0
            for dev in args.gpu:
                if gpu_utils.is_target_dev_gpu(dev):
                    memory_gpu_used = torch.cuda.memory_stats(args.device)['reserved_bytes.all.current']
                    logger.info(
                        f"GPU: Before round {i} we have used {memory_gpu_used / (1024.0 ** 2):.2f} MB from device {dev}")
                    H['history'][current_round]["memory_gpu_used"] += memory_gpu_used / (1024.0 ** 2)

        if current_round % args.eval_every == 0 or i == (args.rounds - 1):
            if "full_gradient_norm_train" in extra_opts and "full_objective_value_train" in extra_opts:
                fed_dataset = trainloader.dataset

                gradient_avg = mutils.get_zero_gradient_compatible_with_model(model)
                fvalue_avg = 0.0

                for c in fed_dataset_all_clients:
                    fed_dataset.set_client(c)
                    fvalue = algorithms.evaluateGradient(None, model, trainloader, criterion, is_rnn,
                                                         update_statistics=False, evaluate_function=True,
                                                         device=args.device, args=args)
                    g = mutils.get_gradient(model)

                    gradient_avg = (gradient_avg * c + g) / (c + 1)
                    fvalue_avg = (fvalue_avg * c + fvalue) / (c + 1)

                H['history'][current_round]["full_gradient_norm_train"] = mutils.l2_norm_of_vec(gradient_avg)
                H['history'][current_round]["full_objective_value_train"] = fvalue_avg

            if "full_gradient_norm_train" in extra_opts and "full_objective_value_train" not in extra_opts:
                fed_dataset = trainloader.dataset
                gradient_avg = mutils.get_zero_gradient_compatible_with_model(model)

                for c in fed_dataset_all_clients:
                    fed_dataset.set_client(c)
                    algorithms.evaluateGradient(None, model, trainloader, criterion, is_rnn, update_statistics=False,
                                                evaluate_function=False, device=args.device, args=args)
                    g = mutils.get_gradient(model)
                    gradient_avg = (gradient_avg * c + g) / (c + 1)

                H['history'][current_round]["full_gradient_norm_train"] = mutils.l2_norm_of_vec(gradient_avg)

            if "full_objective_value_train" in extra_opts and "full_gradient_norm_train" not in extra_opts:
                fed_dataset = trainloader.dataset
                fvalue_avg = 0.0
                for c in fed_dataset_all_clients:
                    fed_dataset.set_client(c)
                    fvalue = algorithms.evaluateFunction(None, model, trainloader, criterion, is_rnn,
                                                         update_statistics=False, device=args.device, args=args)
                    fvalue_avg = (fvalue_avg * c + fvalue) / (c + 1)
                H['history'][current_round]["full_objective_value_train"] = fvalue_avg
            # ==========================================================================================================
            if "full_gradient_norm_val" in extra_opts and "full_objective_value_val" in extra_opts:
                fvalue = algorithms.evaluateGradient(None, model, testloader, criterion, is_rnn,
                                                     update_statistics=False, evaluate_function=True,
                                                     device=args.device, args=args)
                H['history'][current_round]["full_gradient_norm_val"] = mutils.l2_norm_of_gradient_m(model)
                H['history'][current_round]["full_objective_value_val"] = fvalue

            if "full_gradient_norm_val" in extra_opts and "full_objective_value_val" not in extra_opts:
                algorithms.evaluateGradient(None, model, testloader, criterion, is_rnn, update_statistics=False,
                                            evaluate_function=False, device=args.device, args=args)
                H['history'][current_round]["full_gradient_norm_val"] = mutils.l2_norm_of_gradient_m(model)

            if "full_objective_value_val" in extra_opts and "full_gradient_norm_val" not in extra_opts:
                fvalue = algorithms.evaluateFunction(None, model, testloader, criterion, is_rnn,
                                                     update_statistics=False, device=args.device, args=args)
                H['history'][current_round]["full_objective_value_val"] = fvalue
            # ==========================================================================================================
            # Interpolate
            # ==========================================================================================================
            if current_round > 0:
                look_behind = args.eval_every

                # ======================================================================================================
                if i % args.eval_every == 0:
                    look_behind = args.eval_every  # Exactly look_behind previous evaluation had place to be
                elif i == args.rounds - 1:
                    look_behind = i % args.eval_every  # We are in case that we in last round "R-1".
                    pass  # Exactly i%args.eval_every round before previous eval had place to be
                else:
                    assert (not "Impossible case")
                # ======================================================================================================

                for s in range(1, look_behind):
                    alpha = s / float(look_behind)
                    prev_eval_round = current_round - look_behind
                    inter_round = prev_eval_round + s

                    # Linearly interpolate between a (alpha=0.0) and b(alpha=1.0)
                    def lerp(a, b, alpha):
                        return (1.0 - alpha) * a + alpha * b

                    if "full_gradient_norm_train" in extra_opts:
                        H['history'][inter_round]["full_gradient_norm_train"] = lerp(
                            H['history'][prev_eval_round]["full_gradient_norm_train"],
                            H['history'][current_round]["full_gradient_norm_train"], alpha)
                    if "full_objective_value_train" in extra_opts:
                        H['history'][inter_round]["full_objective_value_train"] = lerp(
                            H['history'][prev_eval_round]["full_objective_value_train"],
                            H['history'][current_round]["full_objective_value_train"], alpha)
                    if "full_gradient_norm_val" in extra_opts:
                        H['history'][inter_round]["full_gradient_norm_val"] = lerp(
                            H['history'][prev_eval_round]["full_gradient_norm_val"],
                            H['history'][current_round]["full_gradient_norm_val"], alpha)
                    if "full_objective_value_val" in extra_opts:
                        H['history'][inter_round]["full_objective_value_val"] = lerp(
                            H['history'][prev_eval_round]["full_objective_value_val"],
                            H['history'][current_round]["full_objective_value_val"], alpha)

        if current_round % args.eval_every == 0 or i == (args.rounds - 1):
            if execution_context.simulation_progress_steps_fn is not None:
                execution_context.simulation_progress_steps_fn(i / float(args.rounds) * 0.75, H)
            wandb_wrapper.logStatistics(H, current_round)

        # ==============================================================================================================
        # Save parameters of the last model
        # if i == args.rounds - 1:
        #    xfinal = mutils.get_params(model)
        #    H['xfinal'] = xfinal
        # ==============================================================================================================
        # Update schedulers
        if exec_ctx.local_training_threads.workers() > 0:
            # In this mode local optimizers for worker threads are used
            for th in exec_ctx.local_training_threads.threads:
                th.local_scheduler.step()
        else:
            # In this mode local optimizer is used
            local_scheduler.step()

        global_scheduler.step()

        # Increment rounder counter
        current_round += 1

        # Empty caches. Force cleanup the cache after round (good to fix fragmentation issues)
        if args.per_round_clean_torch_cache:
            execution_context.torch_global_lock.acquire()
            torch.cuda.empty_cache()
            execution_context.torch_global_lock.release()

    xfinal = mutils.get_params(model)
    H['xfinal'] = xfinal

    logger.debug(f'Average epoch train time: {train_time_meter / args.rounds}')
    # ==================================================================================================================
    logger.info("Wait for threads from a training threadpool")
    exec_ctx.local_training_threads.stop()
    if execution_context.simulation_progress_steps_fn is not None:
        execution_context.simulation_progress_steps_fn(0.8, H)

    logger.info("Wait for threads from a evaluate threadpool")
    exec_ctx.eval_thread_pool.stop()
    if execution_context.simulation_progress_steps_fn is not None:
        execution_context.simulation_progress_steps_fn(0.95, H)

    logger.info("Wait for threads from a serialization threadpool")
    exec_ctx.saver_thread.stop()

    # Not need for use in fact, but we do it for purpose of excluding any bugs inside PyTorch

    logger.info("Synchronize all GPU streams")
    if is_target_gpu:
        torch.cuda.synchronize(args.device)

    # Update metrics based on eval and save results
    H['best_metric'] = max(H['best_metric'], exec_ctx.saver_thread.best_metric)
    H['best_metric'] = max(H['best_metric'], best_metric)

    if len(eval_metrics) > 0:
        H['eval_metrics'] = copy.deepcopy(eval_metrics)
    else:
        H['eval_metrics'] = copy.deepcopy(exec_ctx.saver_thread.eval_metrics)

    if execution_context.simulation_progress_steps_fn is not None:
        execution_context.simulation_progress_steps_fn(1.0, H)
        # ==================================================================================================================
    # Final prune
    HKeys = list(H.keys())

    # Remove any tensor field huger then 1 MBytes from final solution
    tensor_prune_mb_threshold = 0.0

    for field in HKeys:
        if torch.is_tensor(H[field]):
            size_in_mbytes = H[field].numel() * H[field].element_size() / (1024.0 ** 2)
            if size_in_mbytes > tensor_prune_mb_threshold:
                del H[field]
                continue
            else:
                # Convert any not removed tensor to CPU
                H[field] = H[field].cpu()

        if type(H[field]) == list:
            sz = len(H[field])
            i = 0

            while i < sz:
                if torch.is_tensor(H[field][i]):
                    size_in_mbytes = H[field][i].numel() * H[field][i].element_size() / (1024.0 ** 2)
                    if size_in_mbytes >= tensor_prune_mb_threshold:
                        del H[field][i]
                        sz = len(H[field])
                    else:
                        # Convert any not removed tensor to CPU
                        H[field][i] = H[field][i].cpu()
                        i += 1
                else:
                    i += 1

    # Remove local optimiser state, client_compressors, reference to server state
    not_to_remove_from_client_state = ["error"]

    for round, history_state in H['history'].items():
        clients_history = history_state['client_states']
        for client_id, client_state in clients_history.items():
            client_state = client_state['client_state']

            if 'H' in client_state:
                del client_state['H']
            if 'optimiser' in client_state:
                del client_state['optimiser']
            if 'buffers' in client_state:
                del client_state['buffers']
            if 'client_compressor' in client_state:
                del client_state['client_compressor']

            client_state_keys = list(client_state.keys())
            for field in client_state_keys:

                if field in not_to_remove_from_client_state:
                    continue

                if torch.is_tensor(client_state[field]):
                    size_in_mbytes = client_state[field].numel() * client_state[field].element_size() / (1024.0 ** 2)
                    if size_in_mbytes > tensor_prune_mb_threshold:
                        del client_state[field]
                        continue
                    else:
                        # Convert any not removed tensor to CPU
                        client_state[field] = client_state[field].cpu()
                        continue

                if type(client_state[field]) == list:
                    sz = len(client_state[field])
                    i = 0
                    while i < sz:
                        if torch.is_tensor(client_state[field][i]):
                            size_in_mbytes = client_state[field][i].numel() * client_state[field][i].element_size() / (
                                    1024.0 ** 2)
                            if size_in_mbytes >= tensor_prune_mb_threshold:
                                del client_state[field][i]
                                sz = len(client_state[field])
                            else:
                                client_state[field][i] = client_state[field][i].cpu()
                                i += 1
                        else:
                            i += 1

    # Remove reference for execution context
    del H["execution_context"]
    if 'exec_ctx' in H["args"]:
        del H["args"]

    # ==================================================================================================================
    # Cleanup execution context resources
    # ==================================================================================================================
    execution_context.resetExecutionContext(exec_ctx)
    # ==================================================================================================================
    if execution_context.simulation_finish_fn is not None:
        execution_context.simulation_finish_fn(H)
def isSimulationNeedEarlyStopCmdLine(H)
Expand source code
def isSimulationNeedEarlyStopCmdLine(H):
    global g_Terminate
    return g_Terminate
def main(args, raw_cmdline, extra_)
Expand source code
def main(args, raw_cmdline, extra_):
    # Init project and save in args url for the plots
    projectWB = None
    if len(args.wandb_key) > 0:
        projectWB = wandb_wrapper.initWandbProject(
            args.wandb_key,
            args.wandb_project_name,
            f"{args.algorithm}-rounds-{args.rounds}-{args.model}@{args.dataset}-runid-{args.run_id}",
            args
        )
        if projectWB is not None:
            args.wandb_url = projectWB.url

    # In case of DataParallel for .to() to
    # This a device used by master for FL algorithms
    args.device = gpu_utils.get_target_device_str(args.gpu[0]) if type(args.gpu) == list else args.gpu

    # Instantiate execution context
    exec_ctx = execution_context.initExecutionContext()
    exec_ctx.extra_ = extra_

    # Set all seeds
    if args.deterministic:
        exec_ctx.random.seed(args.manual_init_seed)
        exec_ctx.np_random.seed(args.manual_init_seed)

    # Setup extra options for experiments
    for option in args.algorithm_options.split(","):
        kv = option.split(":")
        if len(kv) == 1:
            exec_ctx.experimental_options[kv[0]] = True
        else:
            exec_ctx.experimental_options[kv[0]] = kv[1]

    # Load validation set
    trainset, testset = load_data(exec_ctx, args.data_path, args.dataset, args, load_trainset=True, download=True)

    test_batch_size = get_test_batch_size(args.dataset, args.batch_size)

    # Test set dataloader
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=test_batch_size,
                                             num_workers=args.num_workers_test,
                                             shuffle=False,
                                             pin_memory=False,
                                             drop_last=False
                                             )

    # Reset all seeds
    if args.deterministic:
        exec_ctx.random.seed(args.manual_init_seed)
        exec_ctx.np_random.seed(args.manual_init_seed)

    if not args.evaluate:  # Training mode
        trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            num_workers=args.num_workers_train,
            shuffle=False,
            pin_memory=False,
            drop_last=False)

        # Is target device for master is GPU
        is_target_dev_gpu = gpu_utils.is_target_dev_gpu(args.device)

        # Initialize local training threads in case of using multithreading implementation
        exec_ctx.local_training_threads.adjust_num_workers(w=args.threadpool_for_local_opt, own_cuda_stream=True,
                                                           device_list=args.gpu)
        for th in exec_ctx.local_training_threads.threads:
            th.trainset_copy = copy.deepcopy(trainset)
            th.train_loader_copy = torch.utils.data.DataLoader(th.trainset_copy,
                                                               batch_size=args.batch_size,
                                                               num_workers=args.num_workers_train,
                                                               shuffle=False,
                                                               pin_memory=False,
                                                               drop_last=False)
            if hasattr(th.trainset_copy, "load_data"):
                th.trainset_copy.load_data()

        # Initialize remote clients - one remote connection per one remote computation device
        external_devices = [d.strip() for d in args.external_devices.split(",") if len(d.strip()) > 0]
        exec_ctx.remote_training_threads.adjust_num_workers(w=len(external_devices), own_cuda_stream=False,
                                                            device_list=["cpu"])
        for th_i in range(len(exec_ctx.remote_training_threads.threads)):
            th = exec_ctx.remote_training_threads.threads[th_i]
            dev = external_devices[th_i]
            th.external_dev_long = dev
            th.external_ip, th.external_port, th.external_dev, th.external_dev_dumber = dev.split(":")

            try:
                th.external_socket = comm_socket.CommSocket()
                th.external_socket.sock.connect((th.external_ip, int(th.external_port)))  # connect with remote side
                th.external_socket.rawSendString("execute_work")  # initiate work execution COMMAND
                th.external_socket.rawSendString(th.external_dev_long)  # provide device specification
                th.external_socket.rawSend(pickle.dumps(
                    raw_cmdline))  # provide original command line (it will be changed by the client) to be consistent with used devices
                th.external_socket_online = True
            except socket.error:
                th.external_socket_online = False

        # Initialize saver treads for information serialization (0 is acceptable)
        exec_ctx.saver_thread.adjust_num_workers(w=args.save_async_threads, own_cuda_stream=True, device_list=args.gpu)

        # Initialize local training threads in case of using multithreading implementation (0 is acceptable)
        exec_ctx.eval_thread_pool.adjust_num_workers(w=args.eval_async_threads, own_cuda_stream=True,
                                                     device_list=args.gpu)

        for th in exec_ctx.eval_thread_pool.threads:
            th.testset_copy = copy.deepcopy(testset)
            th.testloader_copy = torch.utils.data.DataLoader(th.testset_copy,
                                                             batch_size=test_batch_size,
                                                             num_workers=args.num_workers_test,
                                                             shuffle=False,
                                                             pin_memory=False,
                                                             drop_last=False
                                                             )
            if hasattr(th.testset_copy, "load_data"):
                th.testset_copy.load_data()

        if args.worker_listen_mode <= 0:
            # Path of execution for local simulation
            init_and_train_model(args, raw_cmdline, trainloader, testloader, exec_ctx)
        else:
            # Path of execution for assist simulation with a remote means
            init_and_help_with_compute(args, raw_cmdline, trainloader, testloader, exec_ctx)

    else:  # Evaluation mode
        # TODO: figure out how to exploit DataParallel. Currently - parallelism across workers
        model, criterion, round = get_training_elements(args.model, args.dataset, args, args.resume_from,
                                                        args.load_best, args.device, args.loss,
                                                        args.turn_off_batch_normalization_and_dropout)

        metrics = evaluate_model(model, testloader, criterion, args.device, round,
                                 print_freq=10, metric_to_optim=args.metric,
                                 is_rnn=args.model in RNN_MODELS)

        metrics_dict = create_metrics_dict(metrics)
        logger.info(f'Validation metrics: {metrics_dict}')
    wandb_wrapper.finishProject(projectWB)
def makeBackupOfServerState(H, round)
Expand source code
def makeBackupOfServerState(H, round):
    args = H["args"]
    job_id = args.run_id
    fname = args.out

    if fname == "":
        fname = os.path.join(args.checkpoint_dir, f"{args.run_id}_{args.algorithm}_{args.global_lr}_backup.bin")
    else:
        fname = fname.replace(".bin", f"{args.run_id}_{args.algorithm}_{args.global_lr}_backup.bin")

    execution_context = H["execution_context"]
    H["execution_context"] = None
    exec_ctx_in_arg = ('exec_ctx' in H["args"])
    exec_ctx = None
    if exec_ctx_in_arg:
        exec_ctx = H["args"]['exec_ctx']
        H["execution_context"] = None

    with open(fname, "wb") as f:
        pickle.dump([(job_id, H)], f)

    H["execution_context"] = execution_context
    if exec_ctx_in_arg:
        H["args"]['exec_ctx'] = exec_ctx
def runSimulation(cmdline, extra_=None)
Expand source code
def runSimulation(cmdline, extra_=None):
    # ==================================================================================================================
    # ENTRY POINT FOR LAUNCH SIMULATION FROM GUI
    # ==================================================================================================================
    global CUDA_SUPPORT

    args = parse_args(cmdline)
    if args.hostname == "":
        args.hostname = socket.gethostname()

    Logger.setup_logging(args.loglevel, args.logfilter, logfile=args.logfile)
    logger = Logger.get(args.run_id)

    logger.info(f"CLI args original: {cmdline}")

    if torch.cuda.device_count():
        CUDA_SUPPORT = True
        # torch.backends.cudnn.benchmark = True
    else:
        logger.warning('CUDA unsupported!!')
        CUDA_SUPPORT = False

    if not CUDA_SUPPORT:
        args.gpu = "cpu"

    if args.deterministic:
        import torch.backends.cudnn as cudnn
        import os
        import random

        # TODO: This settings are not thread safe in case of using cuda backend from several threads
        # Project use execution context random generators for random and numpy in thread safe way
        if CUDA_SUPPORT:
            cudnn.deterministic = args.deterministic
            cudnn.benchmark = not args.deterministic
            torch.cuda.manual_seed(args.manual_init_seed)
            torch.cuda.manual_seed_all(args.manual_init_seed)

            # Turn off Tensor Cores if have been requested
            torch.backends.cudnn.allow_tf32 = args.allow_use_nv_tensorcores
            torch.backends.cuda.matmul.allow_tf32 = args.allow_use_nv_tensorcores

        torch.manual_seed(args.manual_init_seed)
        random.seed(args.manual_init_seed)
        np.random.seed(args.manual_init_seed)

        os.environ['PYTHONHASHSEED'] = str(args.manual_init_seed)
        torch.backends.cudnn.deterministic = True

    main(args, cmdline, extra_)
def saveResult(H)

Serialize server state into filesystem

Expand source code
def saveResult(H):
    """Serialize server state into filesystem"""
    args = H["args"]

    job_id = args.run_id
    fname = args.out
    if fname == "":
        fname = os.path.join(args.checkpoint_dir, f"{args.algorithm}_simulation_history.bin")

    # Save formally list of tuples (job_id, server_state)
    with open(fname, "wb") as f:
        pickle.dump([(job_id, H)], f)

    print(f"Final server state for {job_id} is serialized into: '{fname}'")
    print(f"Program which uses '{args.algorithm}' algorithm for {args.rounds} rounds was finished")
def terminationWithSignal(*args)
Expand source code
def terminationWithSignal(*args):
    global g_Terminate
    g_Terminate = True
    print("Request for terminate process has been obtained. Save results and terminate.")

Classes

class ClientThread (clientSocket, listenPort)

A class that represents a thread of control.

This class can be safely subclassed in a limited fashion. There are two ways to specify the activity: by passing a callable object to the constructor, or by overriding the run() method in a subclass.

This constructor should always be called with keyword arguments. Arguments are:

group should be None; reserved for future extension when a ThreadGroup class is implemented.

target is the callable object to be invoked by the run() method. Defaults to None, meaning nothing is called.

name is the thread name. By default, a unique name is constructed of the form "Thread-N" where N is a small decimal number.

args is the argument tuple for the target invocation. Defaults to ().

kwargs is a dictionary of keyword arguments for the target invocation. Defaults to {}.

If a subclass overrides the constructor, it must make sure to invoke the base class constructor (Thread.init()) before doing anything else to the thread.

Expand source code
class ClientThread(threading.Thread):
    def __init__(self, clientSocket, listenPort):
        threading.Thread.__init__(self)
        self.clientSocket = clientSocket  # OS-like socket to communicate with information
        self.comSocket = comm_socket.CommSocket(clientSocket)  # Light construction on top of socket to communicate
        self.listenPort = listenPort  # Listening port

    def run(self):
        cmd = self.comSocket.rawRecvString()  # Obtain command

        t = time.localtime()
        cur_time = time.strftime("%H:%M:%S", t)
        print(f"-- {cur_time}: The received command: {cmd}")

        if cmd == "list_of_gpus":  # List of resources in the system
            gpus = len(gpu_utils.get_available_gpus())
            self.comSocket.rawSendString(f"{gpus}")
        elif cmd == "execute_work":  # Execute work
            devConfig = self.comSocket.rawRecvString()  # 1-st param: device configuration
            cmdLine = self.comSocket.rawRecv()  # 2-nd param: command line
            cmdLine = pickle.loads(cmdLine)  # unpack cmdline

            ip, port, dev_name, dev_number = devConfig.split(":")

            # Remove unnecessary flags for worker
            i = len(cmdLine) - 1
            while i >= 0:
                if cmdLine[i] == "--metric":
                    cmdLine[i + 1] = "none"

                if cmdLine[i] == "--gpu":
                    cmdLine[i + 1] = dev_number  # update device number

                if cmdLine[i] in ['--wandb-key', '--eval-async-threads', '--save-async-threads',
                                  '--threadpool-for-local-opt', '--external-devices', '--evaluate']:
                    del cmdLine[i]  # remove flag
                    del cmdLine[i]  # remove value for flag (originally at index i+1)
                i -= 1

            # Specify listen mode in which worker is
            cmdLine.append("--worker-listen-mode")
            cmdLine.append(str(self.listenPort))

            runSimulation(cmdLine, extra_=self.comSocket)
        else:
            print(f"The received command {cmd} is not valid")

Ancestors

  • threading.Thread

Methods

def run(self)

Method representing the thread's activity.

You may override this method in a subclass. The standard run() method invokes the callable object passed to the object's constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.

Expand source code
def run(self):
    cmd = self.comSocket.rawRecvString()  # Obtain command

    t = time.localtime()
    cur_time = time.strftime("%H:%M:%S", t)
    print(f"-- {cur_time}: The received command: {cmd}")

    if cmd == "list_of_gpus":  # List of resources in the system
        gpus = len(gpu_utils.get_available_gpus())
        self.comSocket.rawSendString(f"{gpus}")
    elif cmd == "execute_work":  # Execute work
        devConfig = self.comSocket.rawRecvString()  # 1-st param: device configuration
        cmdLine = self.comSocket.rawRecv()  # 2-nd param: command line
        cmdLine = pickle.loads(cmdLine)  # unpack cmdline

        ip, port, dev_name, dev_number = devConfig.split(":")

        # Remove unnecessary flags for worker
        i = len(cmdLine) - 1
        while i >= 0:
            if cmdLine[i] == "--metric":
                cmdLine[i + 1] = "none"

            if cmdLine[i] == "--gpu":
                cmdLine[i + 1] = dev_number  # update device number

            if cmdLine[i] in ['--wandb-key', '--eval-async-threads', '--save-async-threads',
                              '--threadpool-for-local-opt', '--external-devices', '--evaluate']:
                del cmdLine[i]  # remove flag
                del cmdLine[i]  # remove value for flag (originally at index i+1)
            i -= 1

        # Specify listen mode in which worker is
        cmdLine.append("--worker-listen-mode")
        cmdLine.append(str(self.listenPort))

        runSimulation(cmdLine, extra_=self.comSocket)
    else:
        print(f"The received command {cmd} is not valid")