from typing import Any, Dict

import numpy as np
import reverb
import tensorflow as tf
from acme import types

from rosmo.types import ActorOutput


def transform_timesteps(steps: Dict[str, np.ndarray]) -> ActorOutput:
  return ActorOutput(
    observation=steps["observation"],
    reward=steps["reward"],
    is_first=steps["is_first"],
    is_last=steps["is_last"],
    action=steps["action"],
  )


def decode_and_stack_frames(
  episode: Dict[str, Any], num_stack_frames: int = 4
) -> Dict[str, Any]:
  """Decode PNGs from an episode and stack them.

    Args:
        episode: Episode data containtain N steps of transitions.
        num_stack_frames: Number of frames to be stacked (k). Defaults to 4.

    Returns:
        new episode data with stacked images in a (N, 84, 84, k) uint8 Tensor.
    """
  episode_png = episode["steps"]["observation"]
  _frames = [
    tf.image.decode_png(png_str, channels=1).numpy() for png_str in episode_png
  ]
  _frames = np.stack(_frames)
  _stacks = []
  for i in reversed(range(num_stack_frames)):
    if i == 0:
      _stacks.append(_frames)
    else:
      empty = np.concatenate(
        [np.zeros(_frames.shape[1:], dtype=np.uint8)[None]] * i, axis=0
      )
      _stacks.append(np.concatenate([empty, _frames[i:]], axis=0))
  _frames = np.concatenate(_stacks, axis=-1)
  episode["steps"]["observation"] = _frames
  return episode


def discard_extras(sample: reverb.ReplaySample):
  return sample._replace(data=sample.data[:5])


def reverb_to_transition(sample: reverb.ReplaySample):
  data = sample.data
  return types.Transition(
    observation=data[0],
    action=data[1],
    reward=data[2],
    discount=data[3],
    next_observation=data[4],
    # NOTE: we omit the next_action and extras data
  )
