"""Haiku neural network modules."""
import dataclasses
from typing import List, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
from acme import specs
from acme.jax import networks as networks_lib
from acme.jax import utils

from rosmo.types import Array, Forwardable


def get_prediction_head_layers(
  reduced_channels_head: int,
  mlp_layers: List[int],
  num_predictions: int,
  w_init: hk.initializers.Initializer = None,
) -> List[Forwardable]:
  layers = [
    hk.Conv2D(
      reduced_channels_head,
      kernel_shape=1,
      stride=1,
      padding="SAME",
      with_bias=False,
    ),
    hk.LayerNorm(axis=(-3, -2, -1), create_scale=True, create_offset=True),
    jax.nn.relu,
    hk.Flatten(-3),
  ]
  for l in mlp_layers:
    layers.extend(
      [
        hk.Linear(l, with_bias=False),
        hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
        jax.nn.relu,
      ]
    )
  layers.append(hk.Linear(num_predictions, w_init=w_init))
  return layers


def get_ln_relu_layers() -> List[Forwardable]:
  return [
    hk.LayerNorm(axis=(-3, -2, -1), create_scale=True, create_offset=True),
    jax.nn.relu,
  ]


def renormalize_fun(state: Array) -> Array:
  batch_max = jax.lax.stop_gradient(jnp.max(state, axis=(-3, -2, -1)))
  batch_min = jax.lax.stop_gradient(jnp.min(state, axis=(-3, -2, -1)))
  batch_scale = batch_max - batch_min
  batch_scale = jnp.where(batch_scale < 1e-5, batch_scale + 1e-5, batch_scale)
  batch_min = batch_min[..., None, None, None]
  batch_scale = batch_scale[..., None, None, None]
  normalized_state = (state - batch_min) / batch_scale
  return normalized_state


class ResConvBlock(hk.Module):
  """A residual convolutional block in pre-activation style."""

  def __init__(
    self,
    channels: int,
    stride: int,
    use_projection: bool,
    name: str = "res_conv_block",
  ):
    super(ResConvBlock, self).__init__(name=name)
    self._use_projection = use_projection
    if use_projection:
      self._proj_conv = hk.Conv2D(
        channels,
        kernel_shape=3,
        stride=stride,
        padding="SAME",
        with_bias=False
      )
    self._conv_0 = hk.Conv2D(
      channels, kernel_shape=3, stride=stride, padding="SAME", with_bias=False
    )
    self._ln_0 = hk.LayerNorm(
      axis=(-3, -2, -1), create_scale=True, create_offset=True
    )
    self._conv_1 = hk.Conv2D(
      channels, kernel_shape=3, stride=1, padding="SAME", with_bias=False
    )
    self._ln_1 = hk.LayerNorm(
      axis=(-3, -2, -1), create_scale=True, create_offset=True
    )

  def __call__(self, x: Array) -> Array:
    # NOTE: Using LayerNorm is fine (https://arxiv.org/pdf/2104.06294.pdf Appendix A).
    shortcut = out = x
    out = self._ln_0(out)
    out = jax.nn.relu(out)
    if self._use_projection:
      shortcut = self._proj_conv(out)
    out = hk.Sequential(
      [
        self._conv_0,
        self._ln_1,
        jax.nn.relu,
        self._conv_1,
      ]
    )(
      out
    )
    return shortcut + out


class Representation(hk.Module):
  """Representation encoding module."""

  def __init__(
    self,
    channels: int,
    num_blocks: int,
    channel_scaling: float = -1.0,
    name: str = "representation",
  ):
    super(Representation, self).__init__(name=name)
    self._channels = channels
    self._num_blocks = num_blocks

    if channel_scaling > 0:
      # Scale representational capacity.
      self._hidden_channels = int(self._channels * channel_scaling)
    else:
      self._hidden_channels = 0

  def __call__(self, observations: Array) -> Array:
    # 1. Downsampling.
    torso = [
      lambda x: x / 255.0,
      hk.Conv2D(
        self._channels // 2,
        kernel_shape=3,
        stride=2,
        padding="SAME",
        with_bias=False,
      ),
    ]
    torso.extend(
      [
        ResConvBlock(self._channels // 2, stride=1, use_projection=False)
        for _ in range(1)
      ]
    )
    torso.append(ResConvBlock(self._channels, stride=2, use_projection=True))
    torso.extend(
      [
        ResConvBlock(self._channels, stride=1, use_projection=False)
        for _ in range(1)
      ]
    )
    torso.append(
      hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding="SAME")
    )
    torso.extend(
      [
        ResConvBlock(self._channels, stride=1, use_projection=False)
        for _ in range(1)
      ]
    )
    torso.append(
      hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding="SAME")
    )

    # 2. Encoding.
    torso.extend(
      [
        ResConvBlock(self._channels, stride=1, use_projection=False)
        for _ in range(self._num_blocks)
      ]
    )
    if self._hidden_channels:
      torso.pop()
      torso.append(
        ResConvBlock(self._hidden_channels, stride=1, use_projection=True)
      )
    return hk.Sequential(torso)(observations)


class Transition(hk.Module):
  """Dynamics transition module."""

  def __init__(
    self,
    channels: int,
    num_blocks: int,
    name: str = "transition",
  ):
    super(Transition, self).__init__(name=name)
    self._channels = channels
    self._num_blocks = num_blocks

  def __call__(self, encoded_action: Array, prev_state: Array) -> Array:
    channels = prev_state.shape[-1]
    shortcut = prev_state

    prev_state = jax.nn.relu(
      hk.LayerNorm(axis=(-3, -2, -1), create_scale=True,
                   create_offset=True)(prev_state)
    )

    x_and_h = jnp.concatenate([prev_state, encoded_action], axis=-1)
    out = hk.Conv2D(
      self._channels,
      kernel_shape=3,
      stride=1,
      padding="SAME",
      with_bias=False,
    )(
      x_and_h
    )
    out += shortcut  # Residual link to maintain recurrent info flow.

    res_layers = [
      ResConvBlock(channels, stride=1, use_projection=False)
      for _ in range(self._num_blocks)
    ]
    out = hk.Sequential(res_layers)(out)
    return out


class Prediction(hk.Module):
  """Policy, value and reward prediction module."""

  def __init__(
    self,
    num_blocks,
    num_actions: int,
    num_bins: int,
    channel: int,
    fc_layers_reward: List[int],
    fc_layers_value: List[int],
    fc_layers_policy: List[int],
    output_init_scale: float,
    name="prediction",
  ):
    super(Prediction, self).__init__(name=name)
    self._num_blocks = num_blocks
    self._num_actions = num_actions
    self._num_bins = num_bins
    self._channel = channel
    self._fc_layers_reward = fc_layers_reward
    self._fc_layers_value = fc_layers_value
    self._fc_layers_policy = fc_layers_policy
    self._output_init_scale = output_init_scale

  def __call__(self, states: Array) -> Tuple[Array, Array, Array]:
    output_init = hk.initializers.VarianceScaling(
      scale=self._output_init_scale
    )
    reward_head, value_head, policy_head = [], [], []

    # Add LN+Relu due to pre-activation.
    reward_head.extend(get_ln_relu_layers())
    value_head.extend(get_ln_relu_layers())
    policy_head.extend(get_ln_relu_layers())

    reward_head.extend(
      get_prediction_head_layers(
        self._channel,
        self._fc_layers_reward,
        self._num_bins,
        output_init,
      )
    )
    reward_logits = hk.Sequential(reward_head)(states)

    res_layers = [
      ResConvBlock(states.shape[-1], stride=1, use_projection=False)
      for _ in range(self._num_blocks)
    ]
    out = hk.Sequential(res_layers)(states)

    value_head.extend(
      get_prediction_head_layers(
        self._channel,
        self._fc_layers_value,
        self._num_bins,
        output_init,
      )
    )
    value_logits = hk.Sequential(value_head)(out)

    policy_head.extend(
      get_prediction_head_layers(
        self._channel,
        self._fc_layers_policy,
        self._num_actions,
        output_init,
      )
    )
    policy_logits = hk.Sequential(policy_head)(out)
    return policy_logits, reward_logits, value_logits


@dataclasses.dataclass
class Networks:
  representation_network: networks_lib.FeedForwardNetwork
  transition_network: networks_lib.FeedForwardNetwork
  prediction_network: networks_lib.FeedForwardNetwork

  environment_specs: specs.EnvironmentSpec


def make_networks(
  env_spec: specs.EnvironmentSpec,
  channels: int,
  num_bins: int,
  output_init_scale: float,
  channel_scaling: float,
  blocks_representation: int,
  blocks_prediction: int,
  blocks_transition: int,
  reduced_channels_head: int,
  fc_layers_reward: List[int],
  fc_layers_value: List[int],
  fc_layers_policy: List[int],
  renormalize: bool = False,
) -> Networks:

  action_space_size = env_spec.actions.num_values

  def _representation_fun(observations: Array) -> Array:
    network = Representation(channels, blocks_representation, channel_scaling)
    state = network(observations)
    if renormalize:
      state = renormalize_fun(state)
    return state

  representation = hk.without_apply_rng(hk.transform(_representation_fun))

  if channel_scaling > 0:
    hidden_channels = int(channels * channel_scaling)
  else:
    hidden_channels = channels

  def _transition_fun(action: Array, state: Array) -> Array:
    # NOTE Biased plane is much worse than one-hot action (MsPacman).
    action = hk.one_hot(action, action_space_size)[None, :]
    encoded_action = jnp.broadcast_to(
      action, state.shape[:-1] + action.shape[-1:]
    )

    network = Transition(hidden_channels, blocks_transition)
    next_state = network(encoded_action, state)
    if renormalize:
      next_state = renormalize_fun(next_state)
    return next_state

  transition = hk.without_apply_rng(hk.transform(_transition_fun))
  prediction = hk.without_apply_rng(
    hk.transform(
      lambda states: Prediction(
        blocks_prediction,
        action_space_size,
        num_bins,
        reduced_channels_head,
        fc_layers_reward,
        fc_layers_value,
        fc_layers_policy,
        output_init_scale,
      )(states)
    )
  )

  dummy_action = jnp.array([env_spec.actions.generate_value()])
  dummy_obs = utils.zeros_like(env_spec.observations)

  # dummy_action = utils.add_batch_dim(dummy_action)
  # dummy_obs = utils.add_batch_dim(dummy_obs)

  def _dummy_state(key):
    encoder_params = representation.init(key, dummy_obs)
    dummy_state = representation.apply(encoder_params, dummy_obs)
    return dummy_state

  return Networks(
    representation_network=networks_lib.FeedForwardNetwork(
      lambda key: representation.init(key, dummy_obs), representation.apply
    ),
    transition_network=networks_lib.FeedForwardNetwork(
      lambda key: transition.init(key, dummy_action, _dummy_state(key)),
      transition.apply,
    ),
    prediction_network=networks_lib.FeedForwardNetwork(
      lambda key: prediction.init(key, _dummy_state(key)),
      prediction.apply,
    ),
    environment_specs=env_spec,
  )
