import torch
import numpy as np
from k_level_policy_gradients.src.algorithms.value.abstract_dqn import AbstractDQN


class DQN(AbstractDQN):
    """
    Deep Q-Network algorithm.
    "Human-Level Control Through Deep Reinforcement Learning".
    Mnih V. et al.. 2015.

    """

    def _fit(self):
        if self._replay_memory.size > self._warmup_replay_size:
            obs, actions, rewards, next_obs, next_action_masks, absorbing, _ = (
                self._replay_memory.get(self._batch_size)
            )

            obs_t = torch.tensor(obs, dtype=torch.float32)
            actions_t = torch.tensor(actions, dtype=torch.long)
            rewards_t = torch.tensor(rewards, dtype=torch.float32)
            next_obs_t = torch.tensor(next_obs, dtype=torch.float32)
            next_action_masks_t = torch.tensor(next_action_masks, dtype=torch.bool)
            absorbings_t = torch.tensor(absorbing, dtype=torch.bool)

            q_obs = self.approximator.predict(obs_t)
            q_hat = torch.squeeze(q_obs.gather(1, actions_t))
            q_next = self._next_q(next_obs_t, next_action_masks_t)
            q_next *= ~absorbings_t
            q_target = rewards_t + self.mdp_info.gamma * q_next
            loss = self.approximator._loss(q_hat, q_target)
            self.approximator._optimizer.zero_grad()
            loss.backward()
            self.approximator._optimizer.step()

    def _next_q(self, next_state, next_action_mask=None):
        q_next = self.target_approximator.predict(next_state)
        if next_action_mask is not None:
            q_mask = torch.where(next_action_mask, q_next, torch.tensor(float("-inf")))
            return q_mask.max(1)
        return q_next.max(1)
