#!/usr/bin/env python3

# Acknowledgement: parts of this script are taken from fast downward's driver
# script. https://fast-downward.org

import os
import resource
import socket
import subprocess
import sys

executable = os.path.join(os.path.dirname(os.path.abspath(__file__)), "build", "police")


def set_time_limit(time_limit):
    if time_limit is None:
        return
    # Reaching the soft time limit leads to a (catchable) SIGXCPU signal,
    # which we catch to gracefully exit. Reaching the hard limit leads to
    # a SIGKILL, which is unpreventable. We set a hard limit one second
    # higher than the soft limit to make sure we abort also in cases where
    # the graceful shutdown doesn't work, or doesn't work reasonably
    # quickly.
    try:
        resource.setrlimit(resource.RLIMIT_CPU, (time_limit, time_limit + 1))
    except ValueError:
        # If the previous call failed, we try again without the extra second.
        # In particular, this is necessary if there already exists an external
        # hard limit equal to time_limit.
        resource.setrlimit(resource.RLIMIT_CPU, (time_limit, time_limit))


def set_memory_limit(memory):
    """*memory* must be given in bytes or None."""
    if memory is None:
        return
    resource.setrlimit(resource.RLIMIT_AS, (memory, memory))


def _get_preexec_function(time_limit, memory_limit):
    def set_limits():
        def _try_or_exit(function):
            try:
                function()
            except Exception as err:
                print(err, file=sys.stderr)
                sys.exit(1)

        _try_or_exit(lambda: set_time_limit(time_limit))
        _try_or_exit(lambda: set_memory_limit(memory_limit))

    return set_limits


def check_call(cmd, time_limit=None, memory_limit=None):
    env = os.environ.copy()
    hostname = socket.gethostname()
    if "GRB_LICENSE_FILE" not in env.keys():
        env["GRB_LICENSE_FILE"] = os.path.join("licenses", hostname, "gurobi.lic")

    kwargs = {"preexec_fn": _get_preexec_function(time_limit, memory_limit), "env": env}

    sys.stdout.flush()
    return subprocess.check_call(cmd, **kwargs)


def parse_time_string(tstr):
    try:
        if tstr.endswith("s"):
            return int(tstr[:-1])
        elif tstr.endswith("m"):
            return int(float(tstr[:-1]) * 60)
        elif tstr.endswith("h"):
            return int(float(tstr[:-1]) * 60 * 60)
        else:
            return int(tstr)
    except ValueError:
        print("invalid time string", file=sys.stderr)
        sys.exit(1)


def parse_memory_string(mstr):
    try:
        if mstr.endswith("k"):
            return int(mstr[:-1]) * 1024
        elif mstr.endswith("m"):
            return int(float(mstr[:-1]) * 1024 * 1024)
        elif mstr.endswith("g"):
            return int(float(mstr[:-1]) * 1024 * 1024 * 1024)
        elif mstr.endswith("kb"):
            return int(mstr[:-2]) * 1024
        elif mstr.endswith("mb"):
            return int(float(mstr[:-2]) * 1024 * 1024)
        elif mstr.endswith("gb"):
            return int(float(mstr[:-2]) * 1024 * 1024 * 1024)
        else:
            return int(mstr)
    except ValueError:
        print("invalid memory string", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    args = sys.argv[1:]
    i = 0
    time_limit = None
    memory_limit = None
    while i < len(args):
        if args[i] == "--time-limit":
            if i + 1 == len(args):
                print("time limit expected after --time-limit", file=sys.stderr)
                sys.exit(1)
            time_limit = parse_time_string(args[i + 1].strip().lower())
            args = args[:i] + args[i + 2 :]
        elif args[i] == "--memory-limit":
            if i + 1 == len(args):
                print("memory limit expected after --memory-limit", file=sys.stderr)
                sys.exit(1)
            memory_limit = parse_memory_string(args[i + 1].strip().lower())
            args = args[:i] + args[i + 2 :]
        else:
            i += 1
    check_call([executable] + args, time_limit=time_limit, memory_limit=memory_limit)
