from utilities.Logger import Logger
from utilities.train import *
from utilities.GLOBAL_VALUE import SEEDS


def main():
    args = get_args()
    logger = Logger(args)
    # check availbility of device:
    if torch.cuda.is_available():
        device = torch.device("cuda")
        # mps:
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print("\033[92m" + f"Device: {device}" + "\033[0m")
    if args["num_seed"] > len(SEEDS):
        raise ValueError(
            "Number of seeds specified is greater than the seed list, append more custom seed into the list"
        )
    else:
        seeds = SEEDS[: args["num_seed"]]
    # prediction target moved to prepare_dataset function
    # prediction_targets = ["Wearing_Lipstick", "Smiling", "Mouth_Slightly_Open", "High_Cheekbones",
    #                         "Attractive", "Heavy_Makeup", "Male", "Young", "Wavy_Hair", "Straight_Hair"]

    if args["job"] == "cl":
        train_cl(args, device, seeds, logger)
    elif args["job"] == "mtl":
        train_mtl(args, device, seeds, logger)
    else:
        raise ValueError("Job not found. Please specify either 'cl' or 'mtl'.")


if __name__ == "__main__":
    main()
