from typing import Dict

import acme
import dm_env
import jax
import jax.numpy as jnp
import numpy as np
from absl import logging
from acme import types

from rosmo.agent.cql_discrete.learning import Params
from rosmo.agent.cql_discrete.network import CQLNetworks
from rosmo.types import ActorOutput


class CQLEvalActor(acme.core.Actor):
  """CQL evaluation actor."""

  def __init__(
    self,
    networks: CQLNetworks,
    num_actions: int,
    config: Dict,
  ) -> None:
    self._networks = networks
    self._rng_key = jax.random.PRNGKey(config["seed"])
    self._params = None
    self._epsilon_eval = config["epsilon_eval"]
    self._num_actions = num_actions

    def agent_step(params: Params, observation):
      network_out = networks.q_network.apply(params.q_params, observation)
      action = jnp.argmax(network_out["q_value"], axis=1)[0]
      return action

    self._agent_step = jax.jit(agent_step)

    def batch_agent_step(params: Params, observation):
      network_out = networks.q_network.apply(params.q_params, observation)
      actions = jnp.argmax(network_out["q_value"], axis=1)
      return actions

    self._agent_step = jax.jit(agent_step)
    self._batch_agent_step = jax.jit(batch_agent_step)

    logging.info(f"[Actor] Using epsilon-greedy ({self._epsilon_eval}).")

  def select_action(self, observation: types.NestedArray) -> types.NestedArray:
    batched_timestep = jax.tree_map(
      lambda t: t[None], jax.device_put(self._timestep)
    )
    if np.random.uniform() <= self._epsilon_eval:
      action = np.random.randint(low=0, high=self._num_actions)
    else:
      action = self._agent_step(self._params, batched_timestep.observation)
      action = jax.device_get(action)
    return action

  def batch_select_action(self, observation: types.NestedArray) -> types.NestedArray:
    batch_size = observation.shape[0]
    if np.random.uniform() <= self._epsilon_eval:
      action = np.random.randint(low=0, high=self._num_actions, size=(batch_size,))
    else:
      observation = jax.device_put(observation)
      action = self._batch_agent_step(self._params, observation)
      action = jax.device_get(action)
    return action

  def observe_first(self, timestep: dm_env.TimeStep):
    assert self._params is not None, "params not initialized"
    self._timestep = ActorOutput(
      action=np.zeros((1,), dtype=np.int32),
      reward=np.zeros((1,), dtype=np.float32),
      observation=timestep.observation,
      is_first=np.ones((1,), dtype=np.float32),
      is_last=np.zeros((1,), dtype=np.float32),
    )

  def observe(
    self,
    action: types.NestedArray,
    next_timestep: dm_env.TimeStep,
  ):
    self._timestep = ActorOutput(
      action=action,
      reward=next_timestep.reward,
      observation=next_timestep.observation,
      is_first=next_timestep.first(),  # previous last = this first.
      is_last=next_timestep.last(),
    )

  def update(self, wait: bool = False):
    pass

  def update_params(self, params: Params):
    self._params = params
