"""CQL learner implementation."""

import dataclasses
import time
from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple

import acme
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import tree
from acme import types
from acme.agents.jax.dqn import learning_lib
from acme.agents.jax.dqn.learning_lib import LossExtra, LossFn
from acme.agents.jax.dqn.losses import QrDqn
from acme.jax import networks as networks_lib
from acme.jax import utils
from acme.types import Transition
from acme.utils import loggers

from rosmo.agent.cql_discrete.network import CQLNetworks

_PMAP_AXIS_NAME = "data"


class Params(NamedTuple):
  q_params: networks_lib.Params


class TrainingState(NamedTuple):
  """Contains training state for the learner."""

  params: Params
  target_params: Params
  opt_state: optax.OptState
  steps: int


@dataclasses.dataclass
class ConservativeQLearning(QrDqn):
  num_actions: int = -1
  min_q_weight: float = 1.0

  def __call__(
    self,
    network: networks_lib.FeedForwardNetwork,
    params: networks_lib.Params,
    target_params: networks_lib.Params,
    transitions: types.Transition,
    key: networks_lib.PRNGKey,
  ) -> Tuple[jnp.DeviceArray, LossExtra]:
    network_out_tm1 = network.apply(params, transitions.observation)
    network_out_target_t = network.apply(
      target_params, transitions.next_observation
    )
    dist_q_tm1 = network_out_tm1["q_dist"]
    dist_q_target_t = network_out_target_t["q_dist"]

    quantiles = (
      jnp.arange(self.num_atoms, dtype=jnp.float32) + 0.5
    ) / self.num_atoms
    batch_quantile_q_learning = jax.vmap(
      rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None)
    )
    losses = batch_quantile_q_learning(
      dist_q_tm1,
      quantiles,
      transitions.action,
      transitions.reward,
      transitions.discount,
      dist_q_target_t,  # No double Q-learning here.
      dist_q_target_t,
      self.huber_param,
    )
    qr_loss = jnp.mean(losses)

    # Add the CQL loss.
    q_value = network_out_tm1["q_value"]
    dataset_action_one_hot = hk.one_hot(transitions.action, self.num_actions)
    dataset_chosen_q = jnp.sum(q_value * dataset_action_one_hot, axis=1)
    dataset_expec = jnp.mean(dataset_chosen_q)
    negative_sampling = jnp.mean(jax.scipy.special.logsumexp(q_value, axis=1))

    min_q_loss = (negative_sampling - dataset_expec) * self.min_q_weight

    total_loss = qr_loss + min_q_loss

    extra = learning_lib.LossExtra(
      metrics={
        "qr_loss": qr_loss,
        "min_q_loss": min_q_loss,
        "q_value_mean": jnp.mean(q_value),
        "q_value_std": jnp.std(q_value),
      }
    )
    return total_loss, extra


class CQLLearner(acme.Learner):
  """Conservative Q-Learning (CQL) learner.
    This is the learning component of a CQL agent as described in
    https://arxiv.org/abs/2006.04779.
    """

  _state: TrainingState

  def __init__(
    self,
    networks: CQLNetworks,
    random_key: networks_lib.PRNGKey,
    target_update_period: int,
    num_atoms: int,
    huber_param: float,
    minq_weight: float,
    batch_size: int,
    iterator: Iterator[types.Transition],
    optimizer: optax.GradientTransformation,
    grad_updates_per_batch: int = 1,
    logger: Optional[loggers.Logger] = None,
    log_interval: int = 50,
  ):

    q_network = networks.q_network
    num_actions = networks.environment_specs.actions.num_values
    loss_fn = ConservativeQLearning(
      num_atoms, huber_param, num_actions, minq_weight
    )

    def loss(
      params: Params,
      target_params: Params,
      transition: types.Transition,
      key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
      loss, extra = loss_fn(
        q_network, params.q_params, target_params.q_params, transition, key
      )

      log = {
        "total_loss": loss,
        **extra.metrics,
      }
      return loss, log

    def sgd_step(
      state: TrainingState,
      transitions: types.Transition,
      rng_key: networks_lib.PRNGKey,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      grads, log = jax.grad(
        loss, has_aux=True
      )(state.params, state.target_params, transitions, rng_key)
      grads = jax.lax.pmean(grads, axis_name=_PMAP_AXIS_NAME)
      network_updates, opt_state = optimizer.update(
        grads, state.opt_state, state.params
      )
      params = optax.apply_updates(state.params, network_updates)
      log.update(
        {
          "grad_norm": optax.global_norm(grads),
          "update_norm": optax.global_norm(network_updates),
          "param_norm": optax.global_norm(params),
        }
      )
      steps = state.steps + 1
      target_params = optax.periodic_update(
        params, state.target_params, steps, target_update_period
      )
      new_state = TrainingState(
        params=params,
        target_params=target_params,
        opt_state=opt_state,
        steps=steps,
      )
      return new_state, log

    self._devices = jax.local_devices()
    self._num_devices = len(self._devices)
    self._batch_size = batch_size

    # Logger.
    self._logger = logger or loggers.make_default_logger(
      "learner", asynchronous=True, serialize_fn=utils.fetch_devicearray
    )

    # Create prefetching dataset iterator.
    # self._iterator = utils.sharded_prefetch(iterator, devices=self._devices)
    self._iterator = iterator

    # JIT compiler.
    sgd_step = utils.process_multiple_batches(sgd_step, grad_updates_per_batch)
    self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)

    # Create and initialize optimizer.
    key_qf, self._rng_key = jax.random.split(random_key)
    initial_q_params = q_network.init(key_qf)
    params = Params(initial_q_params,)
    opt_state = optimizer.init(params)

    # Create initial state.
    self._state = TrainingState(
      params=params,
      target_params=params,
      opt_state=opt_state,
      steps=0,
    )

    self._state = jax.device_put_replicated(self._state, self._devices)

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
    self._elapsed = 0
    self._log_interval = log_interval

  def step(self, data, transform=False):
    # self._batch_size = data.observation.shape[0]
    transitions = data
    if transform:
      # import pdb;pdb.set_trace()
      transitions = _add_next_action_extras(transitions)

    transitions = tree.map_structure(
      lambda x: x.reshape(
        self._num_devices, self._batch_size // self._num_devices, *x.shape[1:]
      ),
      transitions,
    )
    update_key, self._rng_key = jax.random.split(self._rng_key)
    update_keys = jax.random.split(update_key, self._num_devices)

    self._state, metrics = self._sgd_step(
      self._state, transitions, update_keys
    )

    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp
    self._elapsed += elapsed_time
    step = jax.device_get(self._state.steps[0])
    if step % self._log_interval == 0:
      metrics = jax.tree_util.tree_map(lambda t: t[0], metrics)
      metrics = jax.device_get(metrics)
      self._logger.write(
        {
          **metrics,
          **{
            "step": step,
            "walltime": elapsed_time,
            "elapsed_time": self._elapsed,
            "learner_fps": 1 / (elapsed_time + 1e-6) * self._batch_size,
          },
        }
      )

  def get_variables(self, names: List[str]) -> List[networks_lib.Params]:
    state = self.save()
    variables = {
      "qf": state.params.q_params,
    }
    return [variables[name] for name in names]

  def save(self) -> TrainingState:
    _state = utils.fetch_devicearray(jax.tree_map(lambda t: t[0], self._state))
    return _state

  def restore(self, state: TrainingState):
    self._state = jax.device_put_replicated(state, jax.local_devices())


def _add_next_action_extras(steps: Dict[str, np.ndarray]) -> Transition:
  return Transition(
    observation=steps.observation[:, 0],
    action=steps.action[:, 0],
    reward=steps.reward[:, 0],
    discount=steps.discount[:, 0],
    next_observation=steps.observation[:, 1],
    extras={"next_action": steps.action[:, 1]},
  )
