import dataclasses
from typing import Callable, List, Optional, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from acme import specs
from acme.jax import networks as networks_lib
from acme.jax import utils

from rosmo.agent.muzero.network import (
  Representation,
  ResConvBlock,
  get_ln_relu_layers,
  get_prediction_head_layers,
)


@dataclasses.dataclass
class CRRNetworks:
  """Network and pure functions for the CRR agent.."""

  torso_network: networks_lib.FeedForwardNetwork
  policy_network: networks_lib.FeedForwardNetwork
  critic_network: networks_lib.FeedForwardNetwork
  log_prob: networks_lib.LogProbFn
  sample: networks_lib.SampleFn
  sample_eval: networks_lib.SampleFn


def get_prediction_head(
  num_preds: int,
  channels: int,
  num_blocks: int,
  reduced_channel: int,
  mlp_layers: List[int],
  output_init_scale: Optional[float] = None,
  use_projection: bool = False,
):
  if output_init_scale is not None:
    output_init = hk.initializers.VarianceScaling(scale=output_init_scale)
  else:
    output_init = output_init_scale
  head = [ResConvBlock(channels, stride=1, use_projection=use_projection)]
  head.extend(
    [
      ResConvBlock(channels, stride=1, use_projection=False)
      for _ in range(num_blocks - 1)
    ]
  )

  head.extend(get_ln_relu_layers())

  head.extend(
    get_prediction_head_layers(
      reduced_channel, mlp_layers, num_preds, output_init
    )
  )
  return head


def make_networks(
  env_spec: specs.EnvironmentSpec,
  channels: int,
  num_bins: int,
  output_init_scale: float,
  blocks_torso: int,
  blocks_policy: int,
  blocks_value: int,
  reduced_channels_head: int,
  fc_layers_policy: List[int],
  fc_layers_value: List[int],
) -> CRRNetworks:
  """Creates networks used by the agent."""

  # Create dummy observations and actions to create network parameters.
  dummy_action = jnp.array(env_spec.actions.generate_value())
  dummy_obs = utils.zeros_like(env_spec.observations)
  dummy_action = utils.add_batch_dim(dummy_action)
  dummy_obs = utils.add_batch_dim(dummy_obs)

  def _torso_fn(obs: jnp.ndarray) -> jnp.ndarray:
    network = Representation(channels, blocks_torso)
    return network(obs)

  torso = hk.without_apply_rng(hk.transform(_torso_fn))
  torso_network = networks_lib.FeedForwardNetwork(
    lambda key: torso.init(key, dummy_obs), torso.apply
  )

  def _dummy_state(key):
    encoder_params = torso.init(key, dummy_obs)
    dummy_state = torso.apply(encoder_params, dummy_obs)
    return dummy_state

  def _policy_fn(torso_out: jnp.ndarray) -> jnp.ndarray:
    _prediction_layers = get_prediction_head(
      env_spec.actions.num_values,
      channels,
      blocks_policy,
      reduced_channels_head,
      fc_layers_policy,
    )[:-1]  # Replace the last Linear with Categorical.
    _prediction_layers.append(
      networks_lib.CategoricalHead(env_spec.actions.num_values)
    )
    network = hk.Sequential(_prediction_layers)
    return network(torso_out)

  policy = hk.without_apply_rng(hk.transform(_policy_fn))
  policy_network = networks_lib.FeedForwardNetwork(
    lambda key: policy.init(key, _dummy_state(key)), policy.apply
  )

  def _critic_fn(torso_out: jnp.ndarray, action: jnp.ndarray):
    one_hot_action = hk.one_hot(action, env_spec.actions.num_values)
    one_hot_action = one_hot_action[:, None, None, :]
    one_hot_action = jnp.broadcast_to(
      one_hot_action, torso_out.shape[:-1] + one_hot_action.shape[-1:]
    )
    network = hk.Sequential(
      get_prediction_head(
        num_bins,
        channels,
        blocks_value,
        reduced_channels_head,
        fc_layers_value,
        output_init_scale,
        True,
      )
    )
    value_logits = network(
      jnp.concatenate([torso_out, one_hot_action], axis=-1)
    )
    return value_logits

  critic = hk.without_apply_rng(hk.transform(_critic_fn))
  critic_network = networks_lib.FeedForwardNetwork(
    lambda key: critic.init(key, _dummy_state(key), dummy_action), critic.apply
  )

  return CRRNetworks(
    torso_network=torso_network,
    policy_network=policy_network,
    critic_network=critic_network,
    log_prob=lambda params, actions: params.log_prob(actions),
    sample=lambda params, key: params.sample(seed=key),
    sample_eval=lambda params, key: params.mode(),
  )


def make_networks_continuous(
  spec: specs.EnvironmentSpec,
  policy_layer_sizes: Tuple[int, ...] = (256, 256),
  critic_layer_sizes: Tuple[int, ...] = (256, 256),
  activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
) -> CRRNetworks:
  """Creates networks used by the agent."""
  num_actions = np.prod(spec.actions.shape, dtype=int)

  # Create dummy observations and actions to create network parameters.
  dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
  dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

  def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray:
    network = hk.Sequential(
      [
        hk.nets.MLP(
          list(policy_layer_sizes),
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
          activate_final=True,
        ),
        networks_lib.NormalTanhDistribution(num_actions),
      ]
    )
    return network(obs)

  policy = hk.without_apply_rng(hk.transform(_policy_fn))
  policy_network = networks_lib.FeedForwardNetwork(
    lambda key: policy.init(key, dummy_obs), policy.apply
  )

  def _critic_fn(obs, action):
    network = hk.Sequential(
      [
        hk.nets.MLP(
          list(critic_layer_sizes) + [1],
          w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
          activation=activation,
        ),
      ]
    )
    data = jnp.concatenate([obs, action], axis=-1)
    return network(data)

  critic = hk.without_apply_rng(hk.transform(_critic_fn))
  critic_network = networks_lib.FeedForwardNetwork(
    lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply
  )

  return CRRNetworks(
    torso_network=None,
    policy_network=policy_network,
    critic_network=critic_network,
    log_prob=lambda params, actions: params.log_prob(actions),
    sample=lambda params, key: params.sample(seed=key),
    sample_eval=lambda params, key: params.mode(),
  )
