import os
from parser import Parser
from datetime import datetime

from misc.utils import *
from modules.multiprocs import ParentProcess

def main(args):

    args = set_config(args)

    if args.model == 'fedavg':    
        from models.fedavg.server import Server
        from models.fedavg.client import Client
    elif args.model == 'fedpub':    
        from models.fedpub.server import Server
        from models.fedpub.client import Client
    else:
        print('incorrect model was given: {}'.format(args.model))
        os._exit(0)

    pp = ParentProcess(args, Server, Client)
    pp.start()

def set_config(args):

    args.base_lr = 1e-3
    args.min_lr = 1e-3
    args.momentum_opt = 0.9
    args.weight_decay = 1e-6
    args.warmup_epochs = 10
    args.base_momentum = 0.99
    args.final_momentum = 1.0

    if args.task in ['Cora_CC_total_0.2_HET']:
        args.n_clients = 10 if args.n_clients == None else args.n_clients
        args.dist = 'heterogeneous'
        args.dataset = args.task.replace('_HET', '')
        args.backbone = 'GCN'
        args.n_feat = 1433
        args.n_clss = 7
        args.base_lr = 0.01 if args.lr == None else args.lr

    now = datetime.now().strftime("%Y%m%d_%H%M%S")
    now_ymd = now.split('_')[0]
    now_hms = now.split('_')[1]
    
    trial = f'{args.dataset}/clients_{args.n_clients}/{args.folder}{now}_{args.task}_{args.model}' \
            if args.trial == None else f'{args.dataset}/clients_{args.n_clients}/{args.folder}{now}_{args.task}_{args.model}_{args.trial}'

    args.data_path = f'{args.base_path}/data' 
    args.checkpt_path = f'{args.base_path}/checkpoints/{trial}'
    args.log_path = f'{args.base_path}/logs/{trial}'

    if args.debug == True:
        args.checkpt_path = f'{args.base_path}/debug/checkpoints/{trial}'
        args.log_path = f'{args.base_path}/debug/logs/{trial}'
        args.n_clients = 5
        args.n_workers = 5
        args.frac = 1.0

    return args

if __name__ == '__main__':
    main(Parser().parse())










