import argparse
import time
from FedVTC import run_FedVTC
from FedTGP import run_FedTGP
from Fedtype import run_Fedtype
from Fedgen import run_Fedgen
from CCVR import run_ccvr

NUM_SIMS = 1
ROUNDS = 100
MFC = 10
MAC = 100
S = 0.5
Learning_rate = 1e-4
Local_epoch = 5
BATCH = 16

parser = argparse.ArgumentParser(description="Run a FL simulation")
parser.add_argument('-m', '--methods', type=str, metavar='', default='fedvtc', nargs="+", help='<Required> Baseline', required=True)
parser.add_argument('-I', '--iteration', type=int, metavar='', default=ROUNDS, help='Global iteration')
parser.add_argument('-l', '--lr', type=float, metavar="", default=Learning_rate, help="Learning rate")
parser.add_argument('-s', '--sd', type=int, metavar="", default=1, help="Random Seed")
parser.add_argument('-e', '--er', type=int, metavar="", default=50, help="Extra Rounds")
args = parser.parse_args()

if __name__ == '__main__':
    round = args.iteration
    methods = args.methods
    lr = args.lr
    seed = args.sd
    extra_rounds = args.er
    if len(methods) == 0:
        print("Error! No methods are executed!")
    else:
        times = []
        for m in methods:
            if m == 'fedvtc':
                time0 = time.time()
                run_FedVTC(M=MAC, P=MFC, R=round, seed=seed, lr=lr, extra_rounds=extra_rounds, local_epoch=Local_epoch, Batch=BATCH)
                time1 = time.time()
                times.append(time1 - time0)
            elif m == 'fedtgp':
                time0 = time.time()
                run_FedTGP(M=MAC, P=MFC, R=round, seed=seed, lr=lr, Batch=BATCH)
                time1 = time.time()
                times.append(time1 - time0)
            elif m == 'fedtype':
                time0 = time.time()
                run_Fedtype(M=MAC, P=MFC, R=round, seed=seed, lr=lr, Batch=BATCH)
                time1 = time.time()
                times.append(time1 - time0)
            elif m == 'fedgen':
                time0 = time.time()
                run_Fedgen(M=MAC, P=MFC, R=round, seed=seed, lr=lr, Batch=BATCH)
                time1 = time.time()
                times.append(time1 - time0)
            elif m == 'ccvr':
                time0 = time.time()
                run_ccvr(M=MAC, P=MFC, R=round, seed=seed, lr=lr, Batch=BATCH)
                time1 = time.time()
                times.append(time1 - time0)
            else:
                print(f"Error: unkonwn method detected: {m}!")
        for q in range(len(methods)):
            print(f"the time consumption of method {methods[q]} is {times[q]/3600} hours.")