#!/usr/bin/env python3

import argparse
import os
import sys

from run import check_call, executable, parse_memory_string, parse_time_string


class Runner:
    def __init__(
        self,
        domain: str,
        problem: str,
        policy: str,
        start: str,
        time_limit: int,
        memory_limit: int,
        police: str,
        cpddl: str,
        args: list[str],
        cleanup: bool,
    ):
        self.domain = domain
        self.problem = problem
        self.policy = policy
        self.start = start
        self.time_limit = time_limit
        self.memory_limit = memory_limit
        self.police = police
        self.cpddl = cpddl
        self.args = args
        self.ground_policy = "asnets.cg"
        self.output_sas = "output.sas"
        self.policy_interface = "asnets.jani2nnet"
        self.cleanup = cleanup

    def _ground_policy(self):
        check_call(
            [
                self.cpddl,
                "ground",
                self.policy,
                self.domain,
                self.problem,
                self.ground_policy,
            ],
            time_limit=self.time_limit,
            memory_limit=self.memory_limit,
        )

    def _verify_policy(self):
        check_call(
            [
                self.police,
                "--sas",
                self.output_sas,
                "--pddl-init",
                self.start,
                "--policy",
                self.ground_policy,
                "--policy-adapter",
                self.policy_interface,
            ]
            + self.args,
            time_limit=self.time_limit,
            memory_limit=self.memory_limit,
        )

    def _cleanup(self):
        if self.cleanup:
            os.remove(self.ground_policy)
            os.remove(self.output_sas)
            os.remove(self.policy_interface)

    def run(self):
        self._ground_policy()
        self._verify_policy()
        self._cleanup()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("domain", help="Path to domain.pddl")
    parser.add_argument("problem", help="Path to problem.pddl")
    parser.add_argument("start", help="Path to start condition")
    parser.add_argument("policy", help="Path to asnets policy")
    parser.add_argument("--time-limit", type=parse_time_string, help="Time limit")
    parser.add_argument("--memory-limit", type=parse_memory_string, help="Memory limit")
    parser.add_argument(
        "--police", default=executable, help="Path to police executable"
    )
    parser.add_argument(
        "--cpddl",
        default=os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            "ext",
            "cpddl",
            "bin",
            "pddl-asnets",
        ),
        help="Path to pddl-asnets",
    )
    parser.add_argument(
        "--args",
        default=[],
        nargs="*",
        help="Arguments to be passed on to the police executable",
    )
    parser.add_argument("--no-cleanup", action="store_true", default=False)
    sys_args = []
    for i, arg in enumerate(sys.argv):
        if arg == "--args":
            break
        sys_args.append(arg)
    args = parser.parse_args(sys_args[1:])

    if not os.path.exists(args.police):
        print("Cannot access police executable at", args.police)
        sys.exit(1)

    if not os.path.exists(args.cpddl):
        print("Cannot access pddl-asnets executable at", args.cpddl)
        sys.exit(1)

    runner = Runner(
        args.domain,
        args.problem,
        args.policy,
        args.start,
        args.time_limit,
        args.memory_limit,
        args.police,
        args.cpddl,
        sys.argv[i + 1 :],
        not args.no_cleanup,
    )

    runner.run()


if __name__ == "__main__":
    main()
