from expground.types import AgentID, DataArray, Dict, Any, Tuple
from expground.common.schedules import LinearSchedule
from expground.algorithms.base_trainer import Trainer
from expground.algorithms.loss_func import LossFunc
from expground.algorithms.base_policy import Policy

from .policy import DDPG


class DDPGTrainer(Trainer):
    def __init__(
        self,
        loss_func: LossFunc,
        training_config: Dict[str, Any],
        policy_instance: Policy,
    ):
        super(DDPGTrainer, self).__init__(
            loss_func, training_config=training_config, policy_instance=policy_instance
        )

    def _before_loss(
        self, policy: DDPG, batch: Dict[AgentID, Dict[str, DataArray]]
    ) -> Tuple[Dict[str, DataArray], Dict]:
        """Batch preprocessing, single agent training requires only pure batched data without agent mapping.

        Args:
            policy (DDPG): A ddpg policy instance.
            batch (Dict[AgentID, Dict[str, DataArray]]): A batch data mapping from agent to batch dicts.

        Returns:
            Dict[str, DataArray]: A dict of batched data, mapping from string to DataArray instance.
        """

        assert len(batch) == 1
        return list(batch.values())[0], {}

    def _after_loss(self, policy: DDPG, step_counter: int) -> None:
        """Update target network here.

        Args:
            policy (DDPG): A ddpg policy instance.
            step_counter (int): Global step counter.
        """

        # if step_counter % self._training_config["update_interval"] == 0:
        policy.update_target(tau=self._training_config["tau"])


class MADDPGTrainer(Trainer):
    def _before_loss(self, policy, batch):
        return super()._before_loss(policy, batch)

    def _after_loss(self, policy, step_counter: int):
        return super()._after_loss(policy, step_counter)
