import os
import sys
from statistics import stdev, variance, median, mean
import time
import numpy as np
sys.path.append('./simulation')

from run_baselines import run_msm, run_rmsn, run_crn, run_gnet, run_ct, run_edct, run_edts

def main(args):
    if len(args) < 3:
        print('Arguments are too short')
        return 

    # model coeff no gpu
    model = coeff = gpu =  seed = None
    for arg in args[1:]:
        idx = arg.find('=')
        label = arg[:idx] 
        value = arg[idx + 1:] 
        
        if label == "model":
            model = value
        elif label == "coeff":
            coeff = int(value)
        elif label == "gpu":
            gpu = int(value)
        elif label == "seed":
            seed = int(value)
        else:
            print(f"{label} not in variables")
            return
        
    if None in [model, coeff]:
        print("missig parameters")
        return
    
    if gpu is None:
        gpu = 0
        
    if model not in ["msm", "crn", "gnet", "ct", "edct", "edts", "rmsn"]:
        print("model Error")
        return
    if coeff not in [0,1,2,3,4]:
        print("Coeff Error")
        return
    if gpu not in [0, 1]:
        print("no Error")
        return     
    
    run_benchmark(gpu   = gpu, 
                  model = model,
                  coeff = coeff,
                  no    = coeff,
                  seed  = seed)

def run_benchmark(gpu, model, coeff, no, seed):
    # get time when the program stated
    start_time = time.time()

    result = run_seed_set(gpu, model, coeff, no, seed)
    
    # calculate elapsed time
    elapsed_time = int(time.time() - start_time)

    # convert second to hour, minute and seconds
    elapsed_hour = elapsed_time // 3600
    elapsed_minute = (elapsed_time % 3600) // 60
    elapsed_second = (elapsed_time % 3600 % 60)

    # print as 00:00:00
    print(str(elapsed_hour).zfill(2) + ":" + str(elapsed_minute).zfill(2) + ":" + str(elapsed_second).zfill(2))
    
    return result

# ------------------------------------------------------------------------------------------
# run seed
# ------------------------------------------------------------------------------------------
def run_seed_set(gpu, model, coeff, no, seed = 10):
    if model in ["crn", "edct", "rmsn", "edts"]:
        isEncoderDecoder = True
        result = {"encoder_valid_one"  : [], 
                  "encoder_test_multi" : {1: [], 2: [], 3: [], 4: [], 5: [], 6: []},
                  "decoder_valid_multi": {1: [], 2: [], 3: [], 4: [], 5: [], 6: []},
                  "decoder_test_multi" : {1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
                 }
    else:
        isEncoderDecoder = False
        result = {"encoder_valid_one"  : [], 
                  "encoder_test_multi" : {1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
                 }        
        
    if model == "crn":
        encoder_rmse, decoder_rmse, encoder_loss, decoder_loss = \
        run_crn(gpu, no = no, coeff = coeff, seed = seed)
    elif model == "rmsn":
        encoder_rmse, decoder_rmse, encoder_loss, decoder_loss = \
        run_rmsn(gpu, no = no, coeff = coeff, seed = seed)
    elif model == "gnet":
        encoder_rmse, encoder_loss = \
        run_gnet(gpu, no = no, coeff = coeff, seed = seed) 
    elif model == "ct":
        encoder_rmse, encoder_loss = \
        run_ct(gpu, no = no, coeff = coeff, seed = seed) 
    elif model == "edct":
        encoder_rmse, decoder_rmse, encoder_loss, decoder_loss = \
        run_edct(gpu, no = no, coeff = coeff, seed = seed)
    elif model == "edts":
        encoder_rmse, decoder_rmse, encoder_loss, decoder_loss_dict = \
        run_edts(gpu, no = no, coeff = coeff, seed = seed) 
    elif model == "msm":
        encoder_rmse = run_msm(coeff = coeff, seed = seed) 
    else:
        raise NotImplementedError()

    if isEncoderDecoder:
        display_result(encoder_rmse, decoder_rmse)
    else:
        display_result(encoder_rmse)
                          
def display_result(encoder_rmse, decoder_rmse = None):        
    if decoder_rmse is not None:
        print("[Normalized RMSE (validation) for multi-step]")
        for tau, value in decoder_rmse["valid_multi"].items():
            print("\t{}-step: {:<.3f}".format(tau, value))

        print("[Normalized RMSE (test) for multi-step]")
        for tau, value in decoder_rmse["test_multi"].items():
            print("\t{}-step: {:<.3f}".format(tau, value))
    else:
        print("[Normalized RMSE (validation) for one-step]")
        print("\t{:<.3f}".format(encoder_rmse["valid_one"]))
        
        print("[Normalized RMSE (test) for multi-step]")
        for tau, value in encoder_rmse["test_multi"].items():
            print("\t{}-step: {:<.3f}".format(tau, value))        
            
if __name__ == '__main__':
    args = sys.argv
    main(args)
