from expground.types import AgentID, DataArray, Dict, Any, Tuple
from expground.common.schedules import LinearSchedule
from expground.algorithms.base_trainer import Trainer
from expground.algorithms.loss_func import LossFunc
from expground.algorithms.base_policy import Policy

from .policy import SAC


class SACTrainer(Trainer):
    def __init__(
        self,
        loss_func: LossFunc,
        training_config: Dict[str, Any],
        policy_instance: Policy,
    ):
        super(SACTrainer, self).__init__(
            loss_func, training_config=training_config, policy_instance=policy_instance
        )

    def _before_loss(self, policy: SAC, batch):
        assert (
            len(batch) == 1
        ), "SAC supports only single agent mode now, batched agents should less than 2."
        return list(batch.values())[0], {}

    def _after_loss(self, policy: SAC, step_counter: int):
        policy.update_target(tau=self._training_config["tau"])
