from expground.types import Dict, Any

from expground.algorithms.base_trainer import Trainer
from expground.algorithms.loss_func import LossFunc
from expground.algorithms.base_policy import Policy


class RandomTrainer(Trainer):
    def __init__(
        self,
        loss_func: LossFunc,
        training_config: Dict[str, Any] = None,
        policy_instance: Policy = None,
    ):
        super(RandomTrainer, self).__init__(
            loss_func, training_config=training_config, policy_instance=policy_instance
        )

    def _before_loss(self, policy, batch):
        return batch, {}

    def _after_loss(self, policy, step_counter: int):
        pass
