from rlutils.rewards.interfaces import OracleInterface
from ..features.fastjet import *


def oracle(s, a, ns):
    d = dist(s, a, ns, None)
    return - (d + 0.05 * closing_speed(s, a, ns, {"dist": d}) + 10. * up_error(s, a, ns, None))

P = {
    "pbrl": {
        "interface": {
            "class": OracleInterface,
            "oracle": oracle
        },
        "save_path": "trained_models/follow",
        "offline_graph_path": "offline_datasets/follow/200e_1000p.graph"
    }
}
