import argparse
from FedLF_gpu import run_FL_FedLF
from feddrop import run_FL_dropout
from fjord import run_FL_fjord
from cocofl import run_cocoFL
from SLT import run_SLT
from FedLF_approx import run_LFaprox
import time

NUM_SIMS = 1
ROUNDS = 500
MFC = 10
MAC = 100
S = 0.5
Learning_rate = 0.001

parser = argparse.ArgumentParser(description="Run a FL simulation")
parser.add_argument('-m', '--methods', type=str, metavar='', default='fedlf', nargs="+", help='<Required> Baseline', required=True)
parser.add_argument('-r', '--rate', type=float, metavar='', default=S, help='Scaler')
parser.add_argument('-I', '--iteration', type=int, metavar='', default=ROUNDS, help='Global iteration')
parser.add_argument('-l', '--lr', type=float, metavar="", default=Learning_rate, help="Learning rate")
parser.add_argument('-s', '--sd', type=int, metavar="", default=1, help="Random Seed")
args = parser.parse_args()

if __name__ == '__main__':
    round = args.iteration
    methods = args.methods
    rate = args.rate
    lr = args.lr
    seed = args.sd
    if len(methods) == 0:
        print("Error! No methods are executed!")
    else:
        times = []
        for m in methods:
            ### Run Federated Dropout ###
            if m == 'dropout':     
                time0 = time.time()
                run_FL_dropout(M=MAC, P=MFC, R=round, seed=seed, lr=lr)
                time1 = time.time()
                times.append(time1 - time0)
            
            ### Run FjORD ###
            elif m == 'fjord':
                time0 = time.time()
                run_FL_fjord(M=MAC, P=MFC, R=round, seed=seed, lr=lr)
                time1 = time.time()
                times.append(time1 - time0)

            ### Run SLT ###
            elif m == 'slt':
                time0 = time.time()
                run_SLT(M=MAC, P=MFC, R=round, scaler=rate, seed=seed, lr=lr)
                time1 = time.time()
                times.append(time1 - time0)

            ### Run CoCoFL ###
            elif m == 'cocofl':
                time0 = time.time()
                run_cocoFL(M=MAC, P=MFC, R=round, seed=seed, lr=lr)
                time1 = time.time()
                times.append(time1 - time0)

            ### Run FedOLF ###
            elif m == 'fedlf':
                time0 = time.time()
                run_FL_FedLF(M=MAC, P=MFC, R=round, seed=seed, lr=lr)
                time1 = time.time()
                times.append(time1 - time0)

            ### Run FedOLF + TA ###
            elif m == 'fedlfap':
                time0 = time.time()
                run_LFaprox(M=MAC, P=MFC, R=ROUNDS, seed=seed, approx_rate=rate, lr=lr)
                time1 = time.time()
                times.append(time1 - time0)
            else:
                print(f"Error: unkonwn method detected: {m}!")
        
        ### Get running time ###
        for q in range(len(methods)):
            print(f"the time consumption of method {methods[q]} is {times[q]/3600} hours.")