"""Evaluating actor."""
from typing import Dict

import acme
import chex
import dm_env
import jax
import numpy as np
import rlax
from absl import logging
from acme import types

from rosmo.agent.muzero.learning import (
  mcts_improve,
  one_step_improve,
  root_unroll,
)
from rosmo.agent.muzero.network import Networks
from rosmo.agent.muzero.types import AgentOutput, Params
from rosmo.types import ActorOutput


class MuZeroEvalActor(acme.core.Actor):
  """MuZero evaluation actor."""

  def __init__(
    self,
    networks: Networks,
    config: Dict,
  ) -> None:
    self._networks = networks
    self._discount_factor = config["discount_factor"]
    self._clipping_threshold = config["clipping_threshold"]
    self._environment_specs = networks.environment_specs
    self._rng_key = jax.random.PRNGKey(config["seed"])
    self._random_action = False
    self._params = None

    use_bc = config["use_bc"]
    use_qf = config["use_qf"]
    num_simulations = config["num_simulations"]
    sampling_method = config["sampling"]
    search_depth = config["search_depth"]
    improvement_op = config["improvement_op"]
    num_bins = config["num_bins"]
    behavior = config.get("behavior", False)

    if use_bc or use_qf or "only" in behavior:

      def root_step(
        rng_key: chex.PRNGKey,
        params: Params,
        timesteps: ActorOutput,
        temperature: float,
      ):
        logging.info("[Actor] Using policy to act.")
        # Policy head acting.
        trajectory: ActorOutput = jax.tree_map(
          lambda t: t[None], timesteps
        )  # Add a dummy time dimension.
        state = networks.representation_network.apply(
          params.representation, trajectory.observation
        )
        agent_out = root_unroll(self._networks, params, num_bins, state)
        agent_out: AgentOutput = jax.tree_map(
          lambda t: t.squeeze(axis=0), agent_out
        )  # Squeeze the dummy time dimension.
        pi = jax.nn.softmax(agent_out.policy_logits)
        action = rlax.categorical_sample(rng_key, pi)
        return action, agent_out

    else:

      def root_step(
        rng_key: chex.PRNGKey,
        params: Params,
        timesteps: ActorOutput,
        temperature: float,
      ):
        # Model one-step acting.
        trajectory = jax.tree_map(
          lambda t: t[None], timesteps
        )  # Add a dummy time dimension.
        state = networks.representation_network.apply(
          params.representation, trajectory.observation
        )
        agent_out: AgentOutput = root_unroll(
          self._networks, params, num_bins, state
        )
        improve_key, sample_key = jax.random.split(rng_key)
        if improvement_op in ["mcts", "mcts_mpo"]:
          if improvement_op == "mcts":
            logging.info("[Actor] Using mcts improv to act.")
            mcts_out = mcts_improve(
              self._networks,
              improve_key,
              params,
              num_bins,
              agent_out,
              self._discount_factor,
              num_simulations,
              search_depth,
            )
            action = mcts_out.action
          else:
            logging.info("[Actor] Using policy to act.")
            action = rlax.categorical_sample(
              sample_key, jax.nn.softmax(agent_out.policy_logits)
            )
        else:
          agent_out: AgentOutput = jax.tree_map(
            lambda t: t.squeeze(axis=0), agent_out
          )  # Squeeze the dummy time dimension.
          if num_simulations < 0 or sampling_method == "exact":
            logging.info("[Actor] Using onestep improv to act.")
            improved_policy, _ = one_step_improve(
              self._networks,
              improve_key,
              params,
              agent_out,
              num_bins,
              None,
              self._discount_factor,
              self._clipping_threshold,
              False,
              sampling_method,
            )
          else:
            logging.info("[Actor] Using policy to act.")
            improved_policy = jax.nn.softmax(agent_out.policy_logits)
          action = rlax.categorical_sample(sample_key, improved_policy)
        return action, agent_out

    def batch_step(
      rng_key: chex.PRNGKey,
      params: Params,
      timesteps: ActorOutput,
      temperature: float,
    ):
      batch_size = timesteps.reward.shape[0]
      rng_key, step_key = jax.random.split(rng_key)
      step_keys = jax.random.split(step_key, batch_size)
      batch_root_step = jax.vmap(root_step, (0, None, 0, None))
      actions, agent_out = batch_root_step(
        step_keys, params, timesteps, temperature
      )
      return rng_key, actions, agent_out

    self._agent_step = jax.jit(batch_step)

  def select_action(self, observation: types.NestedArray) -> types.NestedArray:
    if self._random_action:
      return np.random.randint(
        0, self._environment_specs.actions.num_values, [1]
      )
    batched_timestep = jax.tree_map(
      lambda t: t[None], jax.device_put(self._timestep)
    )
    self._rng_key, action, _ = self._agent_step(
      self._rng_key, self._params, batched_timestep, 1.0
    )
    action = jax.device_get(action).item()
    return action

  def observe_first(self, timestep: dm_env.TimeStep):
    assert self._params is not None, "params not initialized"
    self._timestep = ActorOutput(
      action=np.zeros((1,), dtype=np.int32),
      reward=np.zeros((1,), dtype=np.float32),
      observation=timestep.observation,
      is_first=np.ones((1,), dtype=np.float32),
      is_last=np.zeros((1,), dtype=np.float32),
    )

  def observe(
    self,
    action: types.NestedArray,
    next_timestep: dm_env.TimeStep,
  ):
    self._timestep = ActorOutput(
      action=action,
      reward=next_timestep.reward,
      observation=next_timestep.observation,
      is_first=next_timestep.first(),  # previous last = this first.
      is_last=next_timestep.last(),
    )

  def update(self, wait: bool = False):
    pass

  def update_params(self, params: Params):
    self._params = params
