# coding=utf-8
# Copyright 2022 The Multi Task Atari Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Compact implementation of an offline multi-task DQN agent in JAX."""

from absl import logging
from multi_task_atari import multi_task_fixed_replay as multi_game_fixed_replay
from multi_task_atari import multi_task_dqn_agent
from multi_task_atari import multi_task_offline_dqn_agent

from multi_task_atari import multi_task_tfds

from multi_task_atari import large_networks
from multi_task_atari import networks
from JaxCQL import jax_utils

import gin
from dopamine.jax import losses
from dopamine.jax.agents.rainbow import rainbow_agent
import numpy as onp

import jax
import jax.numpy as jnp
import optax
import tensorflow as tf

import collections
import functools


################ DATA AUGMENTATIONS ###################


@functools.partial(jax.vmap, in_axes=(0, 0, 0, None))
def _crop_with_indices(img, x, y, cropped_shape):
  cropped_image = (jax.lax.dynamic_slice(img, [x, y, 0], cropped_shape[1:]))
  return cropped_image


def _per_image_random_crop(key, img, cropped_shape):
  """Random crop an image."""
  batch_size, width, height = cropped_shape[:-1]
  key_x, key_y = jax.random.split(key, 2)
  x = jax.random.randint(
      key_x, shape=(batch_size,), minval=0, maxval=img.shape[1] - width)
  y = jax.random.randint(
      key_y, shape=(batch_size,), minval=0, maxval=img.shape[2] - height)
  return _crop_with_indices(img, x, y, cropped_shape)


def _intensity_aug(key, x, scale=0.05):
  """Follows the code in Schwarzer et al. (2020) for intensity augmentation."""
  r = jax.random.normal(key, shape=(x.shape[0], 1, 1, 1))
  noise = 1.0 + (scale * jnp.clip(r, -2.0, 2.0))
  return x * noise

@functools.partial(jax.pmap,
    axis_name='pmap', static_broadcasted_argnums=(2,))
def drq_image_augmentation(key, obs, img_pad):
  """Padding and cropping for DrQ."""
  flat_obs = obs.reshape(-1, *obs.shape[-3:])
  paddings = [(0, 0), (img_pad, img_pad), (img_pad, img_pad), (0, 0)]
  cropped_shape = flat_obs.shape
  # The reference uses ReplicationPad2d in pytorch, but it is not available
  # in Jax. Use 'edge' instead.
  flat_obs = jnp.pad(flat_obs, paddings, 'edge')
  key1, key2 = jax.random.split(key, num=2)
  cropped_obs = _per_image_random_crop(key2, flat_obs, cropped_shape)
  # cropped_obs = _random_crop(key2, flat_obs, cropped_shape)
  aug_obs = _intensity_aug(key1, cropped_obs)
  return aug_obs.reshape(*obs.shape)


def preprocess_inputs_with_augmentation(x, data_augmentation=False, rng=None):
  """Input normalization and if specified, data augmentation."""
  out = x.astype(jnp.float32) / 255.
  if data_augmentation:
    if rng is None:
      raise ValueError('Pass rng when using data augmentation')
    out = drq_image_augmentation(
        jax.random.split(rng, len(jax.local_devices())), out, 4)
  return out
#######################################################


@functools.partial(jax.jit, static_argnums=(0, 4, 5, 6, 7, 8, 10, 11, 14,))
def select_action(network_def, params, state, rng, num_actions, eval_mode,
                  epsilon_eval, epsilon_train, epsilon_decay_period,
                  training_steps, min_replay_history, epsilon_fn,
                  game_index=None, game_valid_actions=None,
                  use_single_game_action_space=False):
  """Select an action from the set of available actions.

  Chooses an action randomly with probability self._calculate_epsilon(), and
  otherwise acts greedily according to the current Q-value estimates.

  Args:
    network_def: Linen Module to use for inference.
    params: Linen params (frozen dict) to use for inference.
    state: input state to use for inference.
    rng: Jax random number generator.
    num_actions: int, number of actions (static_argnum).
    eval_mode: bool, whether we are in eval mode (static_argnum).
    epsilon_eval: float, epsilon value to use in eval mode (static_argnum).
    epsilon_train: float, epsilon value to use in train mode (static_argnum).
    epsilon_decay_period: float, decay period for epsilon value for certain
      epsilon functions, such as linearly_decaying_epsilon, (static_argnum).
    training_steps: int, number of training steps so far.
    min_replay_history: int, minimum number of steps in replay buffer
      (static_argnum).
    epsilon_fn: function used to calculate epsilon value (static_argnum).

  Returns:
    rng: Jax random number generator.
    action: int, the selected action.
  """
  epsilon = jnp.where(eval_mode,
                      epsilon_eval,
                      epsilon_fn(epsilon_decay_period,
                                 training_steps,
                                 min_replay_history,
                                 epsilon_train))

  rng, rng1, rng2 = jax.random.split(rng, num=3)
  rng_generator = jax_utils.JaxRNG(rng)
  p = jax.random.uniform(rng1)
  if game_index is not None and not use_single_game_action_space:
    print ('Using game index for action selection')
    q_out = network_def.apply(params, state, game_index, rngs={'dropout': rng_generator()}).q_values
    if game_valid_actions is not None:
      game_valid = game_valid_actions
      q_out = q_out + 1000 * game_valid
      valid_prob = jax.nn.softmax(1000 * game_valid)
    else:
      valid_prob = jax.nn.softmax(jnp.ones_like(q_out))
    action = jnp.argmax(q_out)
  else:
    print ('Not using game index for action selection')
    action = jnp.argmax(network_def.apply(params, state, rngs={'dropout': rng_generator()}).q_values)
    valid_prob = jax.nn.softmax(jnp.ones(shape=(num_actions,)))

  random_action = jax.random.choice(
      rng2, a=num_actions, shape=(), p=valid_prob)
  return rng_generator(), jnp.where(p <= epsilon, random_action, action)


@functools.partial(jax.pmap, axis_name='pmap',
                   static_broadcasted_argnums=(0, 8, 10))
def compute_trainability_metrics(network_def, online_params, target_params,
                                 states, actions, next_states, rewards,
                                 terminals, cumulative_gamma,
                                 next_actions, sarsa_backups,
                                 task_ids=None):
  """Compute the metrics to indicate trainability of the learned model."""
  def q_online(state, task):
    return network_def.apply(online_params, state, task)

  model_output = jax.vmap(q_online)(states, task_ids)
  representations = jnp.squeeze(model_output.representation)
  covariance_matrix = jnp.matmul(representations,
                                 jnp.transpose(representations, axes=(1, 0)))
  covariance_matrix = covariance_matrix + 0.001 * jnp.eye(
      covariance_matrix.shape[0])
  # Add stop gradient on covariance matrix
  covariance_matrix = jax.lax.stop_gradient(covariance_matrix)
  max_singular_value = jnp.linalg.svd(covariance_matrix, compute_uv=False)[0]

  def q_target(state, task):
    return network_def.apply(target_params, state, task)

  bellman_target = target_q(
      q_target, next_states, rewards, terminals, cumulative_gamma,
      next_actions=next_actions, sarsa_backups=sarsa_backups,
      task_ids=task_ids, without_stop_gradient=True)

  out = jnp.linalg.lstsq(covariance_matrix, bellman_target)[0]
  out_linalg = jnp.linalg.solve(covariance_matrix, bellman_target)

  final_out = jnp.sum(out * bellman_target)
  final_out_linalg = jnp.sum(out_linalg * bellman_target)

  normalized_out = final_out * max_singular_value
  normalized_out_linalg = final_out_linalg * max_singular_value

  ret_dict = dict()
  ret_dict['Difficulty'] = jax.lax.pmean(final_out, axis_name='pmap')
  ret_dict['Difficulty_Linalg'] = jax.lax.pmean(final_out_linalg,
                                                axis_name='pmap')
  ret_dict['Normalized_Difficulty'] = jax.lax.pmean(
      normalized_out, axis_name='pmap')
  ret_dict['Normalized_Difficulty_Linalg'] = jax.lax.pmean(
      normalized_out_linalg, axis_name='pmap')
  ret_dict['Max Singular value'] = jax.lax.pmean(
      max_singular_value, axis_name='pmap')

  return ret_dict


@functools.partial(jax.pmap, axis_name='pmap', static_broadcasted_argnums=(0,))
def compute_gradient_metrics(network_def, online_params, states, actions,
                             next_states, task_ids=None):
  """Compute gradient metrics"""
  # Jacobian over parameters
  def qf_jacobian_matrix(inp_params, state, action, task):
    """Compute jacobian (per state gradient) for DR3 with gradient computation."""
    @jax.jacrev
    def qf_jacobian_helper(inp_params, state, task):
      def compute_q_vals_for_jacobian(state, task):
        return network_def.apply(inp_params, state, task)
      q_vals = jax.vmap(compute_q_vals_for_jacobian)(state, task).q_values
      q_s_ap_values = jnp.max(q_vals, 1)
      return q_s_ap_values

    @jax.jacrev
    def qf_jacobian_helper_sa(inp_params, state, action, task):
      def compute_q_vals_for_jacobian(state, task):
        return network_def.apply(inp_params, state, task)
      q_vals = jax.vmap(compute_q_vals_for_jacobian)(state, task).q_values
      q_vals_data = jax.vmap(lambda x, y: x[y])(q_vals, action)
      return q_vals_data

    jacobian_s_pi = qf_jacobian_helper(inp_params, state, task)
    jacobian_s_a = qf_jacobian_helper_sa(inp_params, state, action, task)
    jacobian_ns_pi = qf_jacobian_helper(inp_params, next_states, task)

    jacobian_s_pi = jax.tree_util.tree_reduce(
        lambda x, y: jnp.concatenate((x, y), axis=-1),
        jax.tree_map(lambda x: x.reshape(x.shape[0], -1), jacobian_s_pi))
    jacobian_s_a = jax.tree_util.tree_reduce(
        lambda x, y: jnp.concatenate((x, y), axis=-1),
        jax.tree_map(lambda x: x.reshape(x.shape[0], -1), jacobian_s_a))
    jacobian_ns_pi = jax.tree_util.tree_reduce(
        lambda x, y: jnp.concatenate((x, y), axis=-1),
        jax.tree_map(lambda x: x.reshape(x.shape[0], -1), jacobian_ns_pi))
    return jacobian_s_a, jacobian_s_pi, jacobian_ns_pi

  jacobian_sa, jacobian_s_pi, jacobian_ns_pi = qf_jacobian_matrix(
      online_params, states, actions, task_ids)
  gradient_norm_s_pi = jnp.mean(jnp.linalg.norm(jacobian_s_pi, axis=-1))
  gradient_norm_sa = jnp.mean(jnp.linalg.norm(jacobian_sa, axis=-1))
  gradient_dot_products = jnp.mean(
      jnp.sum(jacobian_sa * jacobian_ns_pi, axis=-1))

  ret_dict = dict()
  ret_dict['gradient_norm_s_pi'] = jax.lax.pmean(
      gradient_norm_s_pi, axis_name='pmap')
  ret_dict['gradient_norm_s_a'] = jax.lax.pmean(
      gradient_norm_sa, axis_name='pmap')
  ret_dict['gradient_dot_product'] = jax.lax.pmean(
      gradient_dot_products, axis_name='pmap')
  return ret_dict


def compute_feature_norms(state_representations, next_state_representations,
                          prefix=''):
  """Compute feature norms and concentrability."""
  current_feature_norms = jnp.linalg.norm(state_representations, axis=-1)
  next_feature_norms = jnp.linalg.norm(next_state_representations, axis=-1)
  current_feature_norm_mean = jnp.mean(current_feature_norms)
  next_feature_norm_mean = jnp.mean(next_feature_norms)

  # this concentrability is simply the sum of eigenvalues (trace) of
  # phi(s', a') phi(s', a')^T divided by the trace of phi(s, a) phi(s, a)^T
  concentrability = next_feature_norm_mean / (current_feature_norm_mean + 1e-6)
  concentrability_diff = next_feature_norm_mean - current_feature_norm_mean
  ret_dict = dict()
  ret_dict[prefix + '_current_feature_norm'] = current_feature_norm_mean
  ret_dict[prefix + '_next_feature_norm'] = next_feature_norm_mean
  ret_dict[prefix + '_concentrability'] = concentrability
  ret_dict[prefix + '_concentrability_diff'] = concentrability_diff
  return ret_dict


def project_distribution(supports, weights, target_support):
  """
  Projects target values based on cross-entropy
  """
  v_min, v_max = target_support[0], target_support[-1]
  num_dims = target_support.shape[0]

  delta_z = (v_max - v_min) / (num_dims - 1)
  clipped_support = jnp.clip(supports, v_min, v_max)
  # numerator = `|clipped_support - z_i|` in Eq7.
  numerator = jnp.abs(clipped_support - target_support)
  quotient = 1 - (numerator / delta_z)
  # clipped_quotient = `[1 - numerator / (\Delta z)]_0^1` in Eq7.
  clipped_quotient = jnp.clip(quotient, 0, 1)
  # inner_prod = `\sum_{j=0}^{N-1} clipped_quotient * p_j(x', \pi(x'))` in Eq7.
  inner_prod = clipped_quotient
  return jnp.squeeze(inner_prod)



@functools.partial(
    jax.pmap,
    axis_name='pmap',
    static_broadcasted_argnums=(0, 3, 10, 11, 12, 13, 15, 16, 17, 18, 21, 22, 25, 28))
def distributional_train(
    network_def, online_params, target_params, optimizer, optimizer_state,
    states, actions, next_states, rewards, terminals, cumulative_gamma,
    dr3_coefficient, cql_coefficient, sarsa_backups, next_actions,
    difficulty_loss_coefficient, adv_loss_version=0,
    use_spectral_normalization=False, use_parameter_version=False,
    task_ids=None, valid_game_actions=None,
    use_game_action_space=False,
    use_single_game_action_space=False, loss_weights=None,
    reward_scaling=None, use_bc_loss_only=False,
    support=None, rng=None,
    xent_but_not_distributional=False):
  """Create training loop for distributional RL."""
  rng_generator = jax_utils.JaxRNG(rng)

  def loss_fn(params, bellman_target, argmax_action):
    ret_dict = dict()

    def q_online(state, task, rng):
      return network_def.apply(params, state, task, support=support, rngs={'dropout': rng})

    model_output = jax.vmap(q_online, in_axes=(0, 0, None))(states, task_ids, rng_generator())
    representations = jnp.squeeze(model_output.representation)
    logits = jnp.squeeze(model_output.logits)
    chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions)
    chosen_action_q = jax.vmap(lambda x, y: x[y])(
        jnp.squeeze(model_output.q_values), actions)

    bellman_loss = jnp.mean(
        jax.vmap(losses.softmax_cross_entropy_loss_with_logits)(
          bellman_target, chosen_action_logits))

    next_states_model_output = jax.vmap(q_online, in_axes=(0, 0, None))(
        next_states, task_ids, rng_generator())
    next_state_representations = jnp.squeeze(
        next_states_model_output.representation)
    dr3_loss = compute_dr3_loss(representations, next_state_representations)
    norm_statistics_features = compute_feature_norms(
        representations, next_state_representations,
        prefix='features')

    # only image representations
    state_image_representation = jnp.squeeze(
        model_output.representation_before_task)
    next_state_image_representation = jnp.squeeze(
        next_states_model_output.representation_before_task)
    dr3_representation = compute_dr3_loss(
        state_image_representation, next_state_image_representation)
    norm_statistics_image = compute_feature_norms(
        state_image_representation, next_state_image_representation,
        prefix='image')

    # CQL loss
    q_values = jnp.squeeze(model_output.q_values)
    cql_loss = jnp.mean(
        jax.scipy.special.logsumexp(
            q_values, axis=-1)) - jnp.mean(chosen_action_q)
    orig_cql_loss = cql_loss

    adversarial_loss = 0.0
    adv_statistics = {}

    if use_bc_loss_only:
      loss = cql_loss
    else:
      loss = (
          bellman_loss + dr3_coefficient * dr3_loss + cql_coefficient * cql_loss
          + adversarial_loss
      )

    # This works out since one call comes from one game at a time only
    loss = loss * jnp.mean(loss_weights)

    q_vals_data = jnp.mean(chosen_action_q)
    q_vals_pi = jnp.mean(jnp.max(q_values, axis=-1))

    ret_dict = dict()

    # Return a dictionary so it is easy to add new validation statistics
    ret_dict['dr3_loss_features'] = dr3_loss
    ret_dict['td_loss'] = bellman_loss
    ret_dict['cql_loss'] = cql_loss
    ret_dict['q_vals_data'] = q_vals_data
    ret_dict['q_vals_pi'] = q_vals_pi
    ret_dict['adversarial_loss'] = adversarial_loss
    ret_dict['dr3_loss_image_rep'] = dr3_representation
    ret_dict.update(norm_statistics_features)
    ret_dict.update(norm_statistics_image)
    ret_dict.update(adv_statistics)
    return jnp.mean(loss), ret_dict

  def q_target(state, task, rng):
    return network_def.apply(target_params, state, task, rngs={'dropout': rng})

  # Doesn't need reward scaling since this is C51 and it would use logits
  # rewards = rewards / reward_scaling

  # Compute actual losses
  bellman_target, argmax_action = distributional_target_q(
      q_target, next_states, rewards,
      terminals, cumulative_gamma,
      sarsa_backups=sarsa_backups,
      next_actions=next_actions,
      td_backups=False,
      policy_probs=None,
      ret_argmax_action=True,
      task_ids=task_ids,
      valid_game_actions=valid_game_actions,
      use_single_game_action_space=use_single_game_action_space,
      support=support,
      rng_generator=rng_generator,
      xent_but_not_distributional=xent_but_not_distributional)

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, component_losses), grad = grad_fn(
      online_params, bellman_target, argmax_action)

  # pmean the losses and gradients here
  ret_loss = jax.lax.pmean(loss, axis_name='pmap')
  loss_std = jax.lax.pmean((loss - ret_loss) ** 2, axis_name='pmap')
  grad = jax.lax.pmean(grad, axis_name='pmap')
  component_losses = jax.lax.pmean(component_losses, axis_name='pmap')

  updates, optimizer_state = optimizer.update(grad, optimizer_state,
                                              params=online_params)
  online_params = optax.apply_updates(online_params, updates)
  return (optimizer_state, online_params, ret_loss, loss_std,
          component_losses, rng_generator())


def distributional_target_q(
    q_target, next_states, rewards, terminals, cumulative_gamma,
    sarsa_backups=False, next_actions=None, td_backups=False,
    policy_probs=None, ret_argmax_action=False, task_ids=None,
    valid_game_actions=None, use_single_game_action_space=False,
    support=None, rng_generator=None,
    xent_but_not_distributional=False):

  is_terminal_multiplier = 1. - terminals.astype(jnp.float32)
  # Incorporate terminal state to discount factor.
  gamma_with_terminal = cumulative_gamma * is_terminal_multiplier

  target_network_dist = jax.vmap(q_target, in_axes=(0, 0, None))(
      next_states, task_ids, rng_generator())
  next_state_target_outputs = target_network_dist

  # Compute the set of target q values
  target_q_values = jnp.squeeze(next_state_target_outputs.q_values)

  # Action selection using Q-values for next-state
  if sarsa_backups:
    replay_next_qt_max = jax.vmap(lambda x, y: x[y])(
        target_q_values, next_actions)
    argmax_action = next_actions
  elif td_backups:
    replay_next_qt_max = jnp.sum(target_q_values * policy_probs, axis=-1)
  elif valid_game_actions is not None and not use_single_game_action_space:
    print ('Valid game actions / task_ids', valid_game_actions.shape, task_ids.shape)
    valid_game_actions_temp = jnp.sum(
        jnp.expand_dims(
            valid_game_actions, 0) * jnp.expand_dims(task_ids, -1),
        1)
    replay_next_qt_max = jnp.max(
        target_q_values + 1000 * valid_game_actions_temp, 1)
    argmax_action = jnp.argmax(
        target_q_values + 1000 * valid_game_actions_temp, 1)
  else:
    replay_next_qt_max = jnp.max(target_q_values, 1)
    argmax_action = jnp.argmax(target_q_values, 1)

  probabilities = jnp.squeeze(target_network_dist.probabilities)
  next_probabilities = jax.vmap(lambda x, y: x[y])(
      probabilities, argmax_action)

  target_values_to_use = jnp.squeeze(
      jax.vmap(lambda x, y: x[y])(target_q_values, argmax_action))

  if xent_but_not_distributional:
    target_support = rewards[:, None] + gamma_with_terminal[:, None] * target_values_to_use[:, None]
  else:
    target_support = rewards[:, None] + gamma_with_terminal[:, None] * support[None, :]

  if xent_but_not_distributional:
    print ('Using cross entropy but not distributional')
    target = jax.vmap(project_distribution, in_axes=(0, 0, None))(
        target_support, next_probabilities, support)
  else:
    print ('Using distributional....')
    target = jax.vmap(rainbow_agent.project_distribution, in_axes=(0, 0, None))(
        target_support, next_probabilities, support)

  return jax.lax.stop_gradient(target), argmax_action



@functools.partial(
    jax.pmap,
    axis_name='pmap',
    static_broadcasted_argnums=(0, 3, 10, 11, 12, 13, 15, 16, 17, 18, 21, 22, 25))
def train(network_def, online_params, target_params, optimizer, optimizer_state,
          states, actions, next_states, rewards, terminals, cumulative_gamma,
          dr3_coefficient, cql_coefficient, sarsa_backups, next_actions,
          difficulty_loss_coefficient, adv_loss_version=0,
          use_spectral_normalization=False, use_parameter_version=False,
          task_ids=None, valid_game_actions=None,
          use_game_action_space=False,
          use_single_game_action_space=False, loss_weights=None,
          reward_scaling=None, use_bc_loss_only=False, rng=None):
  """Run the training step."""
  rng_generator = jax_utils.JaxRNG(rng)

  # Loss function for the Q-function
  def loss_fn(params, bellman_target, argmax_action):
    ret_dict = dict()
    def q_online(state, task, rng):
      if use_spectral_normalization:
        return network_def.apply(params, state, task,
                                 mutable=['spectral_stats'],
                                 training=True,
                                 rngs={'dropout': rng})
      else:
        return network_def.apply(params, state, task, rngs={'dropout': rng})

    if use_spectral_normalization:
      model_output, new_model_spectral_state = jax.vmap(
          q_online, in_axes=(0, 0, None))(states, task_ids)
    else:
      model_output = jax.vmap(q_online, in_axes=(0, 0, None))(
          states, task_ids, rng_generator())
      new_model_spectral_state = {}
    # Feature-based DR3
    representations = jnp.squeeze(model_output.representation)

    if use_spectral_normalization:
      next_states_model_output, _ = jax.vmap(q_online, in_axes=(0, 0, None))(
          next_states, task_ids, rng_generator())
    else:
      next_states_model_output = jax.vmap(q_online, in_axes=(0, 0, None))(
          next_states, task_ids, rng_generator())
    next_state_representations = jnp.squeeze(
        next_states_model_output.representation)
    dr3_loss = compute_dr3_loss(representations, next_state_representations)
    norm_statistics_features = compute_feature_norms(
        representations, next_state_representations,
        prefix='features')

    # only image representations
    state_image_representation = jnp.squeeze(
        model_output.representation_before_task)
    next_state_image_representation = jnp.squeeze(
        next_states_model_output.representation_before_task)
    dr3_representation = compute_dr3_loss(
        state_image_representation, next_state_image_representation)
    norm_statistics_image = compute_feature_norms(
        state_image_representation, next_state_image_representation,
        prefix='image')

    # Q-learning loss
    q_values = jnp.squeeze(model_output.q_values)
    replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions)

    bellman_loss = jnp.mean(
        jax.vmap(losses.huber_loss)(bellman_target, replay_chosen_q))

    cql_loss = jnp.mean(
        jax.scipy.special.logsumexp(
            q_values, axis=-1)) - jnp.mean(replay_chosen_q)
    orig_cql_loss = cql_loss

    adversarial_loss = 0.0
    adv_statistics = {}

    if use_bc_loss_only:
      loss = cql_loss
    else:
      loss = (
          bellman_loss + dr3_coefficient * dr3_loss + cql_coefficient * cql_loss
          + adversarial_loss
      )


    # This works out since one call comes from one game at a time only
    loss = loss * jnp.mean(loss_weights)

    q_vals_data = jnp.mean(replay_chosen_q)
    q_vals_pi = jnp.mean(jnp.max(q_values, axis=1))

    ret_dict = dict()

    # Return a dictionary so it is easy to add new validation statistics
    ret_dict['dr3_loss_features'] = dr3_loss
    ret_dict['td_loss'] = bellman_loss
    ret_dict['cql_loss'] = cql_loss
    ret_dict['q_vals_data'] = q_vals_data
    ret_dict['q_vals_pi'] = q_vals_pi
    ret_dict['adversarial_loss'] = adversarial_loss
    ret_dict['dr3_loss_image_rep'] = dr3_representation
    ret_dict.update(norm_statistics_features)
    ret_dict.update(norm_statistics_image)
    ret_dict.update(adv_statistics)
    return jnp.mean(loss), (ret_dict, new_model_spectral_state)

  def q_target(state, task, rng):
    if use_spectral_normalization:
      return network_def.apply(target_params, state, task,
                               mutable=['spectral_stats'], training=False,
                               rngs={'dropout': rng})[0]
    else:
      return network_def.apply(target_params, state, task,
                               rngs={'dropout': rng})

  # Scale rewards here
  rewards = rewards / reward_scaling

  # Compute actual losses
  bellman_target, argmax_action = target_q(q_target, next_states, rewards,
                            terminals, cumulative_gamma,
                            sarsa_backups=sarsa_backups,
                            next_actions=next_actions,
                            td_backups=False,
                            policy_probs=None,
                            ret_argmax_action=True,
                            task_ids=task_ids,
                            valid_game_actions=valid_game_actions,
                            use_single_game_action_space=use_single_game_action_space,
                            rng_generator=rng_generator)

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (component_losses, spectral_state)), grad = grad_fn(
      online_params, bellman_target, argmax_action)

  # pmean the losses and gradients here
  ret_loss = jax.lax.pmean(loss, axis_name='pmap')
  loss_std = jax.lax.pmean((loss - ret_loss) ** 2, axis_name='pmap')
  grad = jax.lax.pmean(grad, axis_name='pmap')
  component_losses = jax.lax.pmean(component_losses, axis_name='pmap')
  spectral_state = jax.lax.pmean(spectral_state, axis_name='pmap')

  updates, optimizer_state = optimizer.update(grad, optimizer_state,
                                              params=online_params)
  online_params = optax.apply_updates(online_params, updates)
  if use_spectral_normalization:
    online_params = flax.core.frozen_dict.unfreeze(online_params)
    new_spectral_state = jax.tree_map(lambda x: x[-1, Ellipsis], spectral_state)
    online_params['spectral_stats'] = new_spectral_state['spectral_stats']
    online_params = flax.core.frozen_dict.freeze(online_params)
  return (optimizer_state, online_params, ret_loss, loss_std,
          component_losses, rng_generator())


def target_q(target_network, next_states, rewards, terminals, cumulative_gamma,
             task_ids=None,
             sarsa_backups=None, next_actions=None,
             td_backups=False, policy_probs=None,
             ret_argmax_action=False,
             without_stop_gradient=False,
             valid_game_actions=None,
             use_single_game_action_space=False,
             rng_generator=None):
  """Compute the target Q-value."""
  q_vals = jax.vmap(target_network, in_axes=(0, 0, None))(
      next_states, task_ids, rng_generator()).q_values
  q_vals = jnp.squeeze(q_vals)
  if sarsa_backups:
    replay_next_qt_max = jax.vmap(lambda x, y: x[y])(q_vals, next_actions)
    argmax_action = next_actions
  elif td_backups:
    replay_next_qt_max = jnp.sum(q_vals * policy_probs, axis=-1)
  elif valid_game_actions is not None and not use_single_game_action_space:
    print ('Valid game actions / task_ids', valid_game_actions.shape, task_ids.shape)
    valid_game_actions_temp = jnp.sum(
        jnp.expand_dims(
            valid_game_actions, 0) * jnp.expand_dims(task_ids, -1),
        1)
    replay_next_qt_max = jnp.max(q_vals + 1000 * valid_game_actions_temp, 1)
    argmax_action = jnp.argmax(q_vals + 1000 * valid_game_actions_temp, 1)
  else:
    replay_next_qt_max = jnp.max(q_vals, 1)
    argmax_action = jnp.argmax(q_vals, 1)

  target_val = rewards + cumulative_gamma * replay_next_qt_max * (1. - terminals)
  if not without_stop_gradient:
    target_val = jax.lax.stop_gradient(target_val)

  if ret_argmax_action:
    return target_val, argmax_action
  return target_val


def compute_dr3_loss(state_representations, next_state_representations):
  """Minimizes dot product between state and next state representations."""
  dot_products = jnp.einsum(
      'ij,ij->i', state_representations, next_state_representations)
  # Minimize |\phi(s) \phi(s')|
  return jnp.mean(jnp.abs(dot_products))


@gin.configurable
class OfflineMultiTaskJaxCQLAgent(multi_task_offline_dqn_agent.OfflineMultiTaskJaxDQNAgent):
  """A JAX implementation of the Offline DQN agent."""

  def __init__(self,
               num_actions,
               replay_data_dir,
               dr3_coefficient=0.0,
               cql_coefficient=0.0,
               data_augmentation=False,
               difficulty_loss_coefficient=0.0,
               adv_loss_feature_version=0,
               use_sarsa_backups=False,
               summary_writer=None,
               replay_buffer_builder=None,
               preprocess_fn=None,
               network=None,
               game_names=('Asterix',),
               num_devices=1,
               use_game_action_space=False,
               use_tfds_data=False,
               use_single_game_action_space=False,
               network_type='IMPALA',
               override_num_games=-1,
               override_game_index=-1,
               with_task_ids=False,
               use_bc_loss_only=False,
               use_distributional=False,
               num_atoms=51,
               vmax=20, vmin=None,
               use_xent_not_distributional=False):
    """Initializes the agent and constructs the necessary components.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      replay_data_dir: str, log Directory from which to load the replay buffer.
      summary_writer: SummaryWriter object for outputting training statistics
      replay_buffer_builder: Callable object that takes "self" as an argument
        and returns a replay buffer to use for training offline. If None,
        it will use the default FixedReplayBuffer.
    """
    logging.info('Creating %s agent with the following parameters:',
                 self.__class__.__name__)
    logging.info('\t replay directory: %s', replay_data_dir)
    logging.info('\t CQL coefficient: %s', cql_coefficient)
    logging.info('\t dr3_coefficient: %s', dr3_coefficient)
    logging.info('\t SARSA backups: %s', use_sarsa_backups)
    logging.info('\t Data augmentation: %s', data_augmentation)
    logging.info('\t Adversarial loss feature version: %s',
                 adv_loss_feature_version)
    logging.info('\t Difficulty Loss coefficient: %s',
                 difficulty_loss_coefficient)
    logging.info('\t Use tfds data: %s', use_tfds_data)
    logging.info('\t Use BC loss only: %s', use_bc_loss_only)
    logging.info('\t Use distributional losses: %s', use_distributional)
    logging.info('\t Use distributional, not xent: %s', use_xent_not_distributional)

    self.replay_data_dir = replay_data_dir
    print ('Replay data dir for the CQL agent: ', self.replay_data_dir)
    if replay_buffer_builder is not None:
      self._build_replay_buffer = replay_buffer_builder

    self._data_augmentation = data_augmentation
    self._difficulty_loss_coefficient = difficulty_loss_coefficient
    self._adv_loss_version = adv_loss_feature_version
    self._dr3_coefficient = dr3_coefficient
    self._cql_coefficient = cql_coefficient
    self._sarsa_backups = use_sarsa_backups
    self._use_game_action_space = use_game_action_space
    self._use_tfds_data = use_tfds_data
    self._use_bc_loss_only = use_bc_loss_only
    self._use_distributional = use_distributional
    self._xent_not_distributional = use_xent_not_distributional

    self.train_preprocess_fn = functools.partial(
        preprocess_inputs_with_augmentation,
        data_augmentation=data_augmentation)

    self._network_type = network_type

    if self._use_distributional:
      self._num_atoms = num_atoms
      self._vmin = vmin if vmin is not None else -vmax
      self._vmax = vmax
      self._support = jnp.linspace(self._vmin, self._vmax, num_atoms)
      print ('Using distributional agent......')
      print ('Num atoms: ', self._num_atoms, self._vmin, self._vmax)
      print ('Xent but not distributional', self._xent_not_distributional)

      self._network_type = network_type + '_distributional'

    try:
      network = {
          'IMPALA': networks.ImpalaNetworkWithRepresentations,
          'MultiHeadIMPALA': networks.MultiHeadIMPALA,
          'ResNet18': large_networks.ResNet18,
          'ResNet34': large_networks.ResNet34,
          'ResNet50': large_networks.ResNet50,
          'ResNet101': large_networks.ResNet101,
          'ResNet152': large_networks.ResNet152,
          'ViTSmall': large_networks.ViTSmall,
          'ViTBase': large_networks.ViTBase,
          'ViTLarge': large_networks.ViTLarge,
          'ViTHuge': large_networks.ViTHuge,
          'small': networks.JAXDQNNetworkWithRepresentations,
          'MultiHeadResNet18': large_networks.MultiHeadResNet18,
          'MultiHeadResNet34': large_networks.MultiHeadResNet34,
          'MultiHeadResNet50': large_networks.MultiHeadResNet50,
          'MultiHeadResNet101': large_networks.MultiHeadResNet101,
          'MultiHeadResNet152': large_networks.MultiHeadResNet152,
          'MultiHeadResNet18Aux': large_networks.MultiHeadResNet18Aux,
          'MultiHeadResNet34Aux': large_networks.MultiHeadResNet34Aux,
          'MultiHeadResNet50Aux': large_networks.MultiHeadResNet50Aux,
          'MultiHeadResNet18Att': large_networks.MultiHeadResNet18Att,
          'MultiHeadResNet34Att': large_networks.MultiHeadResNet34Att,
          'MultiHeadResNet50Att': large_networks.MultiHeadResNet50Att,
          'MultiHeadResNet101Att': large_networks.MultiHeadResNet101Att,
          'MultiHeadResNet152Att': large_networks.MultiHeadResNet152Att,
          'MultiHeadResNet18Single': large_networks.MultiHeadResNet18Single,
          'MultiHeadResNet34Single': large_networks.MultiHeadResNet34Single,
          'MultiHeadResNet50Single': large_networks.MultiHeadResNet50Single,
          'MultiHeadResNet101Single': large_networks.MultiHeadResNet101Single,
          'MultiHeadResNet152Single': large_networks.MultiHeadResNet152Single,
          'MultiHeadResNet34SingleWide': large_networks.MultiHeadResNet34SingleWide,
          'MultiHeadResNet50SingleWide': large_networks.MultiHeadResNet50SingleWide,

          # ViT ones
          'MultiHeadViTSmall': large_networks.MultiHeadViTSmall,
          'MultiHeadViTBase': large_networks.MultiHeadViTBase,
          'MultiHeadViTSmallSingle': large_networks.MultiHeadViTSmallSingle,
          'MultiHeadViTBaseSingle': large_networks.MultiHeadViTBaseSingle,
          'MultiHeadViTLargeSingle': large_networks.MultiHeadViTLargeSingle,

          # Discretized networks
          'MultiHeadResNet18Single_distributional': large_networks.DistributionalMultiHeadResNet18Single,
          'MultiHeadResNet34Single_distributional': large_networks.DistributionalMultiHeadResNet34Single,
          'MultiHeadResNet50Single_distributional': large_networks.DistributionalMultiHeadResNet50Single,
          'MultiHeadResNet101Single_distributional': large_networks.DistributionalMultiHeadResNet101Single,
          'MultiHeadViTSmallSingle_distributional': large_networks.DistributionalMultiHeadViTSmallSingle,
          'MultiHeadViTBaseSingle_distributional': large_networks.DistributionalMultiHeadViTBaseSingle,
          'MultiHeadViTLargeSingle_distributional': large_networks.DistributionalMultiHeadViTLargeSingle,

          'MultiHeadResNet101SingleWide_distributional': large_networks.DistributionalMultiHeadResNet101SingleWide,
          'MultiHeadResNet101SingleWider_distributional': large_networks.DistributionalMultiHeadResNet101SingleWider,

          'MultiHeadResNet152Single_distributional': large_networks.DistributionalMultiHeadResNet152Single,

          # Old multi-head resnets
          'OldMultiHeadResNet34Single': large_networks.OldMultiHeadResNet34Single,
          'OldMultiHeadResNet50Single': large_networks.OldMultiHeadResNet50Single,

      }[self._network_type]
    except KeyError:
      raise ValueError('Network type not found!')


    self._game_names = game_names
    self._num_devices = num_devices
    print ('Multi task offline CQL: ', self._game_names, self._num_devices)
    super().__init__(
        num_actions, summary_writer=summary_writer,
        # Don't pass none here since, then the network would try to preprocess
        # inputs on its own which will interfere with data augmentation
        preprocess_fn=preprocess_inputs_with_augmentation,
        replay_data_dir=replay_data_dir,
        game_names=self._game_names,
        num_devices=self._num_devices,
        network=network,
        use_single_game_action_space=False,
        override_game_index=override_game_index,
        override_num_games=override_num_games,
        with_task_ids=with_task_ids)

    # Game array in jnp
    valid_actions_np = onp.tile(
        onp.array(onp.log(
            self.game_valid_actions))[None], (self._num_devices, 1, 1))
    self._valid_actions_jnp = self.put_device_sharded(tensors=(
        valid_actions_np,))[0]

    if self._use_distributional:
      tile_support = jnp.tile(self._support[None, :], (self._num_devices, 1))
      self._support = self.put_device_sharded(tensors=(tile_support,))[0]

  def _build_replay_buffer(self):
    """Creates the fixed replay buffer used by the agent."""
    if self._use_tfds_data:
      replay_buffer = multi_task_tfds.JaxMultiTaskFixedReplayBufferTFDS(
        observation_shape=self.observation_shape,
        stack_size=self.stack_size,
        update_horizon=self.update_horizon,
        gamma=self.gamma,
        observation_dtype=self.observation_dtype,
        game_names=self._game_names,
        num_devices=self._num_devices,
        use_single_game_action_space=False,
        num_games=len(self._game_names),
      )
      self._replay_iterator = replay_buffer.get_iterator_with_jax_parallel()
      return replay_buffer
    else:
      return multi_game_fixed_replay.JaxMultiTaskFixedReplayBuffer(
          data_dir=self.replay_data_dir,
          observation_shape=self.observation_shape,
          stack_size=self.stack_size,
          update_horizon=self.update_horizon,
          gamma=self.gamma,
          observation_dtype=self.observation_dtype,
          game_names=self._game_names,
          num_devices=self._num_devices,
          use_single_game_action_space=False)

  def reload_data(self):
    self._replay.reload_data()

  def _sample_from_replay_buffer(self):
    if self._use_tfds_data:
      self.replay_elements = next(self._replay_iterator)
    else:
      raise RuntimeError("Using old replay buffer.... unexpected.")
      super()._sample_from_replay_buffer()

  def append_to_summary_list(self, ret_dict, step, prefix=''):
    """Function used to append data to the summary_list."""
    for key in ret_dict:
      if self._num_devices > 1:
        value = ret_dict[key][0]
      else:
        value = ret_dict[key]

      with self.summary_writer.as_default():
        tf.summary.scalar(prefix + '/' + key, onp.squeeze(value), step=step)

  def _train_step(self):
    """Runs a single training step."""
    if self.training_steps % self.update_period == 0:
      self._sample_from_replay_buffer()
      self._rng, rng1, rng2 = jax.random.split(self._rng, num=3)

      # Can also distribute this right at the beginning, especially when
      # data augmentation is used
      states = self.train_preprocess_fn(
          self.replay_elements['state'], rng=rng1)
      next_states = self.train_preprocess_fn(
          self.replay_elements['next_state'], rng=rng2)

      if self._use_tfds_data:
        """tfds buffer already pre-fetches the data on jax devices."""
        actions = self.replay_elements['action']
        rewards = self.replay_elements['reward']
        terminals = self.replay_elements['terminal']
        task_ids = self.replay_elements['task_id']
        next_actions = self.replay_elements['next_action']
        loss_weights = self.replay_elements['loss_weight']
        reward_scaling = self.replay_elements['reward_scaling']
      else:
        (states, actions, next_states, rewards,\
        terminals, task_ids, next_actions) = self.put_device_sharded(
            tensors=(
                    states,
                    self.replay_elements['action'],
                    next_states,
                    self.replay_elements['reward'],
                    self.replay_elements['terminal'],
                    self.replay_elements['task_id'],
                    self.replay_elements['next_action']))

      if not self._use_distributional:
        self.optimizer_state, self.online_params, loss, var_loss, ret_dict, self._sharded_rng = train(
            self.network_def,
            self.online_params,
            self.target_network_params,
            self.optimizer,
            self.optimizer_state,
            states,
            actions,
            next_states,
            rewards,
            terminals,
            self.cumulative_gamma,
            self._dr3_coefficient,
            self._cql_coefficient,
            self._sarsa_backups,
            next_actions,
            self._difficulty_loss_coefficient,
            self._adv_loss_version,
            False, False,
            task_ids,
            self._valid_actions_jnp,
            self._use_game_action_space,
            False,
            loss_weights,
            reward_scaling,
            self._use_bc_loss_only,
            self._sharded_rng)
      else:
        self.optimizer_state, self.online_params, loss, var_loss, ret_dict, self._sharded_rng = distributional_train(
            self.network_def,
            self.online_params,
            self.target_network_params,
            self.optimizer,
            self.optimizer_state,
            states,
            actions,
            next_states,
            rewards,
            terminals,
            self.cumulative_gamma,
            self._dr3_coefficient,
            self._cql_coefficient,
            self._sarsa_backups,
            next_actions,
            self._difficulty_loss_coefficient,
            self._adv_loss_version,
            False, False,
            task_ids,
            self._valid_actions_jnp,
            self._use_game_action_space,
            False,
            loss_weights,
            reward_scaling,
            self._use_bc_loss_only,
            self._support,
            self._sharded_rng,
            self._xent_not_distributional)

      loss_std = jnp.std(loss)
      var_loss = jnp.mean(var_loss)

      if (self.summary_writer is not None and
          self.training_steps > 0 and
          self.training_steps % (20 * self.summary_writing_frequency) == 0):

        loss_std = jnp.std(loss)
        var_loss = jnp.mean(var_loss)

        if jax.process_index() == 0:
          self.append_to_summary_list(ret_dict,
                                      step=self.training_steps,
                                      prefix='Training')

        if self.training_steps % (100 * self.summary_writing_frequency) == 0 and False:
          # gradient_dict = compute_gradient_metrics(
          #       self.network_def, self.online_params, states,
          #       actions, next_states, task_ids=task_ids)
          training_difficulty_dict = compute_trainability_metrics(
              self.network_def,
              self.online_params,
              self.target_network_params,
              states,
              actions,
              next_states,
              rewards,
              terminals,
              self.cumulative_gamma,
              next_actions,
              self._sarsa_backups,
              task_ids)

          # Log gradient dict and training_difficulty_dict
          # self.append_to_summary_list(gradient_dict, step=self.training_steps,
          #                             prefix='Training')
          if jax.process_index() == 0:
            self.append_to_summary_list(training_difficulty_dict,
                                        step=self.training_steps,
                                        prefix='Training')

        if jax.process_index() == 0:
          self.append_to_summary_list(
            {'pmap_loss_correctness' : [loss_std,]},
            step=self.training_steps, prefix='Training'
          )

          self.append_to_summary_list(
              {'variance_of_loss': [var_loss,]},
              step=self.training_steps, prefix='Training'
          )

    if self.training_steps % self.target_update_period == 0:
      self.target_network_params = multi_task_dqn_agent._sync_weights_local(
          self.online_params)

    self.training_steps += 1

  def step(self, reward, observation, game_index=None):
    """Returns the agent's next action and update agent's state.

    Args:
      reward: float, the reward received from the agent's most recent action.
      observation: numpy array, the most recent observation.

    Returns:
      int, the selected action.
    """
    self._record_observation(observation, game_index=game_index)
    state = self.preprocess_fn(self.state)

    game_valid_actions_local = onp.sum(
        onp.array(onp.log(self.game_valid_actions)) *\
        self.game_index[Ellipsis, None],
        axis=0)

    # if self._use_distributional:
    #   support_args = {'support' : self._support}
    # else:
    #   support_args = {}

    self._rng, self.action = select_action(
        self.network_def, self.online_params, state, self._rng,
        self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
        self.epsilon_decay_period, self.training_steps, self.min_replay_history,
        self.epsilon_fn, game_index=self.game_index_for_network,
        game_valid_actions=game_valid_actions_local,
        use_single_game_action_space=False)
    self.action = onp.asarray(self.action)
    return self.action

  def begin_episode(self, observation, game_index=None):
    self._reset_state()
    self._record_observation(observation, game_index=game_index)
    state = self.preprocess_fn(self.state)

    game_valid_actions_local = onp.sum(
        onp.array(onp.log(self.game_valid_actions)) *\
        self.game_index[Ellipsis, None],
        axis=0)

    # if self._use_distributional:
    #   support_args = {'support' : self._support}
    # else:
    #   support_args = {}

    self._rng, self.action = select_action(
        self.network_def, self.online_params, state,
        self._rng, self.num_actions, self.eval_mode,
        self.epsilon_eval, self.epsilon_train, self.epsilon_decay_period,
        self.training_steps, self.min_replay_history, self.epsilon_fn,
        game_index=self.game_index_for_network,
        game_valid_actions=game_valid_actions_local,
        use_single_game_action_space=False)
    self.action = onp.asarray(self.action)
    return self.action

  def train_step(self):
    """Exposes the train step for offline learning."""
    self._train_step()
