import os
import argparse
import rootutils

from src.utils.validation_utils_imgs import examine_model, extract_path

# Set up root path
rootutils.setup_root(os.path.abspath(""), pythonpath=True)

# Argument parser
parser = argparse.ArgumentParser(description="Run model validation.")

parser.add_argument("--name", type=str, required=True, help="Name of the experiment.")
parser.add_argument("--date", type=str, default="latest", help="Date of the experiment in format YYYY-MM-DD_hh-mm-ss or 'latest' for latest date.")
parser.add_argument("--NFEs", type=int, nargs="+", default=[5, 35, 100], help="List of NFE values to test")
parser.add_argument("--save_images", type=int, default=100, help="How many images to save during validation")
parser.add_argument("--bsz_limit", type=int, default=15, help="Batch size limit during validation")
parser.add_argument("--bsz", type=int, default=0, help="Batch size during validation (0 for the default from the checkpoint)")
parser.add_argument("--device", type=str, default="cuda:0", help="Device name")
parser.add_argument("--checkpoint_name", type=str, default="best", help="Checkpoint name (either best or last)")
parser.add_argument("--ds_root", type=str, default="null", help="Path to dataset root directory")

args = parser.parse_args()


if __name__ == "__main__":
    experiment_root_path = extract_path(args.name, args.date)

    examine_model(
        experiment_root_path, 
        NFEs=args.NFEs, 
        bsz_limit=args.bsz_limit, 
        save_images=args.save_images,
        examined_checkpoint=args.checkpoint_name,
        device=args.device,
        new_bsz=None if args.bsz == 0 else args.bsz,
        ds_root=None if args.ds_root == 'null' else args.ds_root
    )
