"""Model-based onestep lookahead."""

import dataclasses
import time
from typing import (
  Any,
  Callable,
  Dict,
  Iterable,
  Iterator,
  List,
  NamedTuple,
  Optional,
  Tuple,
)

import acme
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import tensorflow_probability
from acme import specs
from acme.agents.jax import actor_core as actor_core_lib
from acme.agents.jax import actors
from acme.jax import networks as networks_lib
from acme.jax import utils, variable_utils
from acme.utils import counting, loggers
from ml_collections import ConfigDict

from rosmo.agent.base import AgentBuilder
from rosmo.types import ActorOutput

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions


@dataclasses.dataclass
class OSNetworks:
  """Network and pure functions for the Onestep agent.."""

  encoder_network: networks_lib.FeedForwardNetwork
  transition_network: networks_lib.FeedForwardNetwork
  prediction_network: networks_lib.FeedForwardNetwork
  log_prob: networks_lib.LogProbFn
  sample: Callable[[jnp.ndarray, jnp.ndarray, networks_lib.PRNGKey],
                   jnp.ndarray]
  sample_eval: networks_lib.SampleFn


class PredictionOutput(NamedTuple):
  """Prediction output."""

  state: jnp.ndarray
  policy: tfd.Distribution
  value: jnp.ndarray
  reward: jnp.ndarray


class AgentOutput(NamedTuple):
  """Agent output."""

  state: jnp.ndarray
  logp_action: jnp.ndarray
  value: jnp.ndarray
  reward: jnp.ndarray


class MlpLn(hk.nets.MLP):
  """A multi-layer perceptron module interleaved with LayerNorm."""

  def __init__(
    self,
    output_sizes: Iterable[Tuple[int, str]],
    w_init: Optional[hk.initializers.Initializer] = None,
    b_init: Optional[hk.initializers.Initializer] = None,
    with_bias: bool = True,
    activation: Callable[[jnp.ndarray], jnp.ndarray] = ...,
    activate_final: bool = False,
    name: Optional[str] = None,
  ):
    if not with_bias and b_init is not None:
      raise ValueError("When with_bias=False b_init must not be set.")

    super().__init__([], name=name)
    self.with_bias = with_bias
    self.w_init = w_init
    self.b_init = b_init
    self.activation = activation
    self.activate_final = activate_final
    layers = []
    output_sizes = tuple(output_sizes)
    for index, output_size in enumerate(output_sizes):
      if isinstance(output_size, int):
        layers.append(
          hk.Linear(
            output_size=output_size,
            w_init=w_init,
            b_init=b_init,
            with_bias=with_bias,
            name="linear_%d" % index,
          )
        )
      elif output_size == "ln":
        layers.append(
          hk.LayerNorm(axis=(-1), create_scale=True, create_offset=True)
        )
      else:
        raise ValueError("Invalid output_size.")

    self.layers = tuple(layers)
    self.output_size = int(output_sizes[-1]) if output_sizes else None


def make_networks(
  spec: specs.EnvironmentSpec,
  encoder_layer_sizes: Tuple[Any, ...] = (256, "ln", 256),
  transition_layer_sizes: Tuple[Any, ...] = (256, "ln", 256),
  prediction_layer_sizes: Tuple[Any, ...] = (256, "ln", 256, 256),
  head_layer_sizes: Tuple[Any, ...] = (128, 64),
  activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
  **_,
) -> OSNetworks:
  """Creates networks used by the agent."""
  num_actions = np.prod(spec.actions.shape, dtype=int)

  # Create dummy observations and actions to create network parameters.
  dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
  dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

  def _encoder_fn(obs: jnp.ndarray) -> jnp.ndarray:
    encoder = hk.Sequential(
      [
        MlpLn(
          list(encoder_layer_sizes),
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
          activate_final=True,
        )
      ]
    )
    return encoder(obs)

  encoder = hk.without_apply_rng(hk.transform(_encoder_fn))
  encoder_network = networks_lib.FeedForwardNetwork(
    lambda key: encoder.init(key, dummy_obs), encoder.apply
  )

  def dummy_state(key):
    encoder_params = encoder.init(key, dummy_obs)
    return encoder.apply(encoder_params, dummy_obs)

  def _transition_fn(state: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
    dynamics = hk.Sequential(
      [
        MlpLn(
          list(transition_layer_sizes),
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
          activate_final=True,
        )
      ]
    )
    sa = jnp.concatenate([state, action], axis=-1)
    return dynamics(sa)

  transition = hk.without_apply_rng(hk.transform(_transition_fn))
  transition_network = networks_lib.FeedForwardNetwork(
    lambda key: transition.init(key, dummy_state(key), dummy_action),
    transition.apply,
  )

  def _prediction_fn(state: jnp.ndarray) -> PredictionOutput:
    network_torso = hk.Sequential(
      [
        MlpLn(
          list(prediction_layer_sizes),
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
          activate_final=True,
        )
      ]
    )
    torso = network_torso(state)
    policy_head = hk.Sequential(
      [
        MlpLn(
          list(head_layer_sizes),
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
          activate_final=True,
        ),
        networks_lib.NormalTanhDistribution(num_actions),
      ]
    )
    value_head = hk.Sequential(
      [
        MlpLn(
          list(head_layer_sizes) + [1],  
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
        ),
      ]
    )
    reward_head = hk.Sequential(
      [
        MlpLn(
          list(head_layer_sizes) + [1],
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
        ),
      ]
    )
    return PredictionOutput(
      state=state,
      policy=policy_head(torso),
      value=value_head(torso),
      reward=reward_head(torso),
    )

  prediction = hk.without_apply_rng(hk.transform(_prediction_fn))
  prediction_network = networks_lib.FeedForwardNetwork(
    lambda key: prediction.init(key, dummy_state(key)), prediction.apply
  )

  return OSNetworks(
    encoder_network=encoder_network,
    transition_network=transition_network,
    prediction_network=prediction_network,
    log_prob=lambda params, actions: params.log_prob(actions),
    sample=lambda params, shape, key: params.sample(
      sample_shape=shape,
      seed=key,
    ),
    sample_eval=lambda params, key: params.mode(),
  )


class Params(NamedTuple):
  """Parameters."""

  encoder_params: networks_lib.Params
  transition_params: networks_lib.Params
  prediction_params: networks_lib.Params


class TrainingState(NamedTuple):
  """Contains training state for the learner."""

  params: Params
  target_params: Params
  opt_state: optax.OptState
  steps: int
  key: networks_lib.PRNGKey


def root_unroll(
  networks: OSNetworks,
  params: Params,
  state: jnp.ndarray,
  action_eval: jnp.ndarray,
) -> AgentOutput:
  prediction: PredictionOutput = networks.prediction_network.apply(
    params.prediction_params, state
  )
  return AgentOutput(
    state=prediction.state,
    logp_action=networks.log_prob(prediction.policy, action_eval),
    value=jnp.squeeze(prediction.value, axis=-1),
    reward=jnp.squeeze(prediction.reward, axis=-1),
  )


def scale_gradient(g, scale: float):
  return g * scale + jax.lax.stop_gradient(g) * (1.0 - scale)


def model_unroll(
  networks: OSNetworks,
  params: Params,
  state: jnp.ndarray,
  action_sequence: jnp.ndarray,
  action_eval: jnp.ndarray,
) -> AgentOutput:

  def fn(state: jnp.ndarray, action: jnp.ndarray):
    next_state = networks.transition_network.apply(
      params.transition_params, state, action
    )
    next_state = scale_gradient(next_state, 0.5)
    return next_state, next_state

  _, state_sequence = jax.lax.scan(fn, state, action_sequence)
  prediction: PredictionOutput = networks.prediction_network.apply(
    params.prediction_params, state_sequence
  )
  return AgentOutput(
    state=prediction.state,
    logp_action=networks.log_prob(prediction.policy, action_eval),
    value=jnp.squeeze(prediction.value, axis=-1),
    reward=jnp.squeeze(prediction.reward, axis=-1),
  )


def _compute_q(
  networks: OSNetworks,
  params: Params,
  state: jnp.ndarray,
  action: jnp.ndarray,
  discount: float,
) -> jnp.ndarray:
  next_state = networks.transition_network.apply(
    params.transition_params, state, action
  )
  prediction: PredictionOutput = networks.prediction_network.apply(
    params.prediction_params, next_state
  )
  q_value = prediction.reward + discount * prediction.value
  return q_value


def one_step_adv_exp(
  networks: OSNetworks,
  target_params: Params,
  trajectory: ActorOutput,
  target_roots: AgentOutput,
  discount: float,
  clip_threshold: float = 1.0,
) -> Tuple[jnp.ndarray, ...]:

  q_data = _compute_q(
    networks, target_params, target_roots.state, trajectory.action, discount
  )
  # dist_params = networks.prediction_network.apply(
  #   target_params.prediction_params, model_root.state
  # )

  # actions = networks.sample(
  #   dist_params, jnp.ndarray(num_action_samples, dtype="int32"), key
  # )
  # q_actions = _compute_q(
  #   networks, target_params, model_root.state, actions, discount
  # )
  # q_estimate = jnp.mean(q_actions)
  # return q_data - q_estimate
  adv = q_data - target_roots.value
  coeff = jnp.exp(jnp.clip(adv, -clip_threshold, clip_threshold))
  return coeff, adv


class OneStepLearner(acme.Learner):
  """Model-based One-step Lookahead learner."""

  _state: TrainingState

  def __init__(
    self,
    networks: OSNetworks,
    random_key: networks_lib.PRNGKey,
    discount_factor: float,
    target_update_period: int,
    unroll_steps: int,
    td_steps: int,
    value_coeff: float,
    policy_coeff: float,
    reward_coeff: float,
    demonstrations: Iterator[ActorOutput],
    optimizer: optax.GradientTransformation,
    grad_updates_per_batch: int = 1,
    counter: Optional[counting.Counter] = None,
    logger: Optional[loggers.Logger] = None,
    **_,
  ):

    def loss(
      params: Params,
      target_params: Params,
      trajectory: ActorOutput,
      rng_key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.ndarray, Any]:
      del rng_key
      state = networks.encoder_network.apply(
        params.encoder_params, trajectory.observation
      )
      target_state = networks.encoder_network.apply(
        target_params.encoder_params, trajectory.observation
      )

      # 1. Model unroll.
      root_state = jax.tree_map(lambda t: t[:1], state)
      learner_root = root_unroll(
        networks, params, root_state, trajectory.action[:1]
      )
      learner_root: AgentOutput = jax.tree_map(lambda t: t[0], learner_root)

      unroll_trajectory: ActorOutput = jax.tree_map(
        lambda t: t[:unroll_steps + 1], trajectory
      )
      invalid_action_mask = (
        jnp.cumprod(1.0 - unroll_trajectory.is_first[1:]) == 0.0
      )
      action_sequence = unroll_trajectory.action[:unroll_steps]
      dummy_actions = jnp.zeros_like(action_sequence)
      simulate_action_sequence = jax.lax.select(
        jnp.broadcast_to(invalid_action_mask[:, None], dummy_actions.shape),
        dummy_actions,
        action_sequence,
      )
      model_out = model_unroll(
        networks,
        params,
        learner_root.state,
        simulate_action_sequence,
        unroll_trajectory.action[1:],
      )

      # 2. Construct targets.
      target_roots = root_unroll(
        networks, target_params, target_state, trajectory.action
      )

      # Reward.
      rewards = trajectory.reward
      reward_target = jax.lax.select(
        invalid_action_mask,
        jnp.zeros_like(rewards[:unroll_steps]),
        rewards[:unroll_steps],
      )

      # Value.
      discounts = (1.0 - trajectory.is_first[1:]) * discount_factor
      v_bootstrap = target_roots.value

      def n_step_return(i):
        bootstrap_value = jax.tree_map(lambda t: t[i + td_steps], v_bootstrap)
        _rewards = jnp.concatenate(
          [rewards[i:i + td_steps], bootstrap_value[None]], axis=0
        )
        _discounts = jnp.concatenate(
          [jnp.ones((1,)),
           jnp.cumprod(discounts[i:i + td_steps])],
          axis=0,
        )
        return jnp.sum(_rewards * _discounts)

      returns = []
      for i in range(unroll_steps + 1):
        returns.append(n_step_return(i))
      returns = jnp.stack(returns)
      zero_return_mask = jnp.cumprod(1.0 - unroll_trajectory.is_last) == 0.0
      value_target = jax.lax.select(
        zero_return_mask, jnp.zeros_like(returns), returns
      )
      value_target = jax.lax.stop_gradient(value_target)

      # 3. Compute losses.
      reward_loss = rlax.l2_loss(model_out.reward - reward_target)
      reward_loss = jax.lax.select(
        invalid_action_mask,
        jnp.zeros_like(reward_loss),
        reward_loss,
      )
      reward_loss = jnp.mean(reward_loss)

      value_pred = jnp.concatenate([learner_root.value[None], model_out.value])
      value_loss = rlax.l2_loss(value_pred - value_target)
      value_loss = jax.lax.select(
        zero_return_mask, jnp.zeros_like(value_loss), value_loss
      )
      value_loss = jnp.mean(value_loss)

      coeff, adv = one_step_adv_exp(
        networks,
        target_params,
        unroll_trajectory,
        jax.tree_map(lambda x: x[:unroll_steps + 1], target_roots),
        discount_factor,
      )
      coeff = jax.lax.stop_gradient(coeff)
      logp_action = jnp.concatenate(
        [learner_root.logp_action[None], model_out.logp_action]
      )
      policy_loss = -jnp.mean(logp_action * coeff, axis=-1)
      unreachable_action_mask = jnp.cumprod(
        1.0 - trajectory.is_first[1:]
      ) == 0.0
      policy_loss = jax.lax.select(
        unreachable_action_mask[:unroll_steps + 1],
        jnp.zeros_like(policy_loss),
        policy_loss,
      )
      policy_loss = jnp.mean(policy_loss)

      total_loss = (
        policy_loss * policy_coeff + value_loss * value_coeff +
        reward_loss * reward_coeff
      )
      return total_loss, {
        "policy_loss": policy_loss,
        "value_loss": value_loss,
        "reward_loss": reward_loss,
        "coeff": coeff,
        "adv": adv,
      }

    def batch_loss(
      params: Params,
      target_params: Params,
      trajectory: ActorOutput,
      rng_key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.ndarray, Any]:
      bs = len(trajectory.reward)
      rng_keys = jax.random.split(rng_key, bs)
      losses, metrics = jax.vmap(loss, (None, None, 0, 0))(
        params,
        target_params,
        trajectory,
        rng_keys,
      )
      metrics_new = {f"{k}": jnp.mean(v) for k, v in metrics.items()}
      std_keys = [
        "coeff",
        "adv",
      ]
      std_keys = [k for k in std_keys if k in metrics]
      metrics_new.update({f"{k}_std": jnp.std(metrics[k]) for k in std_keys})
      return jnp.mean(losses), metrics_new

    def sgd_step(
      state: TrainingState,
      trajectory: ActorOutput,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

      key, key_loss = jax.random.split(state.key, 2)

      # Compute losses and their gradients.
      grads, log = jax.grad(
        batch_loss, has_aux=True
      )(state.params, state.target_params, trajectory, key_loss)

      # Get optimizer updates and state.
      updates, opt_state = optimizer.update(grads, state.opt_state)

      # Apply optimizer updates to parameters.
      params = optax.apply_updates(state.params, updates)

      steps = state.steps + 1

      # Periodically update target networks.
      target_params = optax.periodic_update(
        params, state.target_params, steps, target_update_period
      )

      new_state = TrainingState(
        params=params,
        target_params=target_params,
        opt_state=opt_state,
        steps=steps,
        key=key,
      )

      log.update(
        {
          "grad_norm": optax.global_norm(grads),
          "update_norm": optax.global_norm(updates),
          "params_norm": optax.global_norm(params),
        }
      )

      return new_state, log

    sgd_step = utils.process_multiple_batches(sgd_step, grad_updates_per_batch)
    self._sgd_step = jax.jit(sgd_step)

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger

    # Create prefetching dataset iterator.
    self._iterator = demonstrations

    # Create the network parameters and copy into the target network parameters.
    key, key_encoder, key_transition, key_prediction = jax.random.split(
      random_key, 4
    )
    initial_encoder_params = networks.encoder_network.init(key_encoder)
    initial_transition_params = networks.transition_network.init(
      key_transition
    )
    initial_prediction_params = networks.prediction_network.init(
      key_prediction
    )
    initial_params = Params(
      encoder_params=initial_encoder_params,
      transition_params=initial_transition_params,
      prediction_params=initial_prediction_params,
    )
    initial_target_params = initial_params

    # Initialize optimizers.
    initial_opt_state = optimizer.init(initial_params)

    # Create initial state.
    self._state = TrainingState(
      params=initial_params,
      target_params=initial_target_params,
      opt_state=initial_opt_state,
      steps=0,
      key=key,
    )

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None

  def step(self):
    timesteps = next(self._iterator)

    self._state, metrics = self._sgd_step(self._state, timesteps)

    # Compute elapsed time.
    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp

    # Increment counts and record the current time
    counts = self._counter.increment(steps=1, walltime=elapsed_time)

    # Attempts to write the logs.
    self._logger.write({**metrics, **counts})

  def get_variables(self, names: List[str]) -> List[networks_lib.Params]:
    # We only expose the variables for the learned policy and critic. The target
    # policy and critic are internal details.
    variables = {
      "encoder": self._state.params.encoder_params,
      "transition": self._state.params.transition_params,
      "prediction": self._state.params.prediction_params,
    }
    return [variables[name] for name in names]

  def save(self) -> TrainingState:
    return self._state

  def restore(self, state: TrainingState):
    self._state = state


class OneStepBuilder(AgentBuilder):

  @staticmethod
  def make_default_configs() -> ConfigDict:
    config = ConfigDict()
    config.encoder_layer_sizes = [256, "ln", 256]
    config.transition_layer_sizes = [256, "ln", 256]
    config.prediction_layer_sizes = [256, "ln", 256, 256]
    config.head_layer_sizes = [32]
    config.discount_factor = 0.99
    config.batch_size = 128
    config.learning_rate = 1e-4
    config.target_update_period = 200
    config.unroll_steps = 5
    config.td_steps = 5
    config.value_coeff = 0.25
    config.policy_coeff = 1.0
    config.reward_coeff = 1.0
    config.activation = jax.nn.elu
    config.trajectory_length = config.unroll_steps + config.td_steps

    config.total_steps = 1_200_000
    config.eval_period = 12_000
    config.save_period = 60_000
    return config

  def make_networks(
    self, env_spec: specs.EnvironmentSpec, **kwargs
  ) -> OSNetworks:
    return make_networks(env_spec, **kwargs)

  def make_learner(self, **kwargs) -> OneStepLearner:
    return OneStepLearner(
      **kwargs,
      optimizer=optax.adam(kwargs["learning_rate"]),
    )

  def make_evaluator(
    self,
    networks: OSNetworks,
    learner: OneStepLearner,
    rng_key: networks_lib.PRNGKey,
  ) -> acme.Actor:

    def evaluator_network(
      params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray
    ) -> jnp.DeviceArray:
      state = networks.encoder_network.apply(params[0], observation)
      prediction = networks.prediction_network.apply(params[1], state)
      dist_params = prediction.policy
      return networks.sample_eval(dist_params, key)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
      evaluator_network
    )
    variable_client = variable_utils.VariableClient(
      learner, ["encoder", "prediction"], device="cpu"
    )
    evaluator = actors.GenericActor(
      actor_core, rng_key, variable_client, backend="cpu"
    )
    return evaluator
