# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Double DQN (tuned) agent class."""

# pylint: disable=g-bad-import-order

from typing import Any, Callable, Mapping, Text

from absl import logging
import dm_env
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import chex

from utils import parts
from utils import processors
from utils import replay as replay_lib

Array = chex.Array
Numeric = chex.Numeric
# Batch variant of double_q_learning.

_batch_double_q_learning = jax.vmap(rlax.double_q_learning)


class DoubleDqn(parts.Agent):
  """Double DQN (tuned) agent."""

  def __init__(
      self,
      preprocessor: processors.Processor,
      sample_network_input: jnp.ndarray,
      network: parts.Network,
      optimizer: optax.GradientTransformation,
      transition_accumulator: Any,
      replay: replay_lib.TransitionReplay,
      batch_size: int,
      exploration_epsilon: Callable[[int], float],
      min_replay_capacity_fraction: float,
      learn_period: int,
      target_network_update_period: int,
      grad_error_bound: float,
      reg_weight: float,
      rng_key: parts.PRNGKey,
  ):
    self.my_train_stats = {}
    self.reg_weight = reg_weight
    self._preprocessor = preprocessor
    self._replay = replay
    self._transition_accumulator = transition_accumulator
    self._batch_size = batch_size
    self._exploration_epsilon = exploration_epsilon
    self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
    self._learn_period = learn_period
    self._target_network_update_period = target_network_update_period

    # Initialize network parameters and optimizer.
    self._rng_key, network_rng_key, network_rng_key2 = jax.random.split(rng_key, 3)
    online_params1 = network.init(network_rng_key, sample_network_input[None, ...])
    online_params2 = network.init(network_rng_key2, sample_network_input[None, ...])
    self._online_params = jax.tree_multimap(lambda p1, p2: jnp.stack((p1, p2)), online_params1, online_params2)
    self._target_params = self._online_params
    self._opt_state = optimizer.init(self._online_params)

    # Other agent state: last action, frame count, etc.
    self._action = None
    self._frame_t = -1  # Current frame index.
    self._statistics = {'state_value': np.nan}

    # Define jitted loss, update, and policy functions here instead of as
    # class methods, to emphasize that these are meant to be pure functions
    # and should not access the agent object's state via `self`.

    def loss_fn(online_params, target_params, transitions, rng_key):
      """Calculates loss given network parameters and transitions."""
      _, *apply_keys = jax.random.split(rng_key, 4)
      q_tm1 = jax.vmap(network.apply, (0, None, None))(online_params, apply_keys[0], transitions.s_tm1).q_values
      q_t = jax.vmap(network.apply, (0, None, None))(online_params, apply_keys[1], transitions.s_t).q_values
      prior_q_tm1 = jax.vmap(network.apply, (0, None, None))(target_params, apply_keys[2], transitions.s_tm1).q_values
      td_errors0 = _batch_double_q_learning(
          q_tm1[0],
          transitions.a_tm1,
          transitions.r_t,
          transitions.discount_t,
          q_t[0],
          q_t[1],
      )
      td_errors0 = rlax.clip_gradient(td_errors0, -grad_error_bound,
                                     grad_error_bound)
      td_loss0 = rlax.l2_loss(td_errors0)
      td_errors1 = _batch_double_q_learning(
          q_tm1[1],
          transitions.a_tm1,
          transitions.r_t,
          transitions.discount_t,
          q_t[0],
          q_t[1],
      )
      td_errors1 = rlax.clip_gradient(td_errors1, -grad_error_bound,
                                     grad_error_bound)
      td_loss1 = rlax.l2_loss(td_errors1)
      td_loss = td_loss0 + td_loss1
      prior_q_tm1 = jnp.mean(prior_q_tm1, 0, keepdims=True)
      prior_loss = 0.5 * jnp.sum((prior_q_tm1 - q_tm1)**2, 0)
      prior_loss = jnp.take_along_axis(prior_loss, jnp.reshape(transitions.a_tm1, (-1, 1)), -1)
      prior_loss = jnp.reshape(prior_loss, td_loss.shape)
      prior_loss = rlax.clip_gradient(prior_loss, -grad_error_bound, grad_error_bound)
      losses = td_loss + self.reg_weight * prior_loss
      assert losses.shape == (self._batch_size,)
      loss = jnp.mean(losses)
      return loss, {'prior_loss': jnp.mean(prior_loss), 
                    'td_loss': jnp.mean(td_loss),
                    'training_loss': loss}

    def update(rng_key, opt_state, online_params, target_params, transitions):
      """Computes learning update from batch of replay transitions."""
      rng_key, update_key = jax.random.split(rng_key)
      d_loss_d_params, stats = jax.grad(loss_fn, has_aux=True)(online_params, target_params,
                                                               transitions, update_key)
      updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state)
      new_online_params = optax.apply_updates(online_params, updates)
      return rng_key, new_opt_state, new_online_params, stats

    self._update = jax.jit(update)

    def select_action(rng_key, network_params, s_t, exploration_epsilon):
      """Samples action from eps-greedy policy wrt Q-values at given state."""
      rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
      q_t = jax.vmap(network.apply, (0, None, None))(network_params, apply_key, s_t[None, ...]).q_values[0]
      a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
      v_t = jnp.max(q_t, axis=-1)
      return rng_key, a_t, v_t

    self._select_action = jax.jit(select_action)

  def step(self, timestep: dm_env.TimeStep) -> parts.Action:
    """Selects action given timestep and potentially learns."""
    self._frame_t += 1

    timestep = self._preprocessor(timestep)

    if timestep is None:  # Repeat action.
      action = self._action
    else:
      action = self._action = self._act(timestep)

      for transition in self._transition_accumulator.step(timestep, action):
        self._replay.add(transition)

    if self._replay.size < self._min_replay_capacity:
      return action

    if self._frame_t % self._learn_period == 0:
      self._learn()

    if self._frame_t % self._target_network_update_period == 0:
      self._target_params = self._online_params

    return action

  def reset(self) -> None:
    """Resets the agent's episodic state such as frame stack and action repeat.

    This method should be called at the beginning of every episode.
    """
    self._transition_accumulator.reset()
    processors.reset(self._preprocessor)
    self._action = None

  def _act(self, timestep) -> parts.Action:
    """Selects action given timestep, according to epsilon-greedy policy."""
    s_t = timestep.observation
    self._rng_key, a_t, v_t = self._select_action(self._rng_key,
                                                  self._online_params, s_t,
                                                  self.exploration_epsilon)
    a_t, v_t = jax.device_get((a_t, v_t))
    self._statistics['state_value'] = v_t
    return parts.Action(a_t)

  def _learn(self) -> None:
    """Samples a batch of transitions from replay and learns from it."""
    logging.log_first_n(logging.INFO, 'Begin learning', 1)
    transitions = self._replay.sample(self._batch_size)
    self._rng_key, self._opt_state, self._online_params, stats = self._update(
        self._rng_key,
        self._opt_state,
        self._online_params,
        self._target_params,
        transitions,
    )
    self.my_train_stats = stats

  @property
  def online_params(self) -> parts.NetworkParams:
    """Returns current parameters of Q-network."""
    return self._online_params

  @property
  def statistics(self) -> Mapping[Text, float]:
    """Returns current agent statistics as a dictionary."""
    # Check for DeviceArrays in values as this can be very slow.
    assert all(
        not isinstance(x, jnp.DeviceArray) for x in self._statistics.values())
    return self._statistics

  @property
  def exploration_epsilon(self) -> float:
    """Returns epsilon value currently used by (eps-greedy) behavior policy."""
    return self._exploration_epsilon(self._frame_t)

  def get_state(self) -> Mapping[Text, Any]:
    """Retrieves agent state as a dictionary (e.g. for serialization)."""
    state = {
        'rng_key': self._rng_key,
        'frame_t': self._frame_t,
        'opt_state': self._opt_state,
        'online_params': self._online_params,
        'target_params': self._target_params,
        'replay': self._replay.get_state(),
    }
    return state

  def set_state(self, state: Mapping[Text, Any]) -> None:
    """Sets agent state from a (potentially de-serialized) dictionary."""
    self._rng_key = state['rng_key']
    self._frame_t = state['frame_t']
    self._opt_state = jax.device_put(state['opt_state'])
    self._online_params = jax.device_put(state['online_params'])
    self._target_params = jax.device_put(state['target_params'])
    self._replay.set_state(state['replay'])
