"""
Run an evaluation procedure on a model: log inferences on a real dataset. eg:

python eval_trainer.py \
    --resume_path some/path \
    --val_r_json some/other/path \
    -m -a --tasks m s d
"""
print("Imports...", end="")

import os
import sys

GROUP = os.environ.get("CLIMATEGAN_GROUP")
if GROUP is None:
    print("CLIMATEGAN_GROUP is unknown. Please set env variable")
    sys.exit()

from argparse import ArgumentParser
from pathlib import Path

from addict import Dict
from comet_ml import Experiment  # noqa: F401 -> keep even if unused

from climategan.data import get_loader
from climategan.trainer import Trainer
from climategan.utils import flatten_opts

print("Ok.")


def parsed_args():
    """Parse and returns command-line args

    Returns:
        argparse.Namespace: the parsed arguments
    """
    parser = ArgumentParser()
    parser.add_argument(
        "--resume_path", required=True, type=str, help="Path to the trainer to resume"
    )
    parser.add_argument(
        "--image_domain",
        default="r",
        type=str,
        help="Domain of images in path_to_images, can be 'r' or 's'",
    )
    parser.add_argument(
        "--val_r_json",
        default=f"/network/tmp1/{GROUP}/data/climategan/base/"
        + "val_r_full_with_labelbox.json",
        type=str,
        help="The json file where you want to evaluate for real domain.",
    )
    parser.add_argument(
        "-t",
        "--tasks",
        nargs="+",
        help="list of tasks to eval. eg: `-t m s`",
        default=["m"],
    )
    parser.add_argument(
        "-m",
        "--minimal",
        action="store_true",
        default=False,
        help="Only log smooth mask",
    )
    parser.add_argument(
        "-a",
        "--all_only",
        action="store_true",
        default=False,
        help="Only log smooth mask",
    )

    return parser.parse_args()


if __name__ == "__main__":
    # -----------------------------
    # -----  Parse arguments  -----
    # -----------------------------

    args = parsed_args()
    print("Args:\n" + "\n".join([f"    {k:20}: {v}" for k, v in vars(args).items()]))
    resume_path = Path(args.resume_path).expanduser().resolve()
    assert resume_path.exists()

    image_domain = args.image_domain
    assert image_domain in {"r", "s", "rf", "kitti"}

    overrides = Dict()
    overrides.data.loaders.batch_size = 1
    overrides.comet.rows_per_log = 1
    overrides.tasks = args.tasks
    if args.val_r_json:
        val_r_json_path = Path(args.val_r_json).expanduser().resolve()
        assert val_r_json_path.exists()
        overrides.data.files.val[image_domain] = str(val_r_json_path)

    trainer = Trainer.resume_from_path(
        resume_path, overrides=overrides, inference=True, new_exp=True
    )
    trainer.exp.log_parameters(flatten_opts(trainer.opts))
    trainer.all_loaders = {
        "val": {image_domain: get_loader("val", image_domain, trainer.opts)}
    }
    trainer.set_display_images(True)
    trainer.logger.log_comet_images("val", image_domain, args.minimal, args.all_only)
