"""The trainer defined in this file allows for overwriting rewards with the output from a RewardModel."""
from typing import Optional, Type

from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
from ray.rllib.agents.ppo.ppo import (
    DEFAULT_CONFIG as PPO_DEFAULT_CONFIG,
    execution_plan,
    get_policy_class,
    validate_config,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict

from offline_rl.agents.reward_overwriting.reward_overwriting_policy import RewardOverwritingPolicy

DEFAULT_CONFIG = PPO_DEFAULT_CONFIG
DEFAULT_CONFIG["reward_models"] = {
    "learned": {},
}
# If set to true effectively prevents training from having an impact.
# Used for evaluating a policy while using rllib parallelized data collection.
DEFAULT_CONFIG["do_not_update_model"] = False


def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
    return RewardOverwritingPolicy


RewardOverwritingTrainer = build_trainer(
    name="RewardOverwriting",
    default_config=DEFAULT_CONFIG,
    validate_config=validate_config,
    default_policy=RewardOverwritingPolicy,
    get_policy_class=get_policy_class,
    execution_plan=execution_plan,
)
# This allows for defining arbitrary reward models in the configuration file.
RewardOverwritingTrainer._allow_unknown_subkeys.append("reward_models")
