from typing import Dict
from expground.types import AgentID, DataArray, Any, Tuple
from expground.common.schedules import LinearSchedule
from expground.algorithms import misc
from expground.algorithms.base_trainer import Trainer
from expground.algorithms.loss_func import LossFunc
from expground.algorithms.base_policy import Policy

from .policy import DQN


class DQNTrainer(Trainer):
    def __init__(
        self,
        loss_func: LossFunc,
        training_config: Dict[str, Any],
        policy_instance: Policy,
    ):
        super(DQNTrainer, self).__init__(
            loss_func, training_config=training_config, policy_instance=policy_instance
        )
        exploration_fraction = self._training_config["exploration_fraction"]
        total_timesteps = self._training_config["total_timesteps"]
        exploration_final_eps = self._training_config["exploration_final_eps"]
        self.fixed_eps = self._training_config.get("pretrain_eps")
        self.pretrain_mode = False

        self.exploration = LinearSchedule(
            schedule_timesteps=int(exploration_fraction * total_timesteps),
            initial_p=1.0 if self.fixed_eps is None else self.fixed_eps,
            final_p=exploration_final_eps,
        )

    def set_pretrain(self, pmode=True):
        self.pretrain_mode = pmode

    def _before_loss(
        self, policy: DQN, batch: Dict[AgentID, Dict[str, DataArray]]
    ) -> Tuple[Dict, Dict]:
        """Batch preprocessing, single agent training requires only pure batched data without agent mapping.

        Args:
            policy (DQN): A dqn policy instance.
            batch (Dict[AgentID, Dict[str, DataArray]]): A batch data mapping from agent to batch dicts.
        """

        assert len(batch) == 1

        policy = policy.to(
            "cuda" if policy.custom_config.get("use_cuda", False) else "cpu"
        )

        # set exploration rate for policy
        if not self._training_config.get("param_noise", False):
            update_eps = self.exploration.value(self.counter)
            update_param_noise_threshold = 0.0
        else:
            update_eps = 0.0
        if self.pretrain_mode and self.fixed_eps is not None:
            policy.eps = self.fixed_eps
        else:
            policy.eps = update_eps
        return list(batch.values())[0], {}

    def get_eps(self):
        if not self._training_config.get("param_noise", False):
            update_eps = self.exploration.value(self.counter)
        else:
            update_eps = 0.0
        if self.pretrain_mode and self.fixed_eps is not None:
            eps = self.fixed_eps
        else:
            eps = update_eps
        return "{}, {}, {}".format(
            self._step_counter, self.exploration.schedule_timesteps, eps
        )

    def _after_loss(self, policy: DQN, step_counter: int):
        """Update target network here.

        Args:
            policy (DQN): A dqn policy instance.
            step_counter (int): Global step counter.
        """

        misc.soft_update(
            policy.target_critic, policy.critic, tau=self._training_config["tau"]
        )

        return {"eps": policy.eps}
