"""Utilities for preprocessing episodes into batches of transitions."""
from typing import Dict, List, Optional
import tensorflow as tf
import jax
import jax.numpy as jnp

from imitation_pretraining.data_utils import Batch


def encode_observations(
    batch: Batch,
    encoder,
    pop_pixels: bool = True,
    encode_next_obs: bool = True,
) -> Batch:
    """Encode observation in batch using encoder_fn."""
    if type(encoder) == tuple:  # Hack to encode target as action
        encoder, target_encoder = encoder
        if target_encoder is not None:
            target = target_encoder.encode(batch.observation)
            batch = batch._replace(action=target)

    if encoder is not None:
        embedding = encoder.encode(batch.observation)
        obs = dict(batch.observation, embedding=embedding)
        if pop_pixels:
            obs.pop("pixels")
    else:
        obs = batch.observation

    if encode_next_obs and encoder is not None:
        next_embedding = encoder.encode(batch.next_observation)
        next_obs = dict(batch.next_observation, embedding=next_embedding)
        if pop_pixels:
            next_obs.pop("pixels")
    else:
        next_obs = batch.next_observation

    return batch._replace(observation=obs, next_observation=next_obs)


# Prevent tf from using GPU since we are training with JAX.
tf.config.set_visible_devices([], "GPU")


def stack_history_tf(obs: Dict, history: int) -> Dict:
    """Stack observations in history."""
    # Postpend with repeated first obs
    obs = {
        k: tf.concat([v, tf.repeat(v[:1], history - 1, axis=0)], 0)
        for k, v in obs.items()
    }
    # Stack history into last dim:  (time, ...) -> (time, ..., history)
    stacked_obs = {}
    for k, v in obs.items():
        stacked_obs[k] = tf.stack(
            [tf.roll(v, shift=i, axis=0) for i in range(history)][::-1], axis=-1
        )
    # Remove suffix of wrapped around values
    stacked_obs = {k: v[: -(history - 1)] for k, v in stacked_obs.items()}
    return stacked_obs


def stack_actions_tf(action: tf.Tensor, nstep: int) -> tf.Tensor:
    """Stack future nstep actions along new axis."""
    # Prepend with repeated last action
    action = tf.concat([tf.repeat(action[-1:], nstep - 1, 0), action], 0)
    # Stack history into last dim: (time, ...) -> (time, ..., nstep)
    rolled_action = [tf.roll(action, shift=i, axis=0) for i in range(0, -nstep, -1)]
    stacked_action = tf.stack(rolled_action, axis=-1)
    # Remove prefix of wrapped around values
    stacked_action = stacked_action[nstep - 1 :]
    return stacked_action


def sliding_average_tf(tensor, window_size):
    # Create a kernel of ones
    kernel = tf.ones(shape=(window_size, 1, 1), dtype=tf.float32) / window_size
    # Apply the convolution using the kernel
    averaged_tensor = tf.nn.conv1d(
        tf.transpose(tensor)[:, :, tf.newaxis], filters=kernel, stride=1, padding="SAME"
    )
    # Remove the extra dimension
    averaged_tensor = tf.transpose(tf.squeeze(averaged_tensor))
    return averaged_tensor


def process_episode_tf(
    episode: Dict,
    gamma: jnp.float32,
    nstep: int,
    history: int = 1,
    average_actions: Optional[int] = None,
    include_goal_pixels: bool = True,
) -> Batch:
    """Process an episode into a dict of tensors for training."""
    # Select observations and actions. Note that the first step contains
    # the initial obs, but dummy action/reward/discount.
    obs = {k: v[:-1] for k, v in episode["observation"].items()}
    if include_goal_pixels:
        goal_pixels = episode["observation"]["pixels"][-1:]

    action = episode["action"][1:]
    if average_actions is not None:
        action = sliding_average_tf(action, window_size=average_actions)
    if nstep > 1:  # Stack future actions.
        action = stack_actions_tf(action, nstep)

    if nstep == 0:  # Allow overload of nstep for static methods
        next_obs = obs
    else:
        next_obs = {
            k: tf.concat([v[nstep:], tf.repeat(v[-1:], nstep - 1, axis=0)], 0)
            for k, v in episode["observation"].items()
        }  # Repeat last obs nstep times at end of episode.

    # Stack observation history.
    if history > 1:
        obs = stack_history_tf(obs, history)
        next_obs = stack_history_tf(next_obs, history)

    # Add goal pixels to obs only
    if include_goal_pixels:
        obs["goal_pixels"] = tf.repeat(goal_pixels, len(obs["pixels"]), axis=0)
        next_obs["goal_pixels"] = tf.repeat(
            goal_pixels, len(next_obs["pixels"]), axis=0
        )

    # Compute n step returns
    padded_reward = tf.concat(
        [episode["reward"], tf.zeros_like(episode["reward"][:nstep])], 0
    )
    padded_discount = tf.concat(
        [episode["discount"], tf.zeros_like(episode["discount"][:nstep])], 0
    )
    reward = tf.zeros_like(episode["reward"][1:])
    discount = tf.ones_like(episode["discount"][1:])
    ep_len = len(reward)
    for i in range(1, nstep + 1):
        reward += discount * padded_reward[i : i + ep_len]
        discount *= padded_discount[i : i + ep_len] * gamma

    return Batch(obs, action, reward, discount, next_obs)


# TODO: remove list option for nstep?
def list_process_episode_tf(
    episode: Dict,
    gamma: jnp.float32,
    nstep: List[int],
    history: int = 1,
    include_goal_pixels: bool = False,
):
    """Function to produce tf episodes of varying nsteps"""
    if not isinstance(nstep, list):
        nstep = [nstep]
    batch_list = []
    for n in nstep:
        batch = process_episode_tf(episode, gamma, n, history, include_goal_pixels)
        if n == 0:  # Add zero action for nstep=0
            batch = batch._replace(action=tf.zeros_like(batch.action))
        if len(batch.action.shape) == 2:  # Ensure all actions have 3 axes
            batch = batch._replace(action=tf.expand_dims(batch.action, 2))
        # Pad actions to always have maximum length
        pad = max(max(nstep) - batch.action.shape[-1], 0)
        batch = batch._replace(action=tf.pad(batch.action, [[0, 0], [0, 0], [0, pad]]))
        batch_list.append(batch)

    # Concatenate all batches into one big batch of transitions
    concat_batch = Batch(
        observation={
            k: tf.concat([b.observation[k] for b in batch_list], 0)
            for k in batch_list[0].observation
        },
        action=tf.concat([b.action for b in batch_list], 0),
        reward=tf.concat([b.reward for b in batch_list], 0),
        discount=tf.concat([b.discount for b in batch_list], 0),
        next_observation={
            k: tf.concat([b.next_observation[k] for b in batch_list], 0)
            for k in batch_list[0].next_observation
        },
    )
    return concat_batch


# TODO: Remove jax processing?


def stack_history_jax(obs: Dict, history: int) -> Dict:
    """Stack observations in history."""
    # Postpend with repeated first obs
    obs = jax.tree_util.tree_map(
        lambda x: jnp.concatenate([x, jnp.repeat(x[:1], history - 1, 0)], 0), obs
    )
    # Stack history into last dim: (time, ...) -> (time, ..., history,)
    def stack_and_roll(x):
        return jnp.stack(
            [jnp.roll(x, shift=i, axis=0) for i in range(history)], axis=-1
        )

    stacked_obs = jax.tree_util.tree_map(stack_and_roll, obs)

    # Remove suffix of wrapped around values
    stacked_obs = jax.tree_util.tree_map(lambda x: x[: -(history - 1)], stacked_obs)
    return stacked_obs


def stack_actions_jax(action: jnp.ndarray, nstep: int) -> jnp.ndarray:
    """Stack future nstep actions along new axis."""
    # Prepend with repeated last action
    action = jnp.concatenate([jnp.repeat(action[-1:], nstep - 1, 0), action], 0)
    # Stack history into last dim: (time, ...) -> (time, ..., nstep,)
    rolled_action = [jnp.roll(action, shift=i, axis=0) for i in range(0, -nstep, -1)]
    stacked_action = jnp.stack(rolled_action, axis=-1)
    # Remove prefix of wrapped around values
    stacked_action = stacked_action[nstep - 1 :]
    return stacked_action


def process_episode_jax(
    episode: Dict,
    gamma: jnp.float32,
    nstep: int,
    encoder,
    history: int = 1,
) -> Batch:
    """Process an episode into a batch of transitions."""
    # First step contains true obs and dummy action/reward/discount.
    obs = jax.tree_util.tree_map(lambda x: x[:-1], episode["observation"])

    action = episode["action"][1:]
    if nstep > 1:  # Stack future actions.
        action = stack_actions_jax(action, nstep)

    if nstep == 0:  # Allow overload of nstep for static methods
        next_obs = obs
    else:
        next_obs = jax.tree_util.tree_map(
            lambda x: jnp.concatenate([x[nstep:], jnp.repeat(x[-1:], nstep - 1, 0)]),
            episode["observation"],
        )  # Repeat last obs nstep times at end of episode.

    # Stack observation history.
    if history > 1:
        obs = stack_history_jax(obs, history)
        next_obs = stack_history_jax(next_obs, history)

    # Compute n step returns.
    padded_reward = jnp.concatenate(
        [episode["reward"], jnp.zeros_like(episode["reward"][:nstep])]
    )
    padded_discount = jnp.concatenate(
        [episode["discount"], jnp.zeros_like(episode["discount"][:nstep])]
    )
    reward = jnp.zeros_like(episode["reward"][1:])
    discount = jnp.ones_like(episode["discount"][1:])
    ep_len = len(reward)
    for i in range(1, nstep + 1):
        reward += discount * padded_reward[i : i + ep_len]
        discount *= padded_discount[i : i + ep_len] * gamma
    batch = Batch(obs, action, reward, discount, next_obs)

    # Encode observations.
    if encoder is not None:
        batch = encode_observations(batch, encoder)
    return batch
