"""CRR adapted from Acme implementation."""

import dataclasses
import functools
import time
from typing import Callable, Dict, Iterator, List, NamedTuple, Optional, Tuple

import acme
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl import logging
from acme import specs, types
from acme.agents.jax import actor_core as actor_core_lib
from acme.agents.jax import actors
from acme.jax import networks as networks_lib
from acme.jax import utils, variable_utils
from acme.types import Transition
from acme.utils import counting, loggers
from ml_collections import ConfigDict

from rosmo.agent.base import AgentBuilder
from rosmo.types import ActorOutput


@dataclasses.dataclass
class CRRNetworks:
  """Network and pure functions for the CRR agent.."""

  encoder_network: networks_lib.FeedForwardNetwork
  policy_network: networks_lib.FeedForwardNetwork
  critic_network: networks_lib.FeedForwardNetwork
  log_prob: networks_lib.LogProbFn
  sample: networks_lib.SampleFn
  sample_eval: networks_lib.SampleFn


PolicyLossCoeff = Callable[[
  CRRNetworks,
  networks_lib.Params,
  networks_lib.Params,
  networks_lib.Params,
  types.Transition,
  networks_lib.PRNGKey,
], jnp.ndarray,]


def make_networks(
  spec: specs.EnvironmentSpec,
  encoder_layer_sizes: Tuple[int, ...] = (256, 256),
  policy_layer_sizes: Tuple[int, ...] = (256, 256),
  critic_layer_sizes: Tuple[int, ...] = (256, 256),
  activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
  **_,
) -> CRRNetworks:
  """Creates networks used by the agent."""
  num_actions = np.prod(spec.actions.shape, dtype=int)

  # Create dummy observations and actions to create network parameters.
  dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
  dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

  def _encoder_fn(obs: jnp.ndarray) -> jnp.ndarray:
    encoder = hk.Sequential(
      [
        hk.nets.MLP(
          list(encoder_layer_sizes),
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
        ),
        jax.nn.elu,
      ]
    )
    return encoder(obs)

  encoder = hk.without_apply_rng(hk.transform(_encoder_fn))
  encoder_network = networks_lib.FeedForwardNetwork(
    lambda key: encoder.init(key, dummy_obs), encoder.apply
  )

  def dummy_state(key):
    encoder_params = encoder.init(key, dummy_obs)
    return encoder.apply(encoder_params, dummy_obs)

  def _policy_fn(state: jnp.ndarray) -> jnp.ndarray:
    network = hk.Sequential(
      [
        hk.nets.MLP(
          list(policy_layer_sizes),
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
          activate_final=True,
        ),
        networks_lib.NormalTanhDistribution(num_actions),
      ]
    )
    return network(state)

  policy = hk.without_apply_rng(hk.transform(_policy_fn))
  policy_network = networks_lib.FeedForwardNetwork(
    lambda key: policy.init(key, dummy_state(key)), policy.apply
  )

  def _critic_fn(state, action):
    network = hk.Sequential(
      [
        hk.nets.MLP(
          list(critic_layer_sizes) + [1],
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
        ),
      ]
    )
    data = jnp.concatenate([state, action], axis=-1)
    return network(data)

  critic = hk.without_apply_rng(hk.transform(_critic_fn))
  critic_network = networks_lib.FeedForwardNetwork(
    lambda key: critic.init(key, dummy_state(key), dummy_action), critic.apply
  )

  return CRRNetworks(
    encoder_network=encoder_network,
    policy_network=policy_network,
    critic_network=critic_network,
    log_prob=lambda params, actions: params.log_prob(actions),
    sample=lambda params, key: params.sample(seed=key),
    sample_eval=lambda params, key: params.mode(),
  )


def _compute_advantage(
  networks: CRRNetworks,
  encoder_params: networks_lib.Params,
  policy_params: networks_lib.Params,
  critic_params: networks_lib.Params,
  transition: types.Transition,
  key: networks_lib.PRNGKey,
  num_action_samples: int = 4,
) -> jnp.ndarray:
  """Returns the advantage for the transition."""
  # Sample count actions.
  state = networks.encoder_network.apply(
    encoder_params, transition.observation
  )
  replicated_state = jnp.broadcast_to(
    state, (num_action_samples,) + state.shape
  )
  dist_params = networks.policy_network.apply(policy_params, replicated_state)
  actions = networks.sample(dist_params, key)
  # Compute the state-action values for the sampled actions.
  q_actions = networks.critic_network.apply(
    critic_params, replicated_state, actions
  )
  # Take the mean as the state-value estimate. It is also possible to take the
  # maximum, aka CRR(max); see table 1 in CRR paper.
  q_estimate = jnp.mean(q_actions, axis=0)
  # Compute the advantage.
  q = networks.critic_network.apply(critic_params, state, transition.action)
  return q - q_estimate


def policy_loss_coeff_advantage_exp(
  networks: CRRNetworks,
  encoder_params: networks_lib.Params,
  policy_params: networks_lib.Params,
  critic_params: networks_lib.Params,
  transition: types.Transition,
  key: networks_lib.PRNGKey,
  num_action_samples: int = 4,
  beta: float = 1.0,
  ratio_upper_bound: float = 20.0,
) -> jnp.ndarray:
  """Exponential advantage weigting; see equation (4) in CRR paper."""
  advantage = _compute_advantage(
    networks,
    encoder_params,
    policy_params,
    critic_params,
    transition,
    key,
    num_action_samples,
  )
  return jnp.minimum(jnp.exp(advantage / beta), ratio_upper_bound)


def policy_loss_coeff_advantage_indicator(
  networks: CRRNetworks,
  encoder_params: networks_lib.Params,
  policy_params: networks_lib.Params,
  critic_params: networks_lib.Params,
  transition: types.Transition,
  key: networks_lib.PRNGKey,
  num_action_samples: int = 4,
) -> jnp.ndarray:
  """Indicator advantage weighting; see equation (3) in CRR paper."""
  advantage = _compute_advantage(
    networks,
    encoder_params,
    policy_params,
    critic_params,
    transition,
    key,
    num_action_samples,
  )
  return jnp.heaviside(advantage, 0.0)


def policy_loss_coeff_constant(
  networks: CRRNetworks,
  encoder_params: networks_lib.Params,
  policy_params: networks_lib.Params,
  critic_params: networks_lib.Params,
  transition: types.Transition,
  key: networks_lib.PRNGKey,
  value: float = 1.0,
) -> jnp.ndarray:
  """Constant weights."""
  del networks
  del encoder_params
  del policy_params
  del critic_params
  del transition
  del key
  return value


def add_next_action_extras(double_transitions: ActorOutput) -> Transition:
  """Split batched 2-step trajectory into sarsa transitions."""
  return Transition(
    observation=double_transitions.observation[:, 0, ...],
    action=double_transitions.action[:, 0, ...],
    reward=double_transitions.reward[:, 0],
    discount=1.0 - double_transitions.is_last[:, 0],
    next_observation=double_transitions.observation[:, 1, ...],
    extras={"next_action": double_transitions.action[:, 1, ...]},
  )


class Params(NamedTuple):
  """Parameters."""

  encoder_params: networks_lib.Params
  policy_params: networks_lib.Params
  critic_params: networks_lib.Params


class TrainingState(NamedTuple):
  """Contains training state for the learner."""

  params: Params
  target_params: Params
  opt_state: optax.OptState
  steps: int
  key: networks_lib.PRNGKey


class CRRLearner(acme.Learner):
  """Critic Regularized Regression (CRR) learner.

    This is the learning component of a CRR agent as described in
    https://arxiv.org/abs/2006.15134.
    """

  _state: TrainingState

  def __init__(
    self,
    networks: CRRNetworks,
    random_key: networks_lib.PRNGKey,
    discount: float,
    target_update_period: int,
    policy_loss_coeff_fn: PolicyLossCoeff,
    demonstrations: Iterator[ActorOutput],
    optimizer: optax.GradientTransformation,
    counter: Optional[counting.Counter] = None,
    logger: Optional[loggers.Logger] = None,
    grad_updates_per_batch: int = 1,
    use_sarsa_target: bool = False,
    use_bc: bool = False,
    **_,
  ):
    """Initializes the CRR learner.

        Args:
          networks: CRR networks.
          random_key: a key for random number generation.
          discount: discount to use for TD updates.
          target_update_period: period to update target"s parameters.
          policy_loss_coeff_fn: set the loss function for the policy.
          iterator: an iterator over training data.
          policy_optimizer: the policy optimizer.
          critic_optimizer: the Q-function optimizer.
          counter: counter object used to keep track of steps.
          logger: logger object to be used by learner.
          grad_updates_per_batch: how many gradient updates given a sampled batch.
          use_sarsa_target: compute on-policy target using iterator"s actions rather
            than sampled actions.
            Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf).
            When set to `True`, `target_policy_params` are unused.
          use_bc: learn a behavior cloning agent.
        """

    encoder_network = networks.encoder_network
    critic_network = networks.critic_network
    policy_network = networks.policy_network

    def policy_loss(
      encoder_params: networks_lib.Params,
      policy_params: networks_lib.Params,
      critic_params: networks_lib.Params,
      transition: types.Transition,
      key: networks_lib.PRNGKey,
    ) -> jnp.ndarray:
      # Compute the loss coefficients.
      coeff = policy_loss_coeff_fn(
        networks, encoder_params, policy_params, critic_params, transition, key
      )
      coeff = jax.lax.stop_gradient(coeff)
      # Return the weighted loss.
      state = encoder_network.apply(encoder_params, transition.observation)
      dist_params = policy_network.apply(policy_params, state)
      logp_action = networks.log_prob(dist_params, transition.action)
      return -jnp.mean(logp_action * coeff), coeff

    def critic_loss(
      encoder_params: networks_lib.Params,
      critic_params: networks_lib.Params,
      target_encoder_params: networks_lib.Params,
      target_policy_params: networks_lib.Params,
      target_critic_params: networks_lib.Params,
      transition: types.Transition,
      key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.ndarray, ...]:
      target_next_state = encoder_network.apply(
        target_encoder_params, transition.next_observation
      )
      state = encoder_network.apply(encoder_params, transition.observation)
      # Sample the next action.
      if use_sarsa_target:
        assert (
          "next_action" in transition.extras
        ), "next actions should be given as extras for one step RL."
        next_action = transition.extras["next_action"]
      else:
        next_dist_params = policy_network.apply(
          target_policy_params, target_next_state
        )
        next_action = networks.sample(next_dist_params, key)
      # Calculate the value of the next state and action.
      next_q = critic_network.apply(
        target_critic_params, target_next_state, next_action
      )
      target_q = transition.reward + transition.discount * discount * next_q
      target_q = jax.lax.stop_gradient(target_q)

      q = critic_network.apply(critic_params, state, transition.action)
      q_error = q - target_q
      return 0.5 * jnp.mean(jnp.square(q_error)), target_q

    def total_loss(
      params: Params,
      target_params: Params,
      transition: types.Transition,
      policy_key: networks_lib.PRNGKey,
      critic_key: networks_lib.PRNGKey,
    ):
      encoder_params = params.encoder_params
      policy_params = params.policy_params
      critic_params = params.critic_params
      target_encoder_params = params.encoder_params
      target_policy_params = target_params.policy_params
      target_critic_params = target_params.critic_params

      loss_policy, coeff = policy_loss(
        encoder_params, policy_params, critic_params, transition, policy_key
      )
      total_loss = loss_policy

      loss_critic = 0
      target_q = 0
      if not use_bc:
        loss_critic, target_q = critic_loss(
          encoder_params,
          critic_params,
          target_encoder_params,
          target_policy_params,
          target_critic_params,
          transition,
          critic_key,
        )
        total_loss += loss_critic

      return total_loss, {
        "policy_loss": loss_policy,
        "critic_loss": loss_critic,
        "coeff_mean": jnp.mean(coeff),
        "coeff_std": jnp.std(coeff),
        "target_q_mean": jnp.mean(target_q),
        "target_q_std": jnp.std(target_q),
      }

    def sgd_step(
      state: TrainingState,
      transitions: types.Transition,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

      key, key_policy, key_critic = jax.random.split(state.key, 3)

      # Compute losses and their gradients.
      grads, log = jax.grad(
        total_loss, has_aux=True
      )(
        state.params, state.target_params, transitions, key_policy, key_critic
      )

      # Get optimizer updates and state.
      updates, opt_state = optimizer.update(grads, state.opt_state)

      # Apply optimizer updates to parameters.
      params = optax.apply_updates(state.params, updates)

      steps = state.steps + 1

      # Periodically update target networks.
      target_params = optax.periodic_update(
        params, state.target_params, steps, target_update_period
      )

      new_state = TrainingState(
        params=params,
        target_params=target_params,
        opt_state=opt_state,
        steps=steps,
        key=key,
      )

      log.update(
        {
          "grad_norm": optax.global_norm(grads),
          "update_norm": optax.global_norm(updates),
          "params_norm": optax.global_norm(params),
        }
      )

      return new_state, log

    sgd_step = utils.process_multiple_batches(sgd_step, grad_updates_per_batch)
    self._sgd_step = jax.jit(sgd_step)

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger

    # Create prefetching dataset iterator.
    self._iterator = demonstrations

    # Create the network parameters and copy into the target network parameters.
    key, key_encoder, key_policy, key_critic = jax.random.split(random_key, 4)
    initial_encoder_params = encoder_network.init(key_encoder)
    initial_policy_params = policy_network.init(key_policy)
    initial_critic_params = critic_network.init(key_critic)
    initial_params = Params(
      encoder_params=initial_encoder_params,
      policy_params=initial_policy_params,
      critic_params=initial_critic_params,
    )
    initial_target_params = initial_params

    # Initialize optimizers.
    initial_opt_state = optimizer.init(initial_params)

    # Create initial state.
    self._state = TrainingState(
      params=initial_params,
      target_params=initial_target_params,
      opt_state=initial_opt_state,
      steps=0,
      key=key,
    )

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None

  def step(self):
    timesteps = next(self._iterator)
    transitions = add_next_action_extras(timesteps)

    self._state, metrics = self._sgd_step(self._state, transitions)

    # Compute elapsed time.
    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp

    # Increment counts and record the current time
    counts = self._counter.increment(steps=1, walltime=elapsed_time)

    # Attempts to write the logs.
    self._logger.write({**metrics, **counts})

  def get_variables(self, names: List[str]) -> List[networks_lib.Params]:
    # We only expose the variables for the learned policy and critic. The target
    # policy and critic are internal details.
    variables = {
      "encoder": self._state.params.encoder_params,
      "policy": self._state.params.policy_params,
      "critic": self._state.params.critic_params,
    }
    return [variables[name] for name in names]

  def save(self) -> TrainingState:
    return self._state

  def restore(self, state: TrainingState):
    self._state = state


class CRRBuilder(AgentBuilder):

  def __init__(self, use_bc: bool = False) -> None:
    super().__init__()
    self._use_bc = use_bc

  def make_default_configs(self) -> ConfigDict:
    config = ConfigDict()
    config.encoder_layer_sizes = [256, 256]
    config.policy_layer_sizes = [256, 256]
    config.critic_layer_sizes = [256, 256]
    config.discount = 0.99
    config.batch_size = 256
    config.learning_rate = 1e-4
    config.target_update_period = 100
    config.grad_updates_per_batch = 1
    config.use_sarsa_target = True
    config.policy_loss_coeff = "exp"
    config.beta = 1 / 3  # IQL's inv temp = 3.
    config.trajectory_length = 1
    if self._use_bc:
      config.policy_loss_coeff = "constant"
    return config

  def make_networks(
    self, env_spec: specs.EnvironmentSpec, **kwargs
  ) -> CRRNetworks:
    return make_networks(env_spec, **kwargs)

  def make_learner(self, **kwargs) -> CRRLearner:
    n = kwargs["policy_loss_coeff"]
    beta = kwargs["beta"]
    if n == "exp":
      fun = functools.partial(policy_loss_coeff_advantage_exp, beta=beta)
    elif n == "bin":
      fun = policy_loss_coeff_advantage_indicator
    elif n == "constant":
      logging.info("Use Behavior Cloning")
      fun = policy_loss_coeff_constant
    return CRRLearner(
      **kwargs,
      policy_loss_coeff_fn=fun,
      optimizer=optax.adam(kwargs["learning_rate"]),
      use_bc=self._use_bc,
    )

  def make_evaluator(
    self, networks: CRRNetworks, learner: CRRLearner,
    rng_key: networks_lib.PRNGKey
  ) -> acme.Actor:

    def evaluator_network(
      params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray
    ) -> jnp.DeviceArray:
      state = networks.encoder_network.apply(params[0], observation)
      dist_params = networks.policy_network.apply(params[1], state)
      return networks.sample_eval(dist_params, key)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
      evaluator_network
    )
    variable_client = variable_utils.VariableClient(
      learner, ["encoder", "policy"], device="cpu"
    )
    evaluator = actors.GenericActor(
      actor_core, rng_key, variable_client, backend="cpu"
    )
    return evaluator
