from typing import Dict

import acme
import dm_env
import jax
import numpy as np
from acme import types

from rosmo.agent.crr_discrete.learning import Params
from rosmo.agent.crr_discrete.network import CRRNetworks
from rosmo.types import ActorOutput


class CRREvalActor(acme.core.Actor):
  """CRR evaluation actor."""

  def __init__(
    self,
    networks: CRRNetworks,
    config: Dict,
  ) -> None:
    self._networks = networks
    self._rng_key = jax.random.PRNGKey(config["seed"])
    self._params = None

    def agent_step(key, params: Params, observation):
      key, step_key = jax.random.split(key)
      torso = networks.torso_network.apply(params.torso_params, observation)
      dist_params = networks.policy_network.apply(params.policy_params, torso)
      return key, networks.sample(dist_params, step_key)

    self._agent_step = jax.jit(agent_step)

  def select_action(self, observation: types.NestedArray) -> types.NestedArray:

    batched_timestep = jax.tree_map(
      lambda t: t[None], jax.device_put(self._timestep)
    )  # add batch dimension;
    self._rng_key, action = self._agent_step(
      self._rng_key, self._params, batched_timestep.observation
    )
    action = jax.device_get(action)
    return action.item()

  def observe_first(self, timestep: dm_env.TimeStep):
    assert self._params is not None, "params not initialized"
    self._timestep = ActorOutput(
      action=np.zeros((1,), dtype=np.int32),
      reward=np.zeros((1,), dtype=np.float32),
      observation=timestep.observation,
      is_first=np.ones((1,), dtype=np.float32),
      is_last=np.zeros((1,), dtype=np.float32),
    )

  def observe(
    self,
    action: types.NestedArray,
    next_timestep: dm_env.TimeStep,
  ):
    self._timestep = ActorOutput(
      action=action,
      reward=next_timestep.reward,
      observation=next_timestep.observation,
      is_first=next_timestep.first(),  # previous last = this first.
      is_last=next_timestep.last(),
    )

  def update(self, wait: bool = False):
    pass

  def update_params(self, params: Params):
    self._params = params
