#!/usr/bin/env python3

import argparse
import os
import shlex
import socket
import sys
from datetime import datetime

from run import check_call as check_call_
from run import executable, parse_memory_string, parse_time_string


def check_call(cmd, *args, **kwargs):
    print("############################################")
    print("### Date", datetime.today().strftime("%Y-%m-%d %H:%M:%S"))
    print("### Host", socket.gethostname())
    print("### Command", " ".join((shlex.quote(x) for x in cmd)))
    print("")
    check_call_(cmd, *args, **kwargs)
    print("")


class Runner:
    def __init__(
        self,
        domain: str,
        jani: str,
        properties: str,
        policy: str,
        time_limit: int,
        memory_limit: int,
        police: str,
        cpddl: str,
        args: list[str],
        cleanup: bool,
    ):
        self.domain = domain
        self.jani = jani
        self.policy = policy
        self.properties = properties
        self.time_limit = time_limit
        self.memory_limit = memory_limit
        self.police = police
        self.cpddl = cpddl
        self.args = args
        self.ground_policy = "asnets.cg"
        self.policy_interface = "asnets.jani2nnet"
        self.problem_pddl = "problem.pddl"
        self.problem_jani2pddl = "problem.jani2pddl"
        self.cleanup = cleanup

    def _generate_pddl_interface(self):
        exe = None
        if "blocks" in self.jani:
            exe = "blocks"
        elif "npuzzle" in self.jani:
            exe = "npuzzle"
        elif "transport" in self.jani or "linetrack" in self.jani:
            exe = "linetrack"
        else:
            raise ValueError("could not determine domain")
        check_call(
            [
                os.path.join(os.path.dirname(self.police), "utils", f"police-{exe}"),
                "--jani",
                self.jani,
                "--jani-additional-properties",
                self.properties,
                "--out-pddl",
                self.problem_pddl,
                "--out-jani2pddl",
                self.problem_jani2pddl,
                "--out-jani2nnet",
                self.policy_interface,
            ]
        )

    def _ground_policy(self):
        check_call(
            [
                self.cpddl,
                "ground",
                self.policy,
                self.domain,
                self.problem_pddl,
                self.problem_jani2pddl,
                self.ground_policy,
            ],
            time_limit=self.time_limit,
            memory_limit=self.memory_limit,
        )

    def _verify_policy(self):
        check_call(
            [
                self.police,
                "--jani",
                self.jani,
                "--jani-additional-properties",
                self.properties,
                "--policy",
                self.ground_policy,
                "--policy-adapter",
                self.policy_interface,
            ]
            + self.args,
            time_limit=self.time_limit,
            memory_limit=self.memory_limit,
        )

    def _cleanup(self):
        if self.cleanup:
            os.remove(self.ground_policy)
            os.remove(self.policy_interface)
            os.remove(self.problem_pddl)
            os.remove(self.problem_jani2pddl)

    def run(self):
        self._generate_pddl_interface()
        self._ground_policy()
        self._verify_policy()
        self._cleanup()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("domain", help="Path to domain.pddl")
    parser.add_argument("jani", help="Path to jani instance")
    parser.add_argument("properties", help="Path to properties specification")
    parser.add_argument("policy", help="Path to asnets policy")
    parser.add_argument("--time-limit", type=parse_time_string, help="Time limit")
    parser.add_argument("--memory-limit", type=parse_memory_string, help="Memory limit")
    parser.add_argument(
        "--police", default=executable, help="Path to police executable"
    )
    parser.add_argument(
        "--cpddl",
        default=os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            "ext",
            "cpddl",
            "bin",
            "pddl-asnets",
        ),
        help="Path to pddl-asnets",
    )
    parser.add_argument(
        "--args",
        default=[],
        nargs="*",
        help="Arguments to be passed on to the police executable",
    )
    parser.add_argument("--no-cleanup", action="store_true", default=False)
    sys_args = []
    for i, arg in enumerate(sys.argv):
        if arg == "--args":
            break
        sys_args.append(arg)
    args = parser.parse_args(sys_args[1:])

    if not os.path.exists(args.police):
        print("Cannot access police executable at", args.police)
        sys.exit(1)

    if not os.path.exists(args.cpddl):
        print("Cannot access pddl-asnets executable at", args.cpddl)
        sys.exit(1)

    runner = Runner(
        args.domain,
        args.jani,
        args.properties,
        args.policy,
        args.time_limit,
        args.memory_limit,
        args.police,
        args.cpddl,
        sys.argv[i + 1 :],
        not args.no_cleanup,
    )

    runner.run()


if __name__ == "__main__":
    main()
