"""CRR learner implementation."""

import time
from typing import Callable, Dict, Iterator, List, NamedTuple, Optional, Tuple

import acme
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import tree
from acme import types
from acme.jax import networks as networks_lib
from acme.jax import utils
from acme.types import Transition
from acme.utils import loggers

from rosmo.agent.crr_discrete.network import CRRNetworks
from rosmo.agent.muzero.utils import (
  inv_value_transform,
  logits_to_scalar,
  scalar_to_two_hot,
  value_transform,
)

_PMAP_AXIS_NAME = "data"


class Params(NamedTuple):
  torso_params: networks_lib.Params
  policy_params: networks_lib.Params
  critic_params: networks_lib.Params


class TrainingState(NamedTuple):
  """Contains training state for the learner."""

  params: Params
  target_params: Params
  opt_state: optax.OptState
  steps: int


class CRRLearner(acme.Learner):
  """Critic Regularized Regression (CRR) learner.
    This is the learning component of a CRR agent as described in
    https://arxiv.org/abs/2006.15134.
    """

  _state: TrainingState

  def __init__(
    self,
    networks: CRRNetworks,
    random_key: networks_lib.PRNGKey,
    discount: float,
    target_update_period: int,
    num_bins: int,
    batch_size: int,
    iterator: Iterator[types.Transition],
    optimizer: optax.GradientTransformation,
    num_action_samples: int,
    beta: float,
    coeff_fn: Optional[Callable] = None,
    grad_updates_per_batch: int = 1,
    use_sarsa_target: bool = False,
    logger: Optional[loggers.Logger] = None,
    log_interval: int = 50,
  ):

    torso_network = networks.torso_network
    critic_network = networks.critic_network
    policy_network = networks.policy_network
    _batch_categorical_cross_entropy = jax.vmap(rlax.categorical_cross_entropy)
    distributional = num_bins > 1  # Distributional critic.

    def loss(
      params: Params,
      target_params: Params,
      transition: types.Transition,
      key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
      policy_key, value_key = jax.random.split(key)
      # === Shared Encoding ===
      torso_out = torso_network.apply(
        params.torso_params, transition.observation
      )
      target_torso_out = torso_network.apply(
        target_params.torso_params, transition.next_observation
      )

      # === Policy Loss ===
      # Compute the loss coefficients and weighted loss.
      _coeff_fn = coeff_fn
      if coeff_fn is None:
        _coeff_fn = policy_loss_coeff_advantage_exp
      coeff = _coeff_fn(
        networks,
        params.policy_params,
        params.critic_params,
        torso_out,
        transition.action,
        policy_key,
        num_action_samples=num_action_samples,
        beta=beta,
        num_bins=num_bins,
      )
      coeff = jax.lax.stop_gradient(coeff)
      dist_params = policy_network.apply(params.policy_params, torso_out)
      logp_action = networks.log_prob(dist_params, transition.action)
      policy_loss = -jnp.mean(logp_action * coeff)

      # === Value Loss ===
      # Sample the next action.
      if use_sarsa_target:
        assert (
          "next_action" in transition.extras
        ), "next actions should be given as extras for one step RL."
        next_action = transition.extras["next_action"]
      else:
        next_dist_params = policy_network.apply(
          target_params.policy_params, target_torso_out
        )
        next_action = networks.sample(next_dist_params, value_key)
      # Calculate the value of the next state and action.
      next_q_maybe_logits = critic_network.apply(
        target_params.critic_params, target_torso_out, next_action
      )
      if distributional:
        next_q = logits_to_scalar(next_q_maybe_logits, num_bins)
        next_q = inv_value_transform(next_q)
      else:
        next_q = next_q_maybe_logits
      target_q = transition.reward + transition.discount * discount * next_q
      target_q = jax.lax.stop_gradient(target_q)
      if distributional:
        target_q_logits = scalar_to_two_hot(
          value_transform(target_q), num_bins
        )
      q_maybe_logits = critic_network.apply(
        params.critic_params, torso_out, transition.action
      )
      if distributional:
        critic_loss = jnp.mean(
          _batch_categorical_cross_entropy(target_q_logits, q_maybe_logits)
        )
        predict_q = logits_to_scalar(q_maybe_logits, num_bins)
        predict_q = inv_value_transform(predict_q)
      else:
        predict_q = q_maybe_logits
        q_error = predict_q - target_q
        critic_loss = 0.5 * jnp.mean(jnp.square(q_error))
      total_loss = jnp.mean(policy_loss + critic_loss)
      entropy = dist_params.entropy()
      log = {
        "target_q_mean": jnp.mean(target_q),
        "target_q_std": jnp.std(target_q),
        "predict_q_mean": jnp.mean(predict_q),
        "predict_q_std": jnp.std(predict_q),
        "crr_coeff_mean": jnp.mean(coeff),
        "crr_coeff_std": jnp.std(coeff),
        "entropy": jnp.mean(entropy),
        "policy_loss": jnp.mean(policy_loss),
        "critic_loss": jnp.mean(critic_loss),
      }
      return total_loss, log

    def sgd_step(
      state: TrainingState,
      transitions: types.Transition,
      rng_key: networks_lib.PRNGKey,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      grads, log = jax.grad(
        loss, has_aux=True
      )(state.params, state.target_params, transitions, rng_key)
      grads = jax.lax.pmean(grads, axis_name=_PMAP_AXIS_NAME)
      network_updates, opt_state = optimizer.update(
        grads, state.opt_state, state.params
      )
      params = optax.apply_updates(state.params, network_updates)
      log.update(
        {
          "grad_norm": optax.global_norm(grads),
          "update_norm": optax.global_norm(network_updates),
          "param_norm": optax.global_norm(params),
        }
      )
      steps = state.steps + 1
      target_params = optax.periodic_update(
        params, state.target_params, steps, target_update_period
      )
      new_state = TrainingState(
        params=params,
        target_params=target_params,
        opt_state=opt_state,
        steps=steps,
      )
      return new_state, log

    self._devices = jax.local_devices()
    self._num_devices = len(self._devices)
    self._batch_size = batch_size

    # Logger.
    self._logger = logger or loggers.make_default_logger(
      "learner", asynchronous=True, serialize_fn=utils.fetch_devicearray
    )

    # Create prefetching dataset iterator.
    # self._iterator = utils.sharded_prefetch(iterator, devices=self._devices)
    self._iterator = iterator

    # JIT compiler.
    sgd_step = utils.process_multiple_batches(sgd_step, grad_updates_per_batch)
    self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)

    # Create and initialize optimizer.
    key_torso, key_policy, key_critic, self._rng_key = jax.random.split(
      random_key, 4
    )
    initial_torso_params = torso_network.init(key_torso)
    initial_policy_params = policy_network.init(key_policy)
    initial_critic_params = critic_network.init(key_critic)
    params = Params(
      initial_torso_params,
      initial_policy_params,
      initial_critic_params,
    )
    opt_state = optimizer.init(params)

    # Create initial state.
    self._state = TrainingState(
      params=params,
      target_params=params,
      opt_state=opt_state,
      steps=0,
    )

    self._state = jax.device_put_replicated(self._state, self._devices)

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
    self._elapsed = 0
    self._log_interval = log_interval

  def step(self, transform=False):
    transitions = next(self._iterator)
    if transform:
      # import pdb;pdb.set_trace()
      transitions = _add_next_action_extras(transitions)

    transitions = tree.map_structure(
      lambda x: x.reshape(
        self._num_devices, self._batch_size // self._num_devices, *x.shape[1:]
      ),
      transitions,
    )
    update_key, self._rng_key = jax.random.split(self._rng_key)
    update_keys = jax.random.split(update_key, self._num_devices)

    self._state, metrics = self._sgd_step(
      self._state, transitions, update_keys
    )

    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp
    self._elapsed += elapsed_time
    step = jax.device_get(self._state.steps[0])
    if step % self._log_interval == 0:
      metrics = jax.tree_util.tree_map(lambda t: t[0], metrics)
      metrics = jax.device_get(metrics)
      self._logger.write(
        {
          **metrics,
          **{
            "step": step,
            "walltime": elapsed_time,
            "elapsed_time": self._elapsed,
            "learner_fps": 1 / (elapsed_time + 1e-6) * self._batch_size,
          },
        }
      )

  def get_variables(self, names: List[str]) -> List[networks_lib.Params]:
    state = self.save()
    variables = {
      "torso": state.params.torso_params,
      "policy": state.params.policy_params,
      "critic": state.params.critic_params,
    }
    return [variables[name] for name in names]

  def save(self) -> TrainingState:
    _state = utils.fetch_devicearray(jax.tree_map(lambda t: t[0], self._state))
    return _state

  def restore(self, state: TrainingState):
    self._state = jax.device_put_replicated(state, jax.local_devices())


def _compute_advantage(
  networks: CRRNetworks,
  policy_params: networks_lib.Params,
  critic_params: networks_lib.Params,
  torso_out: jnp.array,
  action: jnp.array,
  key: networks_lib.PRNGKey,
  num_action_samples: int = 4,
  num_bins: int = 1,
) -> jnp.ndarray:
  """Returns the advantage for the transition."""
  distributional = num_bins > 1
  # Sample count actions.
  replicated_torso_out = jnp.broadcast_to(
    torso_out, (num_action_samples,) + torso_out.shape
  )

  def _sample(
    policy_params,
    critic_params,
    torso_out,
    key,
  ):
    dist_params = networks.policy_network.apply(policy_params, torso_out)
    actions = networks.sample(dist_params, key)
    # Compute the state-action values for the sampled actions.
    q_actions = networks.critic_network.apply(
      critic_params, torso_out, actions
    )
    if distributional:
      q_actions = logits_to_scalar(q_actions, num_bins)
      q_actions = inv_value_transform(q_actions)
    return q_actions

  sample_keys = jax.random.split(key, num_action_samples)
  q_actions = jax.vmap(
    _sample, in_axes=(None, None, 0, 0)
  )(policy_params, critic_params, replicated_torso_out, sample_keys)

  # Take the mean as the state-value estimate. It is also possible to take the
  # maximum, aka CRR(max); see table 1 in CRR paper.
  q_estimate = jnp.mean(q_actions, axis=0)
  # Compute the advantage.
  q = networks.critic_network.apply(critic_params, torso_out, action)
  if distributional:
    q = logits_to_scalar(q, num_bins)
    q = inv_value_transform(q)
  return q - q_estimate


def policy_loss_coeff_advantage_exp(
  networks: CRRNetworks,
  policy_params: networks_lib.Params,
  critic_params: networks_lib.Params,
  torso_out: jnp.array,
  action: jnp.array,
  key: networks_lib.PRNGKey,
  num_action_samples: int = 8,
  beta: float = 0.7,
  ratio_upper_bound: float = 20.0,
  num_bins: int = 1,
) -> jnp.ndarray:
  """Exponential advantage weigting; see equation (4) in CRR paper."""
  advantage = _compute_advantage(
    networks,
    policy_params,
    critic_params,
    torso_out,
    action,
    key,
    num_action_samples,
    num_bins,
  )
  return jnp.minimum(jnp.exp(advantage / beta), ratio_upper_bound)

def _add_next_action_extras(steps: Dict[str, np.ndarray]) -> Transition:
  return Transition(
    observation=steps.observation[:, 0],
    action=steps.action[:, 0],
    reward=steps.reward[:, 0],
    discount=steps.discount[:, 0],
    next_observation=steps.observation[:, 1],
    extras={"next_action": steps.action[:, 1]},
  )
