from k_level_policy_gradients.src.algorithms.agent import Agent
from k_level_policy_gradients.src.approximators.torch_approximator import (
    TorchApproximator,
)


class AbstractDQN(Agent):
    def __init__(
        self,
        mdp_info,
        idx_agent,
        policy,
        batch_size,
        replay_memory,
        target_update_frequency,
        tau,
        warmup_replay_size,
        target_update_mode,
        approximator_params,
        grad_norm_clip,
        obs_last_action,
        primary_agent=None,
        use_mixer=False,
        use_cuda=False,
    ):
        super().__init__(mdp_info, policy, idx_agent)

        self._batch_size = batch_size
        self._replay_memory = replay_memory
        self._target_update_frequency = target_update_frequency
        self._tau = tau
        self._warmup_replay_size = warmup_replay_size
        self._target_update_mode = target_update_mode
        self._grad_norm_clip = grad_norm_clip
        self._obs_last_action = obs_last_action
        self._primary_agent = primary_agent
        self._use_mixer = use_mixer
        self._use_cuda = use_cuda

        self._n_updates = 0

        self.approximator = TorchApproximator(**approximator_params)
        self.target_approximator = TorchApproximator(**approximator_params)
        if primary_agent is None:
            self._update_target_hard()
        else:
            # Set this agent's approximator and target approximator to be the same as the primary agent
            self.approximator.set_primary_approximator(primary_agent.approximator)
            self.target_approximator.set_primary_approximator(
                primary_agent.target_approximator
            )
        self._optimizer = self.approximator._optimizer
        policy.set_approximator(self.approximator)

        self._add_save_attr(
            _batch_size="primitive",
            _target_update_frequency="primitive",
            _tau="primitive",
            _replay_memory="mushroom!",
            _n_updates="primitive",
            approximator="mushroom",
            target_approximator="mushroom",
            _optimizer="torch",
            _use_mixer="primitive",
            _use_cuda="primitive",
        )

    def draw_action(self, state, action_mask=None):
        """
        Return the action to execute in the given state. It is the action
        returned by the policy or the action set by the algorithm (e.g. in the
        case of SARSA).

        Args:
            state (np.ndarray): the state where the agent is.
            action_mask (np.ndarray, None): the mask to apply to the action space.

        Returns:
            The action to be executed.

        """
        if action_mask is not None:
            return self.policy.draw_action(state, action_mask)
        else:
            return self.policy.draw_action(state)

    def fit(self, dataset):
        if self._use_mixer:
            loss = 0  # storage and fitting handled by mixer
        else:
            own_dataset = self.split_dataset(dataset)
            self._replay_memory.add(own_dataset)
            loss = self._fit()

        self._n_updates += 1
        if self._idx_agent == 0 or self._primary_agent is None:
            if self._target_update_mode == "soft":
                self._update_target_soft()
            elif self._target_update_mode == "hard":
                if self._n_updates % self._target_update_frequency == 0:
                    self._update_target_hard()
        return loss, loss

    def split_dataset(self, dataset):
        own_dataset = list()
        for sample in dataset:
            own_sample = {
                "obs": sample["obs"][self._idx_agent],
                "action": sample["actions"][self._idx_agent],
                "reward": sample["rewards"][self._idx_agent],
                "next_obs": sample["next_obs"][self._idx_agent],
                "next_action_mask": sample["next_action_masks"][self._idx_agent],
                "absorbing": sample["absorbing"],
                "last": sample["last"],
            }
            own_dataset.append(own_sample)
        return own_dataset

    def _update_target_hard(self):
        """
        Update the target network.

        """
        self.target_approximator.set_weights(self.approximator.get_weights())

    def _update_target_soft(self):
        weights = self._tau * self.approximator.get_weights()
        weights += (1 - self._tau) * self.target_approximator.get_weights()
        self.target_approximator.set_weights(weights)

    def _next_q(self, next_state):
        """
        Args:
            next_state (np.ndarray): the states where next action has to be
                evaluated;

        Returns:
            Maximum action-value for each state in ``next_state``.

        """
        raise NotImplementedError

    def _post_load(self):
        self.policy.set_approximator(self.approximator)

    def set_logger(self, logger, loss_filename="loss_Q"):
        """
        Setter that can be used to pass a logger to the algorithm

        Args:
            logger (Logger): the logger to be used by the algorithm;
            loss_filename (str, 'loss_Q'): optional string to specify the loss filename.

        """
        self.approximator.set_logger(logger, loss_filename)
