# coding=utf-8
# Copyright 2022 The Multi Task Atari Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Networks for offline RL agents with or without task conditioning."""

import collections
from flax import linen as nn
import gin
import jax.numpy as jnp
import jax
import time
import numpy as onp
from typing import Any, Tuple

import functools

NetworkType = collections.namedtuple('network', ['q_values', 'representation', 'representation_before_task'])


def preprocess_atari_inputs(x):
  """Input normalization for Atari 2600 input frames."""
  return x.astype(jnp.float32) / 255.


@gin.configurable
class Stack(nn.Module):
  """Stack of pooling and convolutional blocks with residual connections."""
  num_ch: int
  num_blocks: int
  use_max_pooling: bool = True

  @nn.compact
  def __call__(self, x):
    initializer = nn.initializers.xavier_uniform()
    conv_out = nn.Conv(
        features=self.num_ch,
        kernel_size=(3, 3),
        strides=1,
        kernel_init=initializer,
        padding='SAME')(
            x)
    if self.use_max_pooling:
      conv_out = nn.max_pool(
          conv_out, window_shape=(3, 3), padding='SAME', strides=(2, 2))

    for _ in range(self.num_blocks):
      block_input = conv_out
      conv_out = nn.relu(conv_out)
      conv_out = nn.Conv(
          features=self.num_ch, kernel_size=(3, 3), strides=1, padding='SAME')(
              conv_out)
      conv_out = nn.relu(conv_out)
      conv_out = nn.Conv(
          features=self.num_ch, kernel_size=(3, 3), strides=1, padding='SAME')(
              conv_out)
      conv_out += block_input

    return conv_out


@gin.configurable
class ImpalaNetworkWithRepresentations(nn.Module):
  """Impala Network which also outputs penultimate representation layers."""
  num_actions: int
  inputs_preprocessed: bool = False
  nn_scale: int = 1
  with_task_id: bool = False

  tanh_layer_norm: bool = True
  use_multiple_heads: bool = False


  def setup(self):
    stack_sizes = [16, 32, 32]
    self.stack_blocks = [
        Stack(
            num_ch=stack_sizes[0] * self.nn_scale,
            num_blocks=2),
        Stack(
            num_ch=stack_sizes[1] * self.nn_scale,
            num_blocks=2),
        Stack(
            num_ch=stack_sizes[2] * self.nn_scale,
            num_blocks=2)
    ]

  @nn.compact
  def __call__(self, x, task_id=None):
    initializer = nn.initializers.xavier_uniform()
    if not self.inputs_preprocessed:
      x = preprocess_atari_inputs(x)
    conv_out = x

    for idx in range(3):
      conv_out = self.stack_blocks[idx](conv_out)

    conv_out = nn.relu(conv_out)
    conv_out = conv_out.reshape(-1)

    if self.tanh_layer_norm:
      """Apply tanh and layer norm to normalize the conv network output"""
      print ('Adding a dimensional bottleneck and tanh like in DDPG')
      x = nn.Dense(features=512, kernel_init=initializer)(conv_out)
      x = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + 1e-6)
      conv_out = x

    representation_before_task = conv_out

    if task_id is not None and self.with_task_id and not self.use_multiple_heads:
      """Use tanh gating on the task ids."""
      x = conv_out
      print ('Using task ids with tanh gating')
      task_id_embed = nn.Dense(features=x.shape[-1],
                               kernel_init=initializer)(task_id)
      task_id_embed = jnp.tanh(task_id_embed)
      print ('task id embed shape: ', task_id_embed.shape)
      x = x * task_id_embed

      task_id_concat = nn.Dense(features=int(x.shape[-1] // 2),
                                kernel_init=initializer)(task_id)
      print ('task id concat shape: ', task_id_concat.shape)
      conv_out = jnp.concatenate([x, task_id_concat], axis=-1)
      print ('post flatten with task ids', conv_out.shape)

    if self.use_multiple_heads and self.with_task_id and task_id is not None:
      """Pass in the representation into independent fully connected modules"""
      def get_representation(inp):
        conv_out = nn.Dense(features=512, kernel_init=initializer)(inp)
        representation = nn.relu(conv_out)
        representation = nn.Dense(
            features=512, kernel_init=initializer)(representation)
        representation = nn.relu(representation)
        q_values = nn.Dense(
          features=self.num_actions, kernel_init=initializer)(
              representation)

        return representation, q_values

      print ('using multiple output heads')
      all_reps = []
      all_q_vals = []
      for _ in range(task_id.shape[-1]):
        rep, q_val = get_representation(conv_out)
        all_reps.append(rep)
        all_q_vals.append(q_val)

      print ('stacked representations', len(all_reps), len(all_q_vals))

      all_reps = jnp.stack(all_reps, axis=-1)
      all_q_vals = jnp.stack(all_q_vals, axis=-1)
      print ('stacked all reps shape', all_reps.shape, all_q_vals.shape)
      representation = jnp.sum(all_reps * task_id[None, :], axis=-1)
      q_values = jnp.sum(all_q_vals * task_id[None, :], axis=-1)
      print ('representation shape', representation.shape, q_values.shape)
    else:
      conv_out = nn.Dense(features=512, kernel_init=initializer)(conv_out)
      representation = nn.relu(conv_out)
      representation = nn.Dense(
            features=512, kernel_init=initializer)(representation)
      representation = nn.relu(representation)
      representation = nn.Dense(
            features=512, kernel_init=initializer)(representation)
      representation = nn.relu(representation)

      # Compute Q-values finally
      q_values = nn.Dense(
          features=self.num_actions, kernel_init=initializer)(
              representation)
    return NetworkType(q_values, representation,
                       representation_before_task=representation_before_task)


MultiHeadIMPALA = functools.partial(
    ImpalaNetworkWithRepresentations,
    tanh_layer_norm=False, use_multiple_heads=True)


def _conv_dimension_numbers(input_shape):
  """Computes the dimension numbers based on the input shape."""
  ndim = len(input_shape)
  lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
  rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
  out_spec = lhs_spec
  return jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
