# coding=utf-8
# Copyright 2022 The Variance Double Down Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""QR-DQN agent in Jax with variance reporting."""

import functools
from dopamine.jax.agents.quantile import quantile_agent
from dopamine.metrics import statistics_instance
import gin
import jax
import jax.numpy as jnp
import numpy as onp
import optax


@functools.partial(jax.jit, static_argnames=('network_def',
                                             'optimizer',
                                             'kappa',
                                             'num_atoms',
                                             'cumulative_gamma',
                                             'noise_scale'))
def train(network_def, online_params, target_params, optimizer, optimizer_state,
          states, actions, next_states, rewards, terminals, kappa, num_atoms,
          cumulative_gamma, noise_scale, rng):
  """Run a training step."""
  def loss_fn(params, target):
    def q_online(state):
      return network_def.apply(params, state)

    logits = jax.vmap(q_online)(states).logits
    logits = jnp.squeeze(logits)
    # Fetch the logits for its selected action. We use vmap to perform this
    # indexing across the batch.
    chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions)
    bellman_errors = (target[:, None, :] -
                      chosen_action_logits[:, :, None])  # Input `u' of Eq. 9.
    # Eq. 9 of paper.
    huber_loss = (
        (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) *
        0.5 * bellman_errors ** 2 +
        (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) *
        kappa * (jnp.abs(bellman_errors) - 0.5 * kappa))

    tau_hat = ((jnp.arange(num_atoms, dtype=jnp.float32) + 0.5) /
               num_atoms)  # Quantile midpoints.  See Lemma 2 of paper.
    # Eq. 10 of paper.
    tau_bellman_diff = jnp.abs(
        tau_hat[None, :, None] - (bellman_errors < 0).astype(jnp.float32))
    quantile_huber_loss = tau_bellman_diff * huber_loss
    # Sum over tau dimension, average over target value dimension.
    loss = jnp.sum(jnp.mean(quantile_huber_loss, 2), 1)
    return jnp.mean(loss), loss

  def q_target(state):
    return network_def.apply(target_params, state)

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  target = quantile_agent.target_distribution(
      q_target,
      next_states,
      rewards,
      terminals,
      cumulative_gamma)
  if noise_scale is not None:
    rng, noise_rng = jax.random.split(rng)
    target += jax.random.normal(noise_rng, target.shape) * noise_scale
  (mean_loss, loss), grad = grad_fn(online_params, target)
  grad_avg = jax.tree_util.tree_map(jnp.mean, grad)
  grad_means = {}
  for layer in grad_avg['params']:
    grad_means[layer] = {}
    for layer_comp in grad_avg['params'][layer]:
      grad_means[layer][layer_comp] = (
          grad_avg['params'][layer][layer_comp])
  updates, optimizer_state = optimizer.update(grad, optimizer_state,
                                              params=online_params)
  online_params = optax.apply_updates(online_params, updates)
  return {
      'optimizer_state': optimizer_state,
      'online_params': online_params,
      'loss': loss,
      'mean_loss': mean_loss,
      'grad_means': grad_means,
      'rng': rng,
  }


@functools.partial(jax.jit, static_argnames=('network_def',))
def compute_policy_churn(network_def, online_params, prev_params, states):
  """Compute policy churn value."""

  def q_func(state, params):
    return network_def.apply(params, state)

  q_online = lambda x: q_func(x, online_params)
  q_prev = lambda x: q_func(x, prev_params)

  batch_size = states.shape[0]
  q_values = jnp.squeeze(jax.vmap(q_online)(states).q_values)
  argmax_actions = jnp.argmax(q_values, axis=-1)
  prev_q_values = jnp.squeeze(jax.vmap(q_prev)(states).q_values)
  prev_argmax_actions = jnp.argmax(prev_q_values, axis=-1)
  # Count all pairs of actions which are unequal. This is equivalent to the
  # total variation distance for deterministic policies.
  return jnp.count_nonzero(argmax_actions - prev_argmax_actions) / batch_size


@gin.configurable
class VariAnQuantileAgent(quantile_agent.JaxQuantileAgent):
  """Variance analysis QR-DQN Agent."""

  def __init__(self, num_actions, summary_writer=None, batch_divisor=1,
               max_accumulation=500, target_noise_scale=None,
               policy_churn_batch_size=32):
    """Initializes the agent."""
    self._batch_divisor = batch_divisor
    self._accumulated_losses = []
    self._accumulated_grads = {}
    self._max_accumulation = max_accumulation
    self._target_noise_scale = target_noise_scale
    self._previous_params = None
    self._policy_churn_batch_size = 32
    super().__init__(num_actions, summary_writer=summary_writer)

  def _train_step(self):
    """Runs a single training step.

    Runs training if both:
      (1) A minimum number of frames have been added to the replay buffer.
      (2) `training_steps` is a multiple of `update_period`.

    Also, syncs weights from online_params to target_network_params if training
    steps is a multiple of target update period.
    """
    if self._replay.add_count > self.min_replay_history:
      if self.training_steps % self.update_period == 0:
        policy_churn = 0.0
        if self._previous_params is not None:
          samples = self._replay.sample_transition_batch(
              batch_size=self._policy_churn_batch_size)
          types = self._replay.get_transition_elements(
              batch_size=self._policy_churn_batch_size)
          for element, element_type in zip(samples, types):
            if element_type.name == 'state':
              churn_states = element
              break
          policy_churn = compute_policy_churn(
              self.network_def, self.online_params, self._previous_params,
              self.preprocess_fn(churn_states))
        self._sample_from_replay_buffer()
        mini_states = onp.array_split(self.replay_elements['state'],
                                      self._batch_divisor)
        mini_actions = onp.array_split(self.replay_elements['action'],
                                       self._batch_divisor)
        mini_next_states = onp.array_split(self.replay_elements['next_state'],
                                           self._batch_divisor)
        mini_rewards = onp.array_split(self.replay_elements['reward'],
                                       self._batch_divisor)
        mini_terminals = onp.array_split(self.replay_elements['terminal'],
                                         self._batch_divisor)
        mini_sampling_probabilities = onp.array_split(
            self.replay_elements['sampling_probabilities'], self._batch_divisor)
        mini_indices = onp.array_split(self.replay_elements['indices'],
                                       self._batch_divisor)
        mean_mean_losses = 0.0
        for i in range(self._batch_divisor):
          train_returns = train(
              self.network_def,
              self.online_params,
              self.target_network_params,
              self.optimizer,
              self.optimizer_state,
              self.preprocess_fn(mini_states[i]),
              mini_actions[i],
              self.preprocess_fn(mini_next_states[i]),
              mini_rewards[i],
              mini_terminals[i],
              self._kappa,
              self._num_atoms,
              self.cumulative_gamma,
              self._target_noise_scale,
              self._rng)
          self.optimizer_state = train_returns['optimizer_state']
          self._previous_params = self.online_params
          self.online_params = train_returns['online_params']
          loss = train_returns['loss']
          mean_loss = train_returns['mean_loss']
          self._rng = train_returns['rng']
          grad_means = train_returns['grad_means']
          if self._replay_scheme == 'prioritized':
            # The original prioritized experience replay uses a linear exponent
            # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of
            # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders)
            # suggested a fixed exponent actually performs better, except on
            # Pong.
            probs = mini_sampling_probabilities[i]
            loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)
            loss_weights /= jnp.max(loss_weights)

            # Rainbow and prioritized replay are parametrized by an exponent
            # alpha, but in both cases it is set to 0.5 - for simplicity's sake
            # we leave it as is here, using the more direct sqrt(). Taking the
            # square root "makes sense", as we are dealing with a squared loss.
            # Add a small nonzero value to the loss to avoid 0 priority items.
            # While technically this may be okay, setting all items to 0
            # priority will cause troubles, and also result in 1.0 / 0.0 = NaN
            # correction terms.
            self._replay.set_priority(mini_indices[i],
                                      jnp.sqrt(loss + 1e-10))

            # Weight the loss by the inverse priorities.
            loss = loss_weights * loss
            mean_loss = jnp.mean(loss)
          mean_mean_losses += mean_loss
          # Accumulate last max_accumulation variances.
          self._accumulated_losses.append(mean_loss)
          self._accumulated_losses = self._accumulated_losses[
              -self._max_accumulation:]
          for layer in grad_means:
            if layer not in self._accumulated_grads:
              self._accumulated_grads[layer] = {}
            for layer_comp in grad_means[layer]:
              if layer_comp not in self._accumulated_grads[layer]:
                self._accumulated_grads[layer][layer_comp] = []
              self._accumulated_grads[layer][layer_comp].append(
                  onp.asarray(grad_means[layer][layer_comp]))
              self._accumulated_grads[layer][layer_comp] = (
                  self._accumulated_grads[layer][layer_comp][
                      -self._max_accumulation:])
        if (self.summary_writer is not None and
            self.training_steps > 0 and
            self.training_steps % self.summary_writing_frequency == 0):
          if hasattr(self, 'collector_dispatcher'):
            aggregated_loss_variance = onp.var(self._accumulated_losses)
            stats = [
                statistics_instance.StatisticsInstance(
                    'Loss', onp.asarray(mean_mean_losses / self._batch_divisor),
                    step=self.training_steps),
                statistics_instance.StatisticsInstance(
                    'AggrLossVariance', aggregated_loss_variance,
                    step=self.training_steps),
                statistics_instance.StatisticsInstance(
                    'PolicyChurn', onp.asarray(policy_churn),
                    step=self.training_steps),
            ]
            for layer in grad_means:
              for layer_comp in grad_means[layer]:
                stats.append(
                    statistics_instance.StatisticsInstance(
                        f'AggrGradVariance_{layer}_{layer_comp}',
                        onp.var(
                            self._accumulated_grads[layer][layer_comp]),
                        step=self.training_steps))
            self.collector_dispatcher.write(
                stats,
                collector_allowlist=self._collector_allowlist)
      if self.training_steps % self.target_update_period == 0:
        self._sync_weights()

    self.training_steps += 1
