from utilities.Logger import Logger
from utilities.eval_speed import *
from utilities.GLOBAL_VALUE import SEEDS


def main():
    args = get_args()
    # check availbility of device:
    if torch.cuda.is_available():
        device = torch.device("cuda")
        # mps:
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print("\033[92m" + f"Device: {device}" + "\033[0m")
    if args["num_seed"] > len(SEEDS):
        raise ValueError(
            "Number of seeds specified is greater than the seed list, append more custom seed into the list"
        )
    else:
        seeds = SEEDS[: args["num_seed"]]

    if args["job"] == "cl":
        train_cl_speed(args, device, seeds)
    elif args["job"] == "mtl":
        train_mtl_speed(args, device, seeds)
    else:
        raise ValueError("Job not found. Please specify either 'cl' or 'mtl'.")

def visualize():
    args = get_args()
    _, _, _, prediction_targets = prepare_data(args, 42, None, return_datasets=True)
    for task_index in range(3):
        visualize_speed_results(args, task_index=task_index, prediction_targets=prediction_targets)


if __name__ == "__main__":
    main()

