"""Agent learner."""
import functools
import time
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple

import acme
import chex
import distrax
import haiku as hk
import jax
import jax.numpy as jnp
import mctx
import optax
import rlax
import tree
from absl import logging
from acme.jax import networks as networks_lib
from acme.jax import utils
from acme.utils import loggers

from rosmo.agent.muzero.network import Networks
from rosmo.agent.muzero.types import AgentOutput, Params
from rosmo.agent.muzero.utils import (
  inv_value_transform,
  logits_to_scalar,
  scalar_to_two_hot,
  scale_gradient,
  value_transform,
)
from rosmo.types import ActorOutput, Array

_DEBUG_LOSS = False


class TrainingState(NamedTuple):
  optimizer_state: optax.OptState
  params: Params
  target_params: Params


class MuZeroLearner(acme.core.Learner):
  """MuZero learner."""

  def __init__(
    self,
    networks: Networks,
    demonstrations: Iterator[ActorOutput],
    config: Dict,
    logger: Optional[loggers.Logger] = None,
  ) -> None:
    discount_factor = config["discount_factor"]
    weight_decay = config["weight_decay"]
    value_coef = config["value_coef"]
    behavior_coef = config["behavior_coef"]
    policy_coef = config["policy_coef"]
    unroll_steps = config["unroll_steps"]
    td_steps = config["td_steps"]
    clipping_threshold = config["clipping_threshold"]
    target_update_interval = config["target_update_interval"]
    log_interval = config["log_interval"]
    sampling_method = config["sampling"]
    pessimism = config["pessimism"]
    num_pessimism_samples = config["num_pessimism_samples"]
    pessimism_weight = config["pessimism_weight"]
    batch_size = config["batch_size"]
    use_bc = config["use_bc"]
    use_qf = config["use_qf"]
    max_grad_norm = config["max_grad_norm"]
    num_simulations = config["num_simulations"]
    search_depth = config["search_depth"]
    improvement_op = config["improvement_op"]
    safe_q_values = config["safe_q_values"]
    dynamics_noise = config["dynamics_noise"]
    num_bins = config["num_bins"]
    behavior = config.get("behavior", "")
    value_no_search = config.get("value_no_search", False)

    _batch_categorical_cross_entropy = jax.vmap(rlax.categorical_cross_entropy)

    if dynamics_noise:
      logging.info(f"[Learning] Injecting dynamics noise {dynamics_noise}.")
    if use_bc:

      def loss(
        params: Params,
        target_params: Params,
        trajectory: ActorOutput,
        rng_key: networks_lib.PRNGKey,
      ):
        del target_params, rng_key
        logging.info("[Learning] Use behavior cloning.")
        state = networks.representation_network.apply(
          params.representation, trajectory.observation
        )  # [T, S]
        learner_root: AgentOutput = root_unroll(
          networks, params, num_bins, state
        )
        num_actions = learner_root.policy_logits.shape[-1]
        logits = jnp.squeeze(learner_root.policy_logits)
        target = jnp.squeeze(trajectory.action)

        policy_loss = jnp.mean(
          _batch_categorical_cross_entropy(
            jax.nn.one_hot(target, num_actions), logits
          )
        )
        total_loss = policy_coef * policy_loss
        log = {
          "policy_entropy":
            -rlax.entropy_loss(logits, jnp.ones(logits.shape[:-1])),
          "policy_loss":
            policy_loss,
          "total_loss":
            total_loss,
        }
        return total_loss, log

    elif use_qf:

      def loss(
        params: Params,
        target_params: Params,
        trajectory: ActorOutput,
        rng_key: networks_lib.PRNGKey,
      ):
        del rng_key
        logging.info("[Learning] Use learnt Q for policy target.")
        # Encode obs via learning and target networks, [T, S]
        state = networks.representation_network.apply(
          params.representation, trajectory.observation
        )
        target_state = networks.representation_network.apply(
          target_params.representation, trajectory.observation
        )

        learner_out: AgentOutput = root_unroll(
          networks, params, num_bins, state
        )

        def dynamics(
          params_transition: Array, state: Array, action: Array
        ) -> AgentOutput:
          # NOTE To maintain similar capacity and reuse networks,
          # we simply use value prediction of next state to estimate q value.
          next_state = networks.transition_network.apply(
            params_transition, action[None], state
          )
          next_state = jax.tree_map(lambda s: s[None], next_state)
          (
            _,
            _,
            q_val_logits,
          ) = networks.prediction_network.apply(params.prediction, next_state)
          q_val = logits_to_scalar(q_val_logits, num_bins)
          q_val = inv_value_transform(q_val)

          return AgentOutput(
            state=next_state,
            policy_logits=None,
            reward_logits=None,
            reward=None,
            value_logits=jnp.squeeze(q_val_logits),
            value=jnp.squeeze(q_val),
          )

        s_tm1 = state[:-1]
        a_tm1 = trajectory.action[:-1]
        out_tm1: AgentOutput = jax.vmap(dynamics, (None, 0, 0))(
          params.transition,
          s_tm1,
          a_tm1,
        )
        s_t = target_state[1:]
        a_t = trajectory.action[1:]
        target_out_t: AgentOutput = jax.vmap(dynamics, (None, 0, 0))(
          target_params.transition,
          s_t,
          a_t,
        )

        q_val = out_tm1.value
        q_val_logits = out_tm1.value_logits

        # Construct q value target.
        reward = trajectory.reward[:-1]

        next_q_val = target_out_t.value
        zero_return_mask = jnp.cumprod(1.0 - trajectory.is_last) == 0.0
        next_q_val_masked = jax.lax.select(
          zero_return_mask[:-1],
          jnp.zeros_like(next_q_val),
          next_q_val,
        )

        q_target = reward + discount_factor * next_q_val_masked
        q_target_transformed = value_transform(q_target)
        q_target_logits = scalar_to_two_hot(q_target_transformed, num_bins)

        # Construct policy target.
        num_actions = learner_out.policy_logits.shape[-1]

        def _compute_q_policy_target(model_root: AgentOutput):
          tau = 0.1
          all_actions = jnp.arange(num_actions)
          model_one_step_out: AgentOutput = model_simulate(
            networks, target_params, num_bins, model_root.state, all_actions
          )
          pi_prior = jax.nn.softmax(model_root.policy_logits)
          q_vals = model_one_step_out.value
          chex.assert_equal_shape([pi_prior, q_vals])
          chex.assert_rank(q_vals, 1)
          q_vals = q_vals / jnp.sum(q_vals)
          # q_val_min = jnp.min(q_vals)
          # q_val_max = jnp.max(q_vals)
          # normalized_qs = (q_vals - q_val_min) / (q_val_max - q_val_min)
          pi_improved = pi_prior * jnp.exp(q_vals / tau)
          pi_improved = pi_improved / jnp.sum(pi_improved)
          return pi_improved

        target_out: AgentOutput = root_unroll(
          networks, target_params, num_bins, target_state
        )
        policy_target = jax.vmap(_compute_q_policy_target, (0))(target_out)
        uniform_policy = jnp.ones_like(policy_target) / num_actions
        random_policy_mask = jnp.cumprod(1.0 - trajectory.is_last) == 0.0
        random_policy_mask = jnp.broadcast_to(
          random_policy_mask[:, None], policy_target.shape
        )
        policy_target = jax.lax.select(
          random_policy_mask, uniform_policy, policy_target
        )
        policy_target = jax.lax.stop_gradient(policy_target)

        # Compute losses.
        q_val_loss = jnp.mean(
          _batch_categorical_cross_entropy(q_target_logits, q_val_logits)
        )
        policy_loss = jnp.mean(
          _batch_categorical_cross_entropy(
            policy_target, learner_out.policy_logits
          )
        )
        total_loss = q_val_loss + policy_loss

        policy_entropy = jax.vmap(
          lambda l: distrax.Categorical(logits=l).entropy()
        )(
          learner_out.policy_logits
        )

        log = {
          "q_val_target": q_target,
          "q_val_prediction": q_val,
          "q_val_loss": q_val_loss,
          "policy_entropy": policy_entropy,
          "policy_loss": policy_loss,
          "total_loss": total_loss,
        }

        return total_loss, log

    else:

      def loss(
        params: Params,
        target_params: Params,
        trajectory: ActorOutput,
        rng_key: networks_lib.PRNGKey,
      ):
        logging.info("[Learning] Use Model-based RL.")
        # Encode obs via learning and target networks, [T, S]
        state = networks.representation_network.apply(
          params.representation, trajectory.observation
        )
        target_state = networks.representation_network.apply(
          target_params.representation, trajectory.observation
        )

        # 1) Model unroll, sampling and estimation.
        # ts = jax.tree_map(lambda t: t[:1], trajectory)
        root_state = jax.tree_map(lambda t: t[:1], state)
        learner_root = root_unroll(networks, params, num_bins, root_state)
        learner_root: AgentOutput = jax.tree_map(lambda t: t[0], learner_root)

        unroll_trajectory: ActorOutput = jax.tree_map(
          lambda t: t[:unroll_steps + 1], trajectory
        )
        random_action_mask = (
          jnp.cumprod(1.0 - unroll_trajectory.is_first[1:]) == 0.0
        )
        action_sequence = unroll_trajectory.action[:unroll_steps]
        num_actions = learner_root.policy_logits.shape[-1]
        rng_key, action_key = jax.random.split(rng_key)
        random_actions = jax.random.choice(
          action_key, num_actions, action_sequence.shape, replace=True
        )
        simulate_action_sequence = jax.lax.select(
          random_action_mask, random_actions, action_sequence
        )

        model_out: AgentOutput = model_unroll(
          networks,
          params,
          num_bins,
          learner_root.state,
          simulate_action_sequence,
        )

        # Model predictions
        policy_logits = jnp.concatenate(
          [
            learner_root.policy_logits[None],
            model_out.policy_logits,
          ],
          axis=0,
        )

        value_logits = jnp.concatenate(
          [
            learner_root.value_logits[None],
            model_out.value_logits,
          ],
          axis=0,
        )

        # 2) Model learning targets.
        # a) Reward.
        rewards = trajectory.reward
        reward_target = jax.lax.select(
          random_action_mask,
          jnp.zeros_like(rewards[:unroll_steps]),
          rewards[:unroll_steps],
        )
        reward_target_transformed = value_transform(reward_target)
        reward_logits_target = scalar_to_two_hot(
          reward_target_transformed, num_bins
        )

        # b) Policy.
        target_roots: AgentOutput = root_unroll(
          networks, target_params, num_bins, target_state
        )
        search_roots: AgentOutput = jax.tree_map(
          lambda t: t[:unroll_steps + 1], target_roots
        )
        rng_key, improve_key = jax.random.split(rng_key)

        improve_adv = 0.0
        if improvement_op in ["mcts", "mcts_mpo"]:
          logging.info("[Improvement] Use Monte-Carlo Tree Search.")
          mcts_out = mcts_improve(
            networks,
            improve_key,
            target_params,
            num_bins,
            target_roots,
            discount_factor,
            num_simulations,
            search_depth,
            dynamics_noise,
          )
          if improvement_op == "mcts":
            policy_target = mcts_out.action_weights[:unroll_steps + 1]
          else:
            logging.info(
              "[Improvement] Compute mpo-style policy target using Qs."
            )
            node_index = mctx.Tree.ROOT_INDEX
            mcts_mpo_tau = 0.1

            def _compute_q_policy_target(
              mctx_policy_out: mctx.PolicyOutput, node_index: chex.Numeric
            ) -> Array:
              normalized_qs = qtransform_by_parent_and_siblings(
                mctx_policy_out.search_tree, node_index, safe_q_values
              )
              prior_logits = (
                mctx_policy_out.search_tree.children_prior_logits[node_index]
              )
              chex.assert_equal_shape([normalized_qs, prior_logits])
              pi_prior = jax.nn.softmax(prior_logits)
              pi_improved = pi_prior * jnp.exp(normalized_qs / mcts_mpo_tau)
              pi_improved = pi_improved / jnp.sum(pi_improved)
              return pi_improved

            policy_target = jax.vmap(_compute_q_policy_target, (0, None))(
              jax.tree_map(lambda t: t[:unroll_steps + 1], mcts_out),
              node_index,
            )
        else:
          logging.info("[Improvement] Use One-step lookahead.")
          improve_keys = jax.random.split(
            improve_key, search_roots.state.shape[0]
          )
          policy_target, improve_adv = jax.vmap(
            one_step_improve,
            (None, 0, None, 0, None, 0, None, None, None, None, None),
          )(
            networks,
            improve_keys,
            target_params,
            search_roots,
            num_bins,
            unroll_trajectory.action,
            discount_factor,
            clipping_threshold,
            num_simulations,
            sampling_method,
            dynamics_noise,
          )
        uniform_policy = jnp.ones_like(policy_target) / num_actions
        random_policy_mask = jnp.cumprod(
          1.0 - unroll_trajectory.is_last
        ) == 0.0
        random_policy_mask = jnp.broadcast_to(
          random_policy_mask[:, None], policy_target.shape
        )
        policy_target = jax.lax.select(
          random_policy_mask, uniform_policy, policy_target
        )
        policy_target = jax.lax.stop_gradient(policy_target)

        # c) Value.
        discounts = (1.0 - trajectory.is_last[1:]) * discount_factor

        if not value_no_search and improvement_op in ["mcts", "mcts_mpo"]:
          node_values = mcts_out.search_tree.node_values
          v_bootstrap = node_values[:, mctx.Tree.ROOT_INDEX]
        else:
          if improvement_op in ["mcts", "mcts_mpo"]:
            logging.info("[Learning] Value target is not from search.")
          v_bootstrap = target_roots.value

        def n_step_return(i):
          bootstrap_value = jax.tree_map(
            lambda t: t[i + td_steps], v_bootstrap
          )
          _rewards = jnp.concatenate(
            [rewards[i:i + td_steps], bootstrap_value[None]], axis=0
          )
          _discounts = jnp.concatenate(
            [jnp.ones((1,)),
             jnp.cumprod(discounts[i:i + td_steps])],
            axis=0,
          )
          return jnp.sum(_rewards * _discounts)

        returns = []
        for i in range(unroll_steps + 1):
          returns.append(n_step_return(i))
        returns = jnp.stack(returns)
        # Value targets for the absorbing state and the states after are 0.
        zero_return_mask = jnp.cumprod(1.0 - unroll_trajectory.is_last) == 0.0
        value_target = jax.lax.select(
          zero_return_mask, jnp.zeros_like(returns), returns
        )
        value_target_transformed = value_transform(value_target)
        value_logits_target = scalar_to_two_hot(
          value_target_transformed, num_bins
        )
        value_logits_target = jax.lax.stop_gradient(value_logits_target)

        # 2.5a) Behavior loss
        behavior_loss = 0.0
        if behavior:
          logging.info("[Learning] Use behavior loss.")
          in_sample_action = trajectory.action[:unroll_steps + 1]
          log_prob = jax.nn.log_softmax(policy_logits)
          action_log_prob = log_prob[jnp.arange(unroll_steps + 1),
                                     in_sample_action]

          _target_value = target_roots.value[:unroll_steps + 1]
          _target_reward = target_roots.reward[1:unroll_steps + 1 + 1]
          _target_value_prime = target_roots.value[1:unroll_steps + 1 + 1]
          _target_adv = _target_reward + discount_factor * _target_value_prime - _target_value
          if "exp" in behavior:
            logging.info("[Behavior] Use exponential.")
            behavior_loss = -action_log_prob * jnp.minimum(
              jnp.exp(_target_adv), 5.0
            )
          elif "bin" in behavior:
            logging.info("[Behavior] Use binary.")
            behavior_loss = -action_log_prob * jnp.heaviside(_target_adv, 0.)
          else:
            raise ValueError(f"{behavior} not supported")

          behavior_loss = jnp.mean(behavior_loss) * behavior_coef

        # 2.5b) (experimentally) Conservative dynamics
        pessimism_loss = 0.0
        if pessimism:
          logging.info("[Learning] Use pessimism loss.")

          def _pessimism_loss(target, estimate, expectile=0.1):
            diff = target - estimate
            # weight = jnp.where(diff > 0, expectile, (1 - expectile))
            # return weight * (diff**2)

            diff_oh = scalar_to_two_hot(diff, num_bins)
            _u = (jnp.arange(num_bins) - (num_bins // 2)) / (num_bins // 2)
            _weight = jnp.where(_u > 0, expectile, (1 - expectile)) * (diff**2)
            _weight = _weight / jnp.sum(_weight, axis=-1, keepdims=True)
            _weight = jnp.expand_dims(_weight, 1)
            _target = scalar_to_two_hot(jnp.zeros_like(diff), num_bins)

            _loss = _target * jax.nn.log_softmax(diff_oh)
            _weighted_ce = jnp.sum(_loss * _weight, -1)
            return _weighted_ce

          rng_key, sample_key, noise_key = jax.random.split(rng_key, 3)
          noise_keys = jax.random.split(noise_key, len(policy_logits))

          sample_logits = jax.lax.stop_gradient(policy_logits)
          sample_act = distrax.Categorical(logits=sample_logits).sample(
            seed=sample_key, sample_shape=num_pessimism_samples
          )  # (num_pessimism_samples, unroll_length+1)
          sample_act = jax.lax.transpose(sample_act, (1, 0))
          sample_out: AgentOutput = jax.vmap(
            model_simulate, (None, None, None, 0, 0, None, 0)
          )(
            networks,
            params,
            num_bins,
            jnp.concatenate([learner_root.state[None], model_out.state]),
            sample_act,
            0.,
            noise_keys,
          )
          ood_action_mask = 1 - (
            sample_act == trajectory.action[:unroll_steps + 1, None]
          )
          in_sample_out_tp1: AgentOutput = jax.tree_map(
            lambda t: t[1:unroll_steps + 2], target_roots
          )
          in_sample_reward = jax.lax.stop_gradient(
            in_sample_out_tp1.reward[:, None]
          )
          in_sample_value = jax.lax.stop_gradient(
            in_sample_out_tp1.value[:, None]
          )

          _loss_r = _pessimism_loss(in_sample_reward, sample_out.reward)
          _loss_v = _pessimism_loss(in_sample_value, sample_out.value)
          _loss = _loss_r + _loss_v * value_coef

          pessimism_loss = jnp.mean(
            jax.lax.select(
              ood_action_mask,
              _loss,
              jnp.zeros_like(_loss),
            )
          )

          # sample_keys = jax.random.split(sample_key, policy_target.shape[0])

          # def _compute_pessimism_loss(
          #   sample_logits: Array,
          #   data_action: Array,
          #   rng_key: networks_lib.PRNGKey,
          #   state,
          #   networks,
          #   params,
          #   num_sample: int,
          # ) -> Array:
          #   sample_logits = jax.lax.stop_gradient(sample_logits)
          #   sample_act = distrax.Categorical(logits=sample_logits).sample(
          #     seed=rng_key, sample_shape=num_sample
          #   )
          #   sample_out: AgentOutput = model_simulate(
          #     networks, params, num_bins, state, sample_act
          #   )
          #   ood_action_mask = 1 - (sample_act == data_action)
          #   _target = jnp.ones_like(sample_out.reward) * pessimism_weight
          #   _target = scalar_to_two_hot(value_transform(_target), num_bins)
          #   _batch_loss = _batch_categorical_cross_entropy(
          #     _target, sample_out.reward_logits
          #   ) + value_coef * _batch_categorical_cross_entropy(
          #     _target, sample_out.value_logits
          #   )
          #   chex.assert_shape(_batch_loss, (num_sample,))
          #   return jnp.mean(
          #     jax.lax.select(
          #       ood_action_mask,
          #       _batch_loss,
          #       jnp.zeros_like(_batch_loss),
          #     )
          #   )

          # pessimism_loss = jax.vmap(
          #   _compute_pessimism_loss, (0, 0, 0, 0, None, None, None)
          # )(
          #   policy_logits,
          #   unroll_trajectory.action,
          #   sample_keys,
          #   jax.tree_map(lambda x: x[:unroll_steps + 1], state),
          #   networks,
          #   params,
          #   num_pessimism_samples,
          # )

        # 3) Compute the losses.
        reward_loss = jnp.mean(
          _batch_categorical_cross_entropy(
            reward_logits_target, model_out.reward_logits
          )
        )

        value_loss = jnp.mean(
          _batch_categorical_cross_entropy(value_logits_target, value_logits)
        ) * value_coef

        policy_loss = jnp.mean(
          _batch_categorical_cross_entropy(policy_target, policy_logits)
        ) * policy_coef

        if "only" in behavior:
          logging.info("[Learning] Disable policy distillation loss.")
          policy_loss = 0.0
          behavior_loss = behavior_loss / behavior_coef

        total_loss = (
          reward_loss + value_loss + policy_loss + pessimism_loss +
          behavior_loss
        )

        if num_simulations > 0 and sampling_method not in ["naive", "exact"]:
          # Unnormalized.
          entropy_fn = lambda p: distrax.Categorical(logits=p).entropy()
        else:
          entropy_fn = lambda p: distrax.Categorical(probs=p).entropy()
        policy_target_entropy = jax.vmap(entropy_fn)(policy_target)
        policy_entropy = jax.vmap(
          lambda l: distrax.Categorical(logits=l).entropy()
        )(
          policy_logits
        )

        log = {
          "reward_target": reward_target,
          "reward_prediction": model_out.reward,
          "value_target": value_target,
          "value_prediction": model_out.value,
          "policy_entropy": policy_entropy,
          "policy_target_entropy": policy_target_entropy,
          "reward_loss": reward_loss,
          "value_loss": value_loss,
          "policy_loss": policy_loss,
          "pessimism_loss": pessimism_loss,
          "behavior_loss": behavior_loss,
          "improve_advantage": improve_adv,
          "total_loss": total_loss,
        }

        losses = {
          "reward_loss": reward_loss,
          "value_loss": value_coef * value_loss,
          "policy_loss": policy_coef * policy_loss,
          "pessimism_loss": pessimism_loss,
          "behavior_loss": behavior_loss,
        }
        if _DEBUG_LOSS:
          return losses, log
        else:
          return total_loss, log

    def batch_loss(
      params: Params,
      target_params: Params,
      trajectory: ActorOutput,
      rng_key: networks_lib.PRNGKey,
    ):
      bs = len(trajectory.reward)
      rng_keys = jax.random.split(rng_key, bs)
      losses, log = jax.vmap(loss, (None, None, 0, 0))(
        params,
        target_params,
        trajectory,
        rng_keys,
      )
      log_mean = {f"{k}_mean": jnp.mean(v) for k, v in log.items()}
      std_keys = [
        "reward_target",
        "reward_prediction",
        "q_val_target",
        "q_val_prediction",
        "value_target",
        "value_prediction",
        "improve_advantage",
      ]
      std_keys = [k for k in std_keys if k in log]
      log_std = {f"{k}_std": jnp.std(log[k]) for k in std_keys}
      log_mean.update(log_std)
      # return jnp.mean(losses), log_mean
      return tree.map_structure(jnp.mean, losses), log_mean

    def update_step(
      state: TrainingState,
      trajectory: ActorOutput,
      rng_key: networks_lib.PRNGKey,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      params = state.params
      optimizer_state = state.optimizer_state
      if _DEBUG_LOSS:
        primals, f_vjp, log = jax.vjp(
          functools.partial(
            batch_loss,
            target_params=state.target_params,
            trajectory=trajectory,
            rng_key=rng_key,
          ),  # make argnums=0
          state.params,
          has_aux=True,
        )
        del primals
        vec = {
          "reward_loss": 0.0,
          "value_loss": 0.0,
          "policy_loss": 0.0,
        }
        for l in vec.keys():
          _vec = {k: v for k, v in vec.items()}
          _vec[l] = 1.0
          grads = f_vjp(_vec)[0]
          grads = jax.lax.pmean(grads, axis_name="i")
          network_updates, optimizer_state = optimizer.update(
            grads, optimizer_state, params
          )
          params = optax.apply_updates(params, network_updates)

          log.update(
            {
              f"grad_norm_{l}": optax.global_norm(grads),
              f"update_norm_{l}": optax.global_norm(network_updates),
            }
          )
        log.update({
          "param_norm": optax.global_norm(params),
        })
      else:
        grads, log = jax.grad(
          batch_loss, has_aux=True
        )(state.params, state.target_params, trajectory, rng_key)
        grads = jax.lax.pmean(grads, axis_name="i")
        network_updates, optimizer_state = optimizer.update(
          grads, optimizer_state, params
        )
        params = optax.apply_updates(params, network_updates)
        log.update(
          {
            "grad_norm": optax.global_norm(grads),
            "update_norm": optax.global_norm(network_updates),
            "param_norm": optax.global_norm(params),
          }
        )
      new_state = TrainingState(
        optimizer_state=optimizer_state,
        params=params,
        target_params=state.target_params,
      )
      return new_state, log

    # Logger.
    self._logger = logger or loggers.make_default_logger(
      "learner", asynchronous=True, serialize_fn=utils.fetch_devicearray
    )

    # Iterator on demonstration transitions.
    self._demonstrations = demonstrations

    # JIT compiler.
    self._batch_size = batch_size
    self._num_devices = jax.lib.xla_bridge.device_count()
    assert self._batch_size % self._num_devices == 0
    self._update_step = jax.pmap(update_step, axis_name="i")
    # self._update_step = update_step

    # Create initial state.
    random_key = jax.random.PRNGKey(config["seed"])
    key_r, key_d, key_p, self._rng_key = jax.random.split(random_key, 4)
    representation_params = networks.representation_network.init(key_r)
    transition_params = networks.transition_network.init(key_d)
    prediction_params = networks.prediction_network.init(key_p)

    # Create and initialize optimizer.
    params = Params(
      representation_params,
      transition_params,
      prediction_params,
    )
    weight_decay_mask = Params(
      representation=hk.data_structures.map(
        lambda module_name, name, value: True if name == "w" else False,
        params.representation,
      ),
      transition=hk.data_structures.map(
        lambda module_name, name, value: True if name == "w" else False,
        params.transition,
      ),
      prediction=hk.data_structures.map(
        lambda module_name, name, value: True if name == "w" else False,
        params.prediction,
      ),
    )
    learning_rate = optax.warmup_exponential_decay_schedule(
      init_value=0.0,
      peak_value=config["learning_rate"],
      warmup_steps=config["warmup_steps"],
      transition_steps=100_000,
      decay_rate=config["learning_rate_decay"],
      staircase=True,
    )
    optimizer = optax.adamw(
      learning_rate=learning_rate,
      weight_decay=weight_decay,
      mask=weight_decay_mask,
    )
    if max_grad_norm:
      optimizer = optax.chain(
        optax.clip_by_global_norm(max_grad_norm), optimizer
      )
    optimizer_state = optimizer.init(params)
    target_params = params

    # Learner state.
    self._state = TrainingState(
      optimizer_state=optimizer_state,
      params=params,
      target_params=target_params,
    )
    self._target_update_interval = target_update_interval

    self._state = jax.device_put_replicated(self._state, jax.local_devices())

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online
    # and fill the replay buffer.
    self._timestamp = None
    self._elapsed = 0
    self._step = 0
    self._log_interval = log_interval
    self._unroll_steps = unroll_steps

  def step(self):
    self._step += 1
    update_key, self._rng_key = jax.random.split(self._rng_key)
    update_keys = jax.random.split(update_key, self._num_devices)
    trajectory: ActorOutput = next(self._demonstrations)
    trajectory = tree.map_structure(
      lambda x: x.reshape(
        self._num_devices, self._batch_size // self._num_devices, *x.shape[1:]
      ),
      trajectory,
    )

    self._state, metrics = self._update_step(
      self._state, trajectory, update_keys
    )
    timestamp = time.time()
    elapsed = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp
    self._elapsed += elapsed

    if self._step % self._target_update_interval == 0:
      state: TrainingState = self._state
      self._state = TrainingState(
        optimizer_state=state.optimizer_state,
        params=state.params,
        target_params=state.params,
      )
    if self._step % self._log_interval == 0:
      metrics = jax.tree_util.tree_map(lambda t: t[0], metrics)
      metrics = jax.device_get(metrics)
      self._logger.write(
        {
          **metrics,
          **{
            "step":
              self._step,
            "walltime":
              elapsed,
            "elapsed_time":
              self._elapsed,
            "learner_fps":
              1 / (elapsed + 1e-6) * self._batch_size *
              (self._unroll_steps + 1),
          },
        }
      )

  def get_variables(self, names: List[str]) -> List[Any]:
    state = self.save()
    variables = {
      "representation": state.params.representation,
      "dynamics": state.params.transition,
      "prediction": state.params.prediction,
    }
    return [variables[name] for name in names]

  def save(self) -> TrainingState:
    _state = utils.fetch_devicearray(jax.tree_map(lambda t: t[0], self._state))
    return _state

  def restore(self, state: TrainingState, steps: int):
    self._state = jax.device_put_replicated(state, jax.local_devices())
    self._step = steps


def root_unroll(
  networks: Networks,
  params: Params,
  num_bins: int,
  state: Array,
) -> AgentOutput:
  (
    policy_logits,
    reward_logits,
    value_logits,
  ) = networks.prediction_network.apply(params.prediction, state)
  reward = logits_to_scalar(reward_logits, num_bins)
  reward = inv_value_transform(reward)
  value = logits_to_scalar(value_logits, num_bins)
  value = inv_value_transform(value)
  return AgentOutput(
    state=state,
    policy_logits=policy_logits,
    reward_logits=reward_logits,
    reward=reward,
    value_logits=value_logits,
    value=value,
  )


def model_unroll(
  networks: Networks,
  params: Params,
  num_bins: int,
  state: Array,
  action_sequence: Array,
) -> AgentOutput:
  """The input `state` and `action` are assumed to be [S] and [T]."""

  def fn(state: Array, action: Array):
    next_state = networks.transition_network.apply(
      params.transition, action[None], state
    )
    next_state = scale_gradient(next_state, 0.5)
    return next_state, next_state

  _, state_sequence = jax.lax.scan(fn, state, action_sequence)
  (
    policy_logits,
    reward_logits,
    value_logits,
  ) = networks.prediction_network.apply(params.prediction, state_sequence)
  reward = logits_to_scalar(reward_logits, num_bins)
  reward = inv_value_transform(reward)
  value = logits_to_scalar(value_logits, num_bins)
  value = inv_value_transform(value)
  return AgentOutput(
    state=state_sequence,
    policy_logits=policy_logits,
    reward_logits=reward_logits,
    reward=reward,
    value_logits=value_logits,
    value=value,
  )


def model_simulate(
  networks: Networks,
  params: Params,
  num_bins: int,
  state: Array,
  actions_to_simulate: Array,
  dynamics_noise: float,
  key_noise: networks_lib.PRNGKey,
) -> AgentOutput:
  """The input `state` and `action` are assumed to be [S] and [T]."""

  def fn(state: Array, action: Array, rng_key: networks_lib.PRNGKey):
    next_state = networks.transition_network.apply(
      params.transition, action[None], state
    )
    sigmas = dynamics_noise * next_state
    next_state += jax.random.normal(rng_key, next_state.shape) * sigmas
    return next_state

  keys = jax.random.split(key_noise, actions_to_simulate.shape[0])
  states_imagined = jax.vmap(fn,
                             (None, 0, 0))(state, actions_to_simulate, keys)

  (
    policy_logits,
    reward_logits,
    value_logits,
  ) = networks.prediction_network.apply(params.prediction, states_imagined)
  reward = logits_to_scalar(reward_logits, num_bins)
  reward = inv_value_transform(reward)
  value = logits_to_scalar(value_logits, num_bins)
  value = inv_value_transform(value)
  return AgentOutput(
    state=states_imagined,
    policy_logits=policy_logits,
    reward_logits=reward_logits,
    reward=reward,
    value_logits=value_logits,
    value=value,
  )


SAMPLING_METHOD = [
  "naive",
  "sarsa",
  "muesli",
  "exact",  # No sampling.
]


def one_step_improve(
  networks: Networks,
  rng_key: networks_lib.PRNGKey,
  params: Params,
  model_root: AgentOutput,
  num_bins: int,
  sarsa_action: Array,
  discount_factor: float,
  clipping_threshold: float,
  num_simulations: int = -1,
  sampling_method: str = "",
  dynamics_noise: float = 0.0,
) -> Tuple[Array, Array]:
  """Obtain the one-step lookahead target policy."""
  key_sample, key_noise = jax.random.split(rng_key)
  environment_specs = networks.environment_specs

  pi_prior = jax.nn.softmax(model_root.policy_logits)
  value_prior = model_root.value

  if num_simulations > 0 and sampling_method != "exact":
    assert sampling_method in SAMPLING_METHOD
    logging.info(
      f"[Sample] Using {num_simulations} samples to estimate improvement."
    )
    if sampling_method == "naive":
      # 1. Naively sample from prior policy and set to zero elsewhere.
      logging.info(f"[Sample] Naive.")
      pi_sample = distrax.Categorical(probs=pi_prior)
      sample_acts = pi_sample.sample(
        seed=key_sample, sample_shape=num_simulations
      )
      sample_one_step_out: AgentOutput = model_simulate(
        networks, params, num_bins, model_root.state, sample_acts,
        dynamics_noise, key_noise
      )
      sample_adv = (
        sample_one_step_out.reward +
        discount_factor * sample_one_step_out.value - value_prior
      )
      adv = sample_adv  # for log
      coeff = jnp.zeros_like(pi_prior)
      sample_exp_adv = jnp.exp(sample_adv)

      def body(i, val):
        delta = val.at[sample_acts[i]].set(sample_exp_adv[i])
        return val + delta

      exp_adv = jax.lax.fori_loop(0, num_simulations, body, coeff)
      pi_improved = pi_prior * exp_adv
      pi_improved = pi_improved / jnp.sum(pi_improved)
    elif sampling_method == "sarsa":
      # 2. Use insample action to construct crr-like gradients.
      logging.info(f"[Sample] SARSA CRR-like.")
      sarsa_action = sarsa_action[None]
      sarsa_one_step_out: AgentOutput = model_simulate(
        networks, params, num_bins, model_root.state, sarsa_action,
        dynamics_noise, key_noise
      )
      sarsa_adv = (
        sarsa_one_step_out.reward +
        discount_factor * sarsa_one_step_out.value - value_prior
      )
      adv = jnp.zeros_like(pi_prior)
      # In-sample single-point estimate.
      exp_adv = jnp.minimum(jnp.exp(sarsa_adv / 0.5), 20.0)  # Same as crr.
      adv = adv.at[sarsa_action].set(sarsa_adv)  # for log
      pi_improved = adv.at[sarsa_action].set(exp_adv)
    elif sampling_method == "muesli":
      logging.info(f"[Sample] Muesli-like.")
      pi_sample = distrax.Categorical(probs=pi_prior)
      sample_acts = pi_sample.sample(
        seed=key_sample, sample_shape=num_simulations
      )
      sample_one_step_out: AgentOutput = model_simulate(
        networks, params, num_bins, model_root.state, sample_acts,
        dynamics_noise, key_noise
      )
      sample_adv = (
        sample_one_step_out.reward +
        discount_factor * sample_one_step_out.value - value_prior
      )
      adv = sample_adv  # for log
      sample_exp_adv = jnp.exp(sample_adv)
      normalizer_raw = (jnp.sum(sample_exp_adv) + 1) / num_simulations
      coeff = jnp.zeros_like(pi_prior)

      def body(i, val):
        normalizer_i = normalizer_raw - sample_exp_adv[i] / num_simulations
        delta = jnp.zeros_like(val)
        delta = delta.at[sample_acts[i]].set(sample_exp_adv[i] / normalizer_i)
        return val + delta

      coeff = jax.lax.fori_loop(0, num_simulations, body, coeff)
      pi_improved = coeff / num_simulations
    else:
      raise ValueError("Unknown sampling method.")
  else:
    all_actions = jnp.arange(environment_specs.actions.num_values)
    model_one_step_out: AgentOutput = model_simulate(
      networks, params, num_bins, model_root.state, all_actions,
      dynamics_noise, key_noise
    )
    chex.assert_equal_shape([model_one_step_out.reward, pi_prior])
    chex.assert_equal_shape([model_one_step_out.value, pi_prior])
    adv = (
      model_one_step_out.reward + discount_factor * model_one_step_out.value -
      value_prior
    )
    pi_improved = pi_prior * jnp.exp(adv)
    pi_improved = pi_improved / jnp.sum(pi_improved)

  chex.assert_equal_shape([pi_improved, pi_prior])
  # pi_improved here might not sum to 1, in which case we use CE
  # to conveniently calculate the policy gradients.
  return pi_improved, adv


def mcts_improve(
  networks: Networks,
  rng_key: networks_lib.PRNGKey,
  params: Params,
  num_bins: int,
  model_root: AgentOutput,
  discount_factor: float,
  num_simulations: int,
  search_depth: int,
  dynamics_noise: float = 0.0,
) -> mctx.PolicyOutput:
  """Obtain the Monte-Carlo Tree Search target policy."""

  # Batch size of [T].

  def recurrent_fn(
    params: Params, rng_key: networks_lib.PRNGKey, action: Array, state: Array
  ) -> Tuple[mctx.RecurrentFnOutput, Array]:

    def fn(state: Array, action: Array):
      next_state = networks.transition_network.apply(
        params.transition, action[None], state
      )
      sigmas = dynamics_noise * next_state
      next_state += jax.random.normal(rng_key, next_state.shape) * sigmas
      return next_state

    next_state = jax.vmap(fn, (0, 0))(state, action)

    (
      policy_logits,
      reward_logits,
      value_logits,
    ) = networks.prediction_network.apply(params.prediction, next_state)
    reward = logits_to_scalar(reward_logits, num_bins)
    reward = inv_value_transform(reward)
    value = logits_to_scalar(value_logits, num_bins)
    value = inv_value_transform(value)
    recurrent_fn_output = mctx.RecurrentFnOutput(
      reward=reward,
      discount=jnp.full_like(value, fill_value=discount_factor),
      prior_logits=policy_logits,
      value=value,
    )
    return recurrent_fn_output, next_state

  root = mctx.RootFnOutput(
    prior_logits=model_root.policy_logits,
    value=model_root.value,
    embedding=model_root.state,
  )

  return mctx.muzero_policy(
    params,
    rng_key,
    root,
    recurrent_fn,
    num_simulations,
    max_depth=search_depth,
  )


def qtransform_by_parent_and_siblings(
  tree: mctx.Tree,
  node_index: chex.Numeric,
  safe_q_values: bool = False,
  epsilon: chex.Numeric = 1e-8,
) -> chex.Array:
  """Returns qvalues normalized by min, max over V(node) and qvalues.

    Args:
      tree: _unbatched_ MCTS tree state.
      node_index: scalar index of the parent node.
      epsilon: the minimum denominator for the normalization.

    Returns:
      Q-values normalized to be from the [0, 1] interval. The unvisited actions
      will have zero Q-value. Shape `[num_actions]`.
    """
  chex.assert_shape(node_index, ())
  qvalues = tree.qvalues(node_index)
  visit_counts = tree.children_visits[node_index]
  chex.assert_rank([qvalues, visit_counts, node_index], [1, 1, 0])
  if safe_q_values:
    node_value = tree.node_values[node_index]
  else:
    node_value = 0
  safe_qvalues = jnp.where(visit_counts > 0, qvalues, node_value)
  chex.assert_equal_shape([safe_qvalues, qvalues])
  min_value = jnp.minimum(node_value, jnp.min(safe_qvalues, axis=-1))
  max_value = jnp.maximum(node_value, jnp.max(safe_qvalues, axis=-1))

  completed_by_min = jnp.where(visit_counts > 0, qvalues, min_value)
  normalized = (completed_by_min - min_value) / (
    jnp.maximum(max_value - min_value, epsilon)
  )
  chex.assert_equal_shape([normalized, qvalues])
  return normalized
