# coding=utf-8
# Copyright 2022 The Multi Task Atari Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Large networks for running multi-game Atari."""

import collections
from flax import linen as nn
import gin
import jax.numpy as jnp
import jax
import time

import functools

import numpy as onp

from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

PRNGKey = Any
Shape = Tuple[int, Ellipsis]
Dtype = Any
Array = Any
PrecisionLike = Union[None, str, jax.lax.Precision, Tuple[str, str],
                      Tuple[jax.lax.Precision, jax.lax.Precision]]

default_kernel_init = nn.initializers.lecun_normal()


NetworkType = collections.namedtuple(
    'network',
    ['q_values', 'representation', 'representation_before_task']
)

DiscretizedNetworkType = collections.namedtuple(
    'network',
    ['q_values', 'logits', 'probabilities',
     'representation', 'representation_before_task']
)



def preprocess_atari_inputs(x):
  return x.astype(jnp.float32) / 255.0


class MyGroupNorm(nn.GroupNorm):
    def __call__(self, x):
        if x.ndim == 3:
            x = x[jnp.newaxis]
            x = super().__call__(x)
            return x[0]
        else:
            return super().__call__(x)


@gin.configurable
class ResNetBlock(nn.Module):
  """ResNet block."""
  filters: int
  conv: Any
  norm: Any
  act: Any
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x, ):
    residual = x
    y = self.conv(self.filters, (3, 3), self.strides)(x)
    y = self.norm()(y)
    y = self.act(y)
    y = self.conv(self.filters, (3, 3))(y)
    y = self.norm()(y)

    if residual.shape != y.shape:
        residual = self.conv(self.filters, (1, 1),
                              self.strides, name='conv_proj')(residual)
        residual = self.norm(name='norm_proj')(residual)

    return self.act(residual + y)


@gin.configurable
class BottleneckResNetBlock(nn.Module):
  """Bottleneck ResNet block."""
  filters: int
  conv: Any
  norm: Any
  act: Any
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x):
    residual = x
    y = self.conv(self.filters, (1, 1))(x)
    y = self.norm()(y)
    y = self.act(y)
    y = self.conv(self.filters, (3, 3), self.strides)(y)
    y = self.norm()(y)
    y = self.act(y)
    y = self.conv(self.filters * 4, (1, 1))(y)
    y = self.norm(scale_init=nn.initializers.zeros)(y)

    if residual.shape != y.shape:
        residual = self.conv(self.filters * 4, (1, 1),
                              self.strides, name='conv_proj')(residual)
        residual = self.norm(name='norm_proj')(residual)

    return self.act(residual + y)


@gin.configurable
class SpatialLearnedEmbeddings(nn.Module):
  height: int
  width: int
  channel: int
  num_features: int = 5
  param_dtype: Any = jnp.float32

  @nn.compact
  def __call__(self, features):
      """
      features is B x H x W X C
      """
      kernel = self.param(
        'kernel', nn.initializers.xavier_uniform(),
        (self.height, self.width, self.channel, self.num_features),
        self.param_dtype)

      # axis = 0 is H, axis = 1 is W, axis = 2 is C, axis = 3 is F
      features = jnp.sum(
          jnp.expand_dims(
            features, -1) * kernel, axis=(0, 1))
      features = jnp.reshape(features, -1)
      return features


@gin.configurable
class ResNetEncoder(nn.Module):
  """ResNetV1."""
  stage_sizes: Sequence[int]
  block_cls: Any
  num_actions: int

  num_filters: int = 64
  dtype: Any = jnp.float32
  act: Any = nn.relu
  conv: Any = nn.Conv
  norm: str = 'group'
  use_multiplicative_cond: bool = False
  use_spatial_learned_embeddings: bool = True
  num_spatial_blocks: int = 8
  with_task_id: bool = False
  inputs_preprocessed: bool = False

  use_multiple_heads: bool = False
  tanh_layer_norm: bool = False

  single_layer_head: bool = False

  use_mixture_clustering: bool = False
  n_clusters: int = 20

  wide_factor: int = 1

  # Only used when using C51
  use_distributional: bool = False
  num_atoms: int = 51
  vmax: float = 20.0


  @nn.compact
  def __call__(self, x, task_id=None, training=True, stop_grad_rep=False,
               support=None):
    initializer = nn.initializers.xavier_uniform()

    if not self.inputs_preprocessed:
      x = preprocess_atari_inputs(x)

    # Use kaiming_init for resnets
    conv = functools.partial(
        self.conv, use_bias=False,
        dtype=self.dtype, kernel_init=nn.initializers.kaiming_normal())

    if self.norm == 'batch':
      norm = functools.partial(
          nn.BatchNorm, use_running_average=not training, momentum=0.9,
          epsilon=1e-5, dtype=self.dtype)
    elif self.norm == 'group':
      norm = functools.partial(
          MyGroupNorm, num_groups=4, epsilon=1e-5, dtype=self.dtype)
    elif self.norm == 'layer':
      norm = functools.partial(nn.LayerNorm,  epsilon=1e-5, dtype=self.dtype)
    else:
        raise ValueError('norm not found')

    print('input ', x.shape)
    strides = (2, 2, 2, 1, 1)
    x = conv(self.num_filters, (7, 7), (strides[0], strides[0]),
              padding=[(3, 3), (3, 3)],
              name='conv_init')(x)
    print('post conv1', x.shape)

    x = norm(name='bn_init')(x)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(strides[1], strides[1]), padding='SAME')

    cond_var = task_id
    print('post maxpool1', x.shape)
    for i, block_size in enumerate(self.stage_sizes):
      for j in range(block_size):
        stride = (strides[i + 1], strides[i + 1]) if i > 0 and j == 0 else (1, 1)
        x = self.block_cls(self.num_filters * 2 ** i * self.wide_factor,
                          strides=stride,
                          conv=conv,
                          norm=norm,
                          act=self.act)(x)
        print('post block layer ', x.shape)
        if self.use_multiplicative_cond:
          assert cond_var is not None, "Cond var is None, nothing to condition on"
          print("Using Multiplicative Cond!")
          cond_var = nn.Dense(x.shape[-1], kernel_init=initializer)(cond_var)
          x_mult = jnp.expand_dims(jnp.expand_dims(cond_var, 0), 0)
          cond_var = nn.relu(cond_var)
          print ('x_mult shape:', x_mult.shape)
          x = x * x_mult
      print('post block ', x.shape)


    if self.use_spatial_learned_embeddings:
      """Learn spatial embeddings"""
      height, width, channel = x.shape[len(x.shape) - 3:]
      print ('pre spatial learned embeddings', x.shape)
      x = SpatialLearnedEmbeddings(
          height=height, width=width, channel=channel,
          num_features=self.num_spatial_blocks
      )(x)
      print ('post spatial learned embeddings', x.shape)
    else:
      """Global average pooling, standard from convnets."""
      x = jnp.mean(x, axis=(len(x.shape) - 3,len(x.shape) - 2))
      print ('post flatten', x.shape)

    if self.tanh_layer_norm:
      """Apply tanh and layer norm to normalize the conv network output"""
      print ('Adding a dimensional bottleneck and tanh like in DDPG')
      x = nn.Dense(features=2048, kernel_init=initializer)(x)
      x = nn.LayerNorm()(x)
      # x = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + 1e-6)
    representation_before_task = x

    if stop_grad_rep:
      x = jax.lax.stop_gradient(x)

    if task_id is not None and self.with_task_id and not self.use_multiple_heads:
      """Use tanh gating on the task ids."""
      print ('Using task ids with tanh gating')
      task_id_embed = nn.Dense(features=x.shape[-1],
                               kernel_init=initializer)(task_id)
      task_id_embed = jnp.tanh(task_id_embed)
      print ('task id embed shape: ', task_id_embed.shape)
      x = x * task_id_embed

      task_id_concat = nn.Dense(features=int(x.shape[-1] // 2),
                                kernel_init=initializer)(task_id)
      print ('task id concat shape: ', task_id_concat.shape)
      x = jnp.concatenate([x, task_id_concat], axis=-1)
      print ('post flatten with task ids', x.shape)

    if self.use_multiple_heads and self.with_task_id and\
        task_id is not None and not self.use_mixture_clustering and\
            not self.single_layer_head:
      """Pass in the representation into independent fully connected modules"""
      def get_representation(inp, idx=0):
        x = nn.Dense(features=512, kernel_init=initializer)(inp)
        x = nn.relu(x)
        x = nn.Dense(features=512, kernel_init=initializer)(x)
        representation = nn.relu(x)  # Use penultimate layer as representation
        output_actions = self.num_actions
        if self.use_distributional:
          output_actions = self.num_actions * self.num_atoms

        q_values = nn.Dense(
          features=output_actions, kernel_init=initializer,
          name='last_layer_weight_' + str(idx))(
              representation)
        return representation, q_values

      print ('using multiple output heads')
      all_reps = []
      all_q_vals = []
      for idx in range(task_id.shape[-1]):
        rep, q_val = get_representation(x, idx=idx)
        all_reps.append(rep)
        all_q_vals.append(q_val)

      print ('stacked representations', len(all_reps), len(all_q_vals))

      all_reps = jnp.stack(all_reps, axis=-1)
      all_q_vals = jnp.stack(all_q_vals, axis=-1)
      print ('stacked all reps shape', all_reps.shape, all_q_vals.shape)
      representation = jnp.sum(all_reps * task_id[None, :], axis=-1)
      q_values = jnp.sum(all_q_vals * task_id[None, :], axis=-1)

      if self.use_distributional:
        logits = q_values.reshape((self.num_actions, self.num_atoms))
        probabilities = jax.nn.softmax(logits, axis=-1)
        if support is None:
          support = jnp.linspace(-self.vmax, self.vmax, self.num_atoms)
        print ('Support in network: ', support)
        q_values = jnp.sum(support * probabilities, axis=-1)

      print ('representation shape', representation.shape, q_values.shape)

    elif self.use_multiple_heads and self.single_layer_head and\
        self.with_task_id and task_id is not None:
      """Pass in representation into fully connected networks."""
      def get_representation(inp, idx=0):
        output_actions = self.num_actions
        if self.use_distributional:
          output_actions = self.num_actions * self.num_atoms

        q_values = nn.Dense(
          features=output_actions, kernel_init=initializer,
          name='last_layer_weight_' + str(idx))(
              inp)
        return q_values

      print ('using multiple output heads, but with a single layer')
      all_reps = []
      all_q_vals = []
      x = nn.Dense(features=1024, kernel_init=initializer)(x)
      x = nn.relu(x)
      x = nn.Dense(features=1024, kernel_init=initializer)(x)
      representation = nn.relu(x)  # Use penultimate layer as representation
      representation = representation / (
            jnp.linalg.norm(representation, axis=-1, keepdims=True) + 1e-6)

      for idx in range(task_id.shape[-1]):
        q_val = get_representation(representation, idx=idx)
        all_q_vals.append(q_val)

      print ('stacked q_values', len(all_q_vals))

      all_q_vals = jnp.stack(all_q_vals, axis=-1)
      print ('stacked all q_vals shape', all_q_vals.shape)
      q_values = jnp.sum(all_q_vals * task_id[None, :], axis=-1)

      if self.use_distributional:
        logits = q_values.reshape((self.num_actions, self.num_atoms))
        probabilities = jax.nn.softmax(logits, axis=-1)
        if support is None:
          support = jnp.linspace(-self.vmax, self.vmax, self.num_atoms)
        print ('Support in network: ', support)
        q_values = jnp.sum(support * probabilities, axis=-1)

      print ('representation shape', representation.shape, q_values.shape)

    elif self.use_multiple_heads and self.with_task_id and self.use_mixture_clustering:
      """Pass in the representation into independent fully connected modules"""
      def get_representation(inp):
        x = nn.Dense(features=512, kernel_init=initializer)(inp)
        x = nn.relu(x)
        x = nn.Dense(features=512, kernel_init=initializer)(x)
        representation = nn.relu(x)  # Use penultimate layer as representation
        representation = representation / (
            jnp.linalg.norm(representation, axis=-1, keepdims=True) + 1e-6)
        q_values = nn.Dense(
          features=self.num_actions, kernel_init=initializer)(
              representation)
        return representation, q_values

      def map_task_id_to_head(x, task_id):
        tmp_x = jnp.concatenate([x, task_id], axis=-1)
        x = nn.Dense(features=512, kernel_init=initializer)(tmp_x)
        x = nn.relu(x)
        x = nn.Dense(features=self.n_clusters, kernel_init=initializer)(x)
        return jax.nn.softmax(x, axis=-1)

      print ('using multiple output heads with attention')
      all_reps = []
      all_q_vals = []
      for _ in range(self.n_clusters):
        rep, q_val = get_representation(x)
        all_reps.append(rep)
        all_q_vals.append(q_val)

      print ('stacked representations', len(all_reps), len(all_q_vals))
      attention_wt = map_task_id_to_head(x, task_id)
      all_reps = jnp.stack(all_reps, axis=-1)
      all_q_vals = jnp.stack(all_q_vals, axis=-1)
      print ('stacked all reps shape', all_reps.shape, all_q_vals.shape, attention_wt.shape)
      representation = jnp.sum(all_reps * attention_wt[None, :], axis=-1)
      q_values = jnp.sum(all_q_vals * attention_wt[None, :], axis=-1)
      print ('representation shape', representation.shape, q_values.shape, attention_wt.shape)
    else:
      x = nn.Dense(features=512, kernel_init=initializer)(x)
      x = nn.relu(x)
      x = nn.Dense(features=512, kernel_init=initializer)(x)
      x = nn.relu(x)
      x = nn.Dense(features=512, kernel_init=initializer)(x)
      representation = nn.relu(x)  

      # Compute final Q-values here...
      output_actions = self.num_actions
      if self.use_distributional:
        output_actions = self.num_actions * self.num_atoms

      q_values = nn.Dense(
          features=output_actions, kernel_init=initializer,
          name='last_layer_weight')(
              representation)

      if self.use_distributional:
        logits = q_values.reshape((self.num_actions, self.num_atoms))
        probabilities = jax.nn.softmax(logits, axis=-1)
        if support is None:
          support = jnp.linspace(-self.vmax, self.vmax, self.num_atoms)
        print ('Support in network: ', support)
        q_values = jnp.sum(support * probabilities, axis=-1)
        print ('q_values.shape: ', logits.shape, q_values.shape)

    if self.use_distributional:
      return DiscretizedNetworkType(
          q_values, logits, probabilities,
          representation, representation_before_task)

    return NetworkType(q_values, representation,
                       representation_before_task)


ResNet18 = functools.partial(ResNetEncoder, stage_sizes=(2, 2, 2, 2),
                   block_cls=ResNetBlock, use_spatial_learned_embeddings=True,
                   tanh_layer_norm=False,
                   use_multiplicative_cond=False)
ResNet34 = functools.partial(ResNetEncoder, stage_sizes=(3, 4, 6, 3),
                   block_cls=ResNetBlock, use_spatial_learned_embeddings=True,
                   tanh_layer_norm=False,
                   use_multiplicative_cond=False)
ResNet50 = functools.partial(ResNetEncoder, stage_sizes=(3, 4, 6, 3),
                             block_cls=BottleneckResNetBlock,
                             use_spatial_learned_embeddings=True,
                             tanh_layer_norm=False,
                             use_multiplicative_cond=False)
ResNet101 = functools.partial(ResNetEncoder, stage_sizes=(3, 4, 23, 3),
                              block_cls=BottleneckResNetBlock,
                              use_spatial_learned_embeddings=True,
                              tanh_layer_norm=False,
                              use_multiplicative_cond=False)
ResNet152 = functools.partial(
    ResNetEncoder, stage_sizes=(3, 8, 36, 3),
    block_cls=BottleneckResNetBlock, use_spatial_learned_embeddings=True,
    tanh_layer_norm=False,
    use_multiplicative_cond=False)

MultiHeadResNet18 = functools.partial(
    ResNetEncoder, stage_sizes=(2, 2, 2, 2),
    block_cls=ResNetBlock, use_spatial_learned_embeddings=True,
    use_multiple_heads=True,
    use_multiplicative_cond=False,
    single_layer_head=True,
    tanh_layer_norm=True)
MultiHeadResNet34 = functools.partial(
    ResNetEncoder, stage_sizes=(3, 4, 6, 3),
    block_cls=ResNetBlock, use_spatial_learned_embeddings=True,
    use_multiple_heads=True,
    use_multiplicative_cond=False,
    single_layer_head=True,
    tanh_layer_norm=True)
MultiHeadResNet50 = functools.partial(
    ResNetEncoder, stage_sizes=(3, 4, 6, 3),
    block_cls=BottleneckResNetBlock, use_spatial_learned_embeddings=True,
    use_multiple_heads=True,
    use_multiplicative_cond=False,
    single_layer_head=True,
    tanh_layer_norm=True)
MultiHeadResNet101 = functools.partial(
    ResNetEncoder, stage_sizes=(3, 4, 23, 3),
    block_cls=BottleneckResNetBlock, use_spatial_learned_embeddings=True,
    use_multiple_heads=True,
    use_multiplicative_cond=False,
    single_layer_head=True,
    tanh_layer_norm=True)

# C51 networks
DistributionalMultiHeadResNet18 = functools.partial(
    ResNetEncoder, stage_sizes=(2, 2, 2, 2),
    block_cls=ResNetBlock, use_spatial_learned_embeddings=True,
    use_multiple_heads=True,
    use_multiplicative_cond=False,
    single_layer_head=True,
    tanh_layer_norm=True, use_distributional=True,
    vmax=20.0, num_atoms=51)
DistributionalMultiHeadResNet34 = functools.partial(
    ResNetEncoder, stage_sizes=(3, 4, 6, 3),
    block_cls=ResNetBlock, use_spatial_learned_embeddings=True,
    use_multiple_heads=True,
    use_multiplicative_cond=False,
    single_layer_head=True,
    tanh_layer_norm=True, use_distributional=True,
    vmax=20.0, num_atoms=51)
DistributionalMultiHeadResNet50 = functools.partial(
    ResNetEncoder, stage_sizes=(3, 4, 6, 3),
    block_cls=BottleneckResNetBlock, use_spatial_learned_embeddings=True,
    use_multiple_heads=True,
    use_multiplicative_cond=False,
    single_layer_head=True,
    tanh_layer_norm=True, use_distributional=True,
    vmax=20.0, num_atoms=51)
DistributionalMultiHeadResNet101 = functools.partial(
    ResNetEncoder, stage_sizes=(3, 4, 23, 3),
    block_cls=BottleneckResNetBlock, use_spatial_learned_embeddings=True,
    use_multiple_heads=True,
    use_multiplicative_cond=False,
    single_layer_head=True,
    tanh_layer_norm=True, use_distributional=True,
    vmax=20.0, num_atoms=51)