import dataclasses
from typing import Dict, List, Optional

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from absl import logging
from acme import specs
from acme.jax import networks as networks_lib
from acme.jax import utils

from rosmo.agent.crr_discrete.network import get_prediction_head
from rosmo.agent.muzero.network import Representation
from rosmo.types import Array


@dataclasses.dataclass
class CQLNetworks:
  """Networks for the CQL agent."""

  q_network: networks_lib.FeedForwardNetwork

  environment_specs: specs.EnvironmentSpec


class QuantileNetwork(hk.Module):
  """Quantile network as in CQL original implementation.
    Ref: https://github.com/aviralkumar2907/CQL/
      file: atari/batch_rl/multi_head/atari_helpers.py
    """

  def __init__(
    self,
    num_actions: int,
    num_atoms: int,
    name: Optional[str] = "quantile_network",
  ):
    super().__init__(name)
    self._num_actions = num_actions
    self._num_atoms = num_atoms
    self._activation_fn = jax.nn.relu
    self._kernel_initializer = hk.initializers.VarianceScaling(
      scale=1 / np.sqrt(3), mode="fan_in", distribution="uniform"
    )
    # Defining layers.
    self._conv1 = hk.Conv2D(
      32,
      kernel_shape=(8, 8),
      stride=4,
      padding="SAME",
      w_init=self._kernel_initializer,
    )
    self._conv2 = hk.Conv2D(
      64,
      kernel_shape=(4, 4),
      stride=2,
      padding="SAME",
      w_init=self._kernel_initializer,
    )
    self._conv3 = hk.Conv2D(
      64,
      kernel_shape=(3, 3),
      stride=2,
      padding="SAME",
      w_init=self._kernel_initializer,
    )
    self._dense1 = hk.Linear(512, w_init=self._kernel_initializer)
    self._dense2 = hk.Linear(
      num_atoms * num_actions, w_init=self._kernel_initializer
    )

  def __call__(self, observations: Array) -> Array:
    net = [
      lambda x: x / 255.0,
      self._conv1,
      self._activation_fn,
      self._conv2,
      self._activation_fn,
      self._conv3,
      self._activation_fn,
      hk.Flatten(-3),
      self._dense1,
      self._activation_fn,
      self._dense2,
    ]
    logits = hk.Sequential(net)(observations)

    return logits


def make_networks(
  env_spec: specs.EnvironmentSpec,
  channels: int,
  num_atoms: int,
  output_init_scale: float,
  blocks_torso: int,
  blocks_qf: int,
  reduced_channels_head: int,
  fc_layers_qf: List[int],
  original: bool = True,
) -> CQLNetworks:
  """Creates networks used by the agent."""
  # Create dummy observations and actions to create network parameters.
  dummy_obs = utils.zeros_like(env_spec.observations)
  dummy_obs = utils.add_batch_dim(dummy_obs)
  num_actions = env_spec.actions.num_values

  def _q_value_fn(obs: jnp.ndarray, *args, **kwargs) -> Dict[str, jnp.ndarray]:
    del args, kwargs
    if original:
      logging.info("[Network] Use CQL original network.")
      q_net_fun = QuantileNetwork(num_actions, num_atoms)
      logits = q_net_fun(obs)
    else:
      torso_network = Representation(channels, blocks_torso)
      state = torso_network(obs)
      qf_head = hk.Sequential(
        get_prediction_head(
          num_atoms * num_actions,
          channels,
          blocks_qf,
          reduced_channels_head,
          fc_layers_qf,
          output_init_scale,
        )
      )
      logits = qf_head(state)

    # Reshape distribution and action dimension, since
    # rlax.quantile_q_learning expects it that way.
    bs = logits.shape[0]
    logits = jnp.reshape(
      logits,
      (
        bs,
        num_atoms,
        num_actions,
      ),
    )  # (B, a, |A|)
    q_values = jnp.mean(logits, axis=1)  # (B, |A|)
    return {"q_dist": logits, "q_value": q_values}

  q_net_fun = _q_value_fn

  qf = hk.without_apply_rng(hk.transform(q_net_fun))
  q_network = networks_lib.FeedForwardNetwork(
    lambda key: qf.init(key, dummy_obs), qf.apply
  )

  return CQLNetworks(q_network=q_network, environment_specs=env_spec)
