import os
import statistics
from parser import Parser
from datetime import datetime
from misc.utils import *
from modules.multiprocs_new import ParentProcess
import warnings

warnings.filterwarnings("ignore")


def main(args):
    args = set_config(args)
    if args.method == 'fedavg':
        from models.fedavg.server import Server
        from models.fedavg.client import Client
    else:
        print('incorrect model was given: {}'.format(args.model))
        os._exit(0)

    print(args)
    pp = ParentProcess(args, Server, Client)
    pp.start()

def set_config(args):
    args.base_lr = 1e-3 if args.lr is None else args.lr
    args.vae_lr = 1e-2
    args.momentum_opt = 0.9
    args.n_clss = get_num_classes(args.dataset)
    args.weight_decay = 5e-4
    args.n_clients = 10 if args.n_clients is None else args.n_clients

    now = datetime.now().strftime("%Y%m%d_%H%M%S")
    trial = f'{args.dataset}/clients_{args.n_clients}/{now}_{args.method}'
    args.base_path = '.'
    args.data_path = f'{args.base_path}/datasets'
    args.checkpt_path = f'{args.base_path}/checkpoints/{trial}'
    args.log_path = f'{args.base_path}/logs/{trial}'

    if args.debug is True:
        args.checkpt_path = f'{args.base_path}/debug/checkpoints/{trial}'
        args.log_path = f'{args.base_path}/debug/logs/{trial}'

    return args


if __name__ == '__main__':
    main(Parser().parse())
