# coding=utf-8
# Copyright 2021 The Circuit Training Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""New circuittraining Model for generalization."""
import sys
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Text
from typing import Union

import gin
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tf_agents.networks import network
from tf_agents.specs import distribution_spec
from tf_agents.specs import tensor_spec
from tf_agents.typing import types
from tf_agents.utils import nest_utils

from a2perf.domains.circuit_training.circuit_training.environment import \
  observation_config as observation_config_lib
from .circuit_training import fully_connected_model_lib


def create_circuit_training_ppo_models_fn(
    rl_architecture: str,
    observation_tensor_spec: types.NestedTensorSpec,
    action_tensor_spec: types.NestedTensorSpec,
    static_features: Dict[str, np.ndarray],
    use_model_tpu: bool = False,
    seed: int = 0,
) -> tuple[Any, Any]:
  """Creates actor/value networks.

  Args:
    rl_architecture: The RL architecture.
    observation_tensor_spec: Env observation spec.
    action_tensor_spec: Env action spec.
    static_features: Env static features.
    use_model_tpu: Use TPU model.
    seed: Random seed.

  Returns:
    Tuple of actor_net, value_net.
  """
  if rl_architecture == 'generalization':
    actor_net, value_net = create_grl_actor_critic_models(
        observation_tensor_spec,
        action_tensor_spec,
        static_features,
        use_model_tpu=use_model_tpu,
        seed=seed,
    )
  else:
    actor_net = fully_connected_model_lib.create_actor_net(
        observation_tensor_spec, action_tensor_spec
    )
    value_net = fully_connected_model_lib.create_value_net(
        observation_tensor_spec
    )

  return actor_net, value_net


def create_circuit_training_dqn_models_fn(
    rl_architecture: str,
    observation_tensor_spec: types.NestedTensorSpec,
    action_tensor_spec: types.NestedTensorSpec,
    static_features: Dict[str, np.ndarray],
    use_model_tpu: bool = False,
    seed: int = 0,
) -> Any:
  """Creates DQN network.

  Args:
    rl_architecture: The RL architecture.
    observation_tensor_spec: Env observation spec.
    action_tensor_spec: Env action spec.
    static_features: Env static features.
    use_model_tpu: Use TPU model.
    seed: Random seed.

  Returns:
    DQN network.
  """
  if rl_architecture == 'generalization':
    return create_grl_q_models(
        observation_tensor_spec,
        action_tensor_spec,
        static_features,
        use_model_tpu=use_model_tpu,
        seed=seed,
    )
  else:
    return fully_connected_model_lib.create_q_network(
        observation_tensor_spec, action_tensor_spec
    )


# Reimplements internal function
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/smart_cond.py.
def smart_cond(
    pred: Union[bool, tf.Tensor],
    true_fn: Callable[[], tf.Tensor],
    false_fn: Callable[[], tf.Tensor],
    name: Optional[str] = None,
) -> tf.Tensor:
  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.

  If `pred` is a bool or has a constant value, we return either `true_fn()`
  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.

  Arguments:
    pred: A scalar determining whether to return the result of `true_fn` or
      `false_fn`.
    true_fn: The callable to be performed if pred is true.
    false_fn: The callable to be performed if pred is false.
    name: Optional name prefix when using `tf.cond`.

  Returns:
    Tensors returned by the call to either `true_fn` or `false_fn`.

  Raises:
    TypeError: If `true_fn` or `false_fn` is not callable.
  """
  if not callable(true_fn):
    raise TypeError('`true_fn` must be callable.')
  if not callable(false_fn):
    raise TypeError('`false_fn` must be callable.')
  pred_value = tf.get_static_value(pred)
  if isinstance(pred, tf.Tensor) or pred_value is None:
    return tf.cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
  if pred_value:
    return true_fn()
  else:
    return false_fn()


@gin.configurable
class CircuitTrainingModel(tf.keras.layers.Layer):
  """GCN-based model for circuit training."""

  EPSILON = 1e-6

  def __init__(
      self,
      all_static_features: Optional[dict[str, np.ndarray]] = None,
      observation_config: Optional[
        observation_config_lib.ObservationConfig
      ] = None,
      num_gcn_layers: int = 3,
      edge_fc_layers: int = 1,
      gcn_node_dim: int = 8,
      dirichlet_alpha: float = 0.1,
      policy_noise_weight: float = 0.0,
      is_augmented: bool = False,
      seed: int = 0,
      include_min_max_var: bool = True,
      use_value_head: bool = True,
  ):
    """Builds the circuit training model.

    Args:
      all_static_features: the static features keyed by the feature name.
      observation_config: Optional observation config.
      num_gcn_layers: Number of GCN layers.
      edge_fc_layers: Number of fully connected layers in the GCN kernel.
      gcn_node_dim: Node embedding dimension.
      dirichlet_alpha: Dirichlet concentration value.
      policy_noise_weight: Weight of the noise added to policy.
      is_augmented: Whether the model uses augmented inputs.
      seed: Seed for sampling noise.
      include_min_max_var: If set include reduce_ min, max, and variance of all
        edges beside the reduce_mean.
    """
    super(CircuitTrainingModel, self).__init__()
    self._num_gcn_layers = num_gcn_layers
    self._gcn_node_dim = gcn_node_dim
    self._dirichlet_alpha = dirichlet_alpha
    self._policy_noise_weight = policy_noise_weight
    self._is_augmented = is_augmented
    self._seed = seed
    self._use_value_head = use_value_head
    self._include_min_max_var = include_min_max_var
    self._all_static_features = all_static_features
    self._observation_config = (
        observation_config or observation_config_lib.ObservationConfig()
    )

    seed = tfp.util.SeedStream(self._seed, salt='kernel_initializer_seed')
    kernel_initializer = tf.keras.initializers.glorot_normal(
        seed=seed() % sys.maxsize
    )

    self._metadata_encoder = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(
                self._gcn_node_dim, kernel_initializer=kernel_initializer
            ),
            tf.keras.layers.ReLU(),
        ],
        name='metadata_encoder',
    )
    self._feature_encoder = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(
                self._gcn_node_dim, kernel_initializer=kernel_initializer
            ),
            tf.keras.layers.ReLU(),
        ],
        name='feature_encoder',
    )

    # Edge-centric GCN layers.
    def create_edge_fc(name=None) -> tf.keras.layers.Layer:
      seq = tf.keras.Sequential(name=name)
      for _ in range(edge_fc_layers):
        seq.add(
            tf.keras.layers.Dense(
                self._gcn_node_dim, kernel_initializer=kernel_initializer
            )
        )
        seq.add(tf.keras.layers.ReLU())
      return seq

    self._edge_fc_list = [
        create_edge_fc(name='edge_fc_%d' % i)
        for i in range(self._num_gcn_layers)
    ]

    # Dot-product attention layer, a.k.a. Luong-style attention [1].
    # [1] Luong, et al, 2015.
    self._attention_layer = tf.keras.layers.Attention(name='attention_layer')
    self._attention_query_layer = tf.keras.layers.Dense(
        self._gcn_node_dim,
        name='attention_query_layer',
        kernel_initializer=kernel_initializer,
    )
    self._attention_key_layer = tf.keras.layers.Dense(
        self._gcn_node_dim,
        name='attention_key_layer',
        kernel_initializer=kernel_initializer,
    )
    self._attention_value_layer = tf.keras.layers.Dense(
        self._gcn_node_dim,
        name='attention_value_layer',
        kernel_initializer=kernel_initializer,
    )

    if use_value_head:
      self._value_head = tf.keras.Sequential(
          [
              tf.keras.layers.Dense(32, kernel_initializer=kernel_initializer),
              tf.keras.layers.ReLU(),
              tf.keras.layers.Dense(8, kernel_initializer=kernel_initializer),
              tf.keras.layers.ReLU(),
              tf.keras.layers.Dense(1, kernel_initializer=kernel_initializer),
          ],
          name='value_head',
      )

    if self._is_augmented:
      self._augmented_embedding_layer = tf.keras.layers.Dense(
          self._gcn_node_dim,
          name='augmented_embedding_layer',
          kernel_initializer=kernel_initializer,
      )

    # GAN-like deconv layers to generated the policy image.
    # See figures in http://shortn/_9HCSFwasnu.
    self._policy_location_head = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(
                (
                    self._observation_config.max_grid_size
                    // 16
                    * self._observation_config.max_grid_size
                    // 16
                    * 32
                ),
                kernel_initializer=kernel_initializer,
            ),
            # 128/16*128/16*32 = 8*8*32
            tf.keras.layers.ReLU(),
            tf.keras.layers.Reshape(
                target_shape=(
                    self._observation_config.max_grid_size // 16,
                    self._observation_config.max_grid_size // 16,
                    32,
                )
            ),
            # 8x8x32
            tf.keras.layers.Conv2DTranspose(
                filters=16,
                kernel_size=3,
                strides=2,
                padding='same',
                kernel_initializer=kernel_initializer,
            ),
            # 16x16x16
            tf.keras.layers.ReLU(),
            tf.keras.layers.Conv2DTranspose(
                filters=8,
                kernel_size=3,
                strides=2,
                padding='same',
                kernel_initializer=kernel_initializer,
            ),
            # 32x32x8
            tf.keras.layers.ReLU(),
            tf.keras.layers.Conv2DTranspose(
                filters=4,
                kernel_size=3,
                strides=2,
                padding='same',
                kernel_initializer=kernel_initializer,
            ),
            # 64x64x4
            tf.keras.layers.ReLU(),
            tf.keras.layers.Conv2DTranspose(
                filters=2,
                kernel_size=3,
                strides=2,
                padding='same',
                kernel_initializer=kernel_initializer,
            ),
            # 128x128x2
            tf.keras.layers.ReLU(),
            # No activation.
            tf.keras.layers.Conv2DTranspose(
                filters=1,
                kernel_size=3,
                strides=1,
                padding='same',
                kernel_initializer=kernel_initializer,
            ),
            # 128x128x1
            tf.keras.layers.Flatten(),
        ],
        name='policy_location_head',
    )

  def _scatter_count(
      self, edge_h: tf.Tensor, indices: tf.Tensor
  ) -> tuple[tf.Tensor, tf.Tensor]:
    """Aggregate (reduce sum) edge embeddings to nodes.

    Args:
      edge_h: A [-1, #edges, h] tensor of edge embeddings.
      indices: A [-1, #edges] tensor of node index of an edge (sparse adjacency
        indices).

    Returns:
      A [-1, #nodes, h] tensor of aggregated node embeddings and a
      [-1, #nodes, 1] tensor of edge count per node for finding the mean.
    """
    batch = tf.shape(edge_h)[0]
    num_items = tf.shape(edge_h)[1]
    num_lattents = edge_h.shape[2]

    h_node = tf.zeros(
        [batch, self._observation_config.max_num_nodes, num_lattents]
    )
    count_edge = tf.zeros_like(h_node)
    count = tf.ones_like(edge_h)

    b_indices = tf.tile(
        tf.expand_dims(tf.range(0, tf.cast(batch, dtype=tf.int32)), -1),
        [1, num_items],
    )
    idx = tf.stack([b_indices, indices], axis=-1)
    h_node = tf.tensor_scatter_nd_add(h_node, idx, edge_h)
    count_edge = tf.tensor_scatter_nd_add(count_edge, idx, count)

    return h_node, count_edge

  def gather_to_edges(
      self,
      h_nodes: tf.Tensor,
      sparse_adj_i: tf.Tensor,
      sparse_adj_j: tf.Tensor,
      sparse_adj_weight: tf.Tensor,
  ) -> tuple[tf.Tensor, tf.Tensor]:
    """Gathers node embeddings to edges.

    For each edge, there are two node embeddings. It concats them together with
    the edge weight. It also masks the output with 0 for edges with no weight.

    Args:
      h_nodes: A [-1, #node, h] tensor of node embeddings.
      sparse_adj_i: A [-1, #edges] tensor for the 1st node index of an edge.
      sparse_adj_j: A [-1, #edges] tensor for the 2nd node index of an edge.
      sparse_adj_weight: A [-1, #edges] tensor for the weight of an edge. 0 for
        fake padded edges.

    Returns:
      A [-1, #edges, 2*h+1], [-1, #edges, 2*h+1]  tensor of edge embeddings.
    """

    h_edges_1 = tf.gather(h_nodes, sparse_adj_i, batch_dims=1)
    h_edges_2 = tf.gather(h_nodes, sparse_adj_j, batch_dims=1)
    h_edges_12 = tf.concat([h_edges_1, h_edges_2, sparse_adj_weight], axis=-1)
    h_edges_21 = tf.concat([h_edges_2, h_edges_1, sparse_adj_weight], axis=-1)
    mask = tf.broadcast_to(
        tf.not_equal(sparse_adj_weight, 0.0), tf.shape(h_edges_12)
    )
    h_edges_i_j = tf.where(mask, h_edges_12, tf.zeros_like(h_edges_12))
    h_edges_j_i = tf.where(mask, h_edges_21, tf.zeros_like(h_edges_21))
    return h_edges_i_j, h_edges_j_i

  def scatter_to_nodes(
      self, h_edges: tf.Tensor, sparse_adj_i: tf.Tensor, sparse_adj_j: tf.Tensor
  ) -> tf.Tensor:
    """Scatters edge embeddings to nodes via mean aggregation.

    For each node, it aggregates the embeddings of all the connected edges by
    averaging them.

    Args:
      h_edges: A [-1, #edges, h] tensor of edge embeddings.
      sparse_adj_i: A [-1, #edges] tensor for the 1st node index of an edge.
      sparse_adj_j: A [-1, #edges] tensor for the 2nd node index of an edge.

    Returns:
      A [-1, #node, h] tensor of node embeddings.
    """
    h_nodes_1, count_1 = self._scatter_count(h_edges, sparse_adj_i)
    h_nodes_2, count_2 = self._scatter_count(h_edges, sparse_adj_j)
    return (h_nodes_1 + h_nodes_2) / (count_1 + count_2 + self.EPSILON)

  def self_attention(
      self,
      h_current_node: tf.Tensor,
      h_nodes: tf.Tensor,
      training: bool = False,
  ) -> tf.Tensor:
    """Returns self-attention wrt to the current node.

    Args:
      h_current_node: A [-1, 1, h] tensor of the current node embedding.
      h_nodes: A [-1, #nodes, h] tensor of all node embeddings.
      training: Set in the training mode.

    Returns:
      A [-1, h] tensor of the weighted average of the node embeddings where
      the weight is the attention score with respect to the current node.
    """
    query = self._attention_query_layer(h_current_node, training=training)
    values = self._attention_value_layer(h_nodes, training=training)
    keys = self._attention_key_layer(h_nodes, training=training)
    h_attended = self._attention_layer([query, values, keys], training=training)
    h_attended = tf.squeeze(h_attended, axis=1)
    return h_attended

  def add_noise(self, logits: tf.Tensor) -> tf.Tensor:
    """Adds a non-trainable dirichlet noise to the policy."""
    seed = tfp.util.SeedStream(self._seed, salt='noise_seed')

    probs = tf.nn.softmax(logits)
    alphas = tf.fill(tf.shape(probs), self._dirichlet_alpha)
    dirichlet_distribution = tfp.distributions.Dirichlet(alphas)
    noise = dirichlet_distribution.sample(seed=seed() % sys.maxsize)
    noised_probs = (1.0 - self._policy_noise_weight) * probs + (
        self._policy_noise_weight
    ) * noise

    noised_logit = tf.math.log(noised_probs + self.EPSILON)

    return noised_logit

  def _get_static_input(
      self, static_feature_key: str, inputs: dict[str, tf.Tensor]
  ) -> tf.Tensor:
    """Returns the tensor for a particular static feature.

    Args:
      static_feature_key: a feature key defined in
        observation_config_lib.STATIC_OBSERVATIONS
      inputs: the dictionary of input features.

    Returns:
      A tensor for the static feature.
    """
    if not self._all_static_features:
      return inputs[static_feature_key]

    if static_feature_key not in self._all_static_features:
      raise ValueError(f'Static feature {static_feature_key} not found.')

    netlist_index = inputs['netlist_index']
    netlist_index = tf.squeeze(netlist_index, axis=-1)
    # Cap the index with the size of the number of static features.
    # In the collect jobs, we use only one static feature, but we have to
    # index the env with its index in the train job. The cap is added, so we
    # don't need to change the index for collect jobs for local use and in
    # replay buffer.
    netlist_index = tf.cast(
        tf.minimum(
            tf.cast(netlist_index, dtype=tf.float32),
            tf.cast(
                self._all_static_features[static_feature_key].shape[0] - 1,
                dtype=tf.float32,
            ),
        ),
        dtype=tf.int32,
    )
    return tf.gather(
        self._all_static_features[static_feature_key], netlist_index
    )

  def call(
      self,
      inputs: dict[str, tf.Tensor],
      training: bool = False,
      is_eval: bool = False,
      finetune_value_only: bool = False,
  ) -> tuple[dict[str, tf.Tensor], tf.Tensor]:
    # Netlist metadata.
    netlist_metadata_inputs = [
        self._get_static_input(key, inputs)
        # pytype: disable=wrong-arg-types  # always-use-return-annotations
        for key in observation_config_lib.NETLIST_METADATA
    ]

    # Graph.
    # pytype: disable=wrong-arg-types  # dynamic-method-lookup
    sparse_adj_weight = self._get_static_input('sparse_adj_weight', inputs)
    sparse_adj_i = tf.cast(
        self._get_static_input('sparse_adj_i', inputs), dtype=tf.int32
    )
    sparse_adj_j = tf.cast(
        self._get_static_input('sparse_adj_j', inputs), dtype=tf.int32
    )

    # Node features.
    node_types = self._get_static_input('node_types', inputs)
    is_node_placed = tf.cast(inputs['is_node_placed'], dtype=tf.float32)
    macros_w = self._get_static_input('macros_w', inputs)
    macros_h = self._get_static_input('macros_h', inputs)
    # pytype: enable=wrong-arg-types  # dynamic-method-lookup
    locations_x = inputs['locations_x']
    locations_y = inputs['locations_y']

    # Current node.
    current_node = tf.cast(inputs['current_node'], dtype=tf.int32)

    is_hard_macro = tf.cast(
        tf.math.equal(node_types, observation_config_lib.HARD_MACRO),
        dtype=tf.float32,
    )
    is_soft_macro = tf.cast(
        tf.math.equal(node_types, observation_config_lib.SOFT_MACRO),
        dtype=tf.float32,
    )
    is_port_cluster = tf.cast(
        tf.math.equal(node_types, observation_config_lib.PORT_CLUSTER),
        dtype=tf.float32,
    )

    netlist_metadata = tf.concat(netlist_metadata_inputs, axis=1)
    h_metadata = self._metadata_encoder(netlist_metadata, training=training)

    h_nodes = tf.stack(
        [
            locations_x,
            locations_y,
            macros_w,
            macros_h,
            is_hard_macro,
            is_soft_macro,
            is_port_cluster,
            is_node_placed,
        ],
        axis=2,
    )

    h_nodes = self._feature_encoder(h_nodes, training=training)

    # Edge-centric GCN
    #
    # Here, we are using a modified version of Graph Convolutional Network
    # (GCN)[1] that focuses on edge properties rather than node properties.
    # In this modified GCN, the features of neighbouring nodes are
    # mixed together to create edge features. Then, edge features are
    # aggregated on the connected nodes to create the output node embedding.
    # The GCN message passing happens indirectly between neighbouring nodes
    # through the mixing on the edges.
    #
    # Edge-centric GCN for Circuit Training
    #
    # The nodes of the circuit training observation graph are hard macros,
    # soft macros, and port clusters and the edges are the wires between them.
    # The intuition behind using edge-centric GCN was that the wirelength and
    # congestion costs (reward signals) depends on properties of the
    # wires (edge) and not the macros.
    # This architecture has shown promising results on predicting supervised
    # graph regression for predicting wirelength and congestion and we hope
    # it performs well in reinforcement setting to predict value and policy.
    #
    # An alternative approach was applying original GCN on the Line Graph of
    # the ckt graph (see https://en.wikipedia.org/wiki/Line_graph).
    # Nodes of the line graph correspond to the edges of the original graph.
    # However, the adjacency matrix of the line graph will be prohibitively
    # large and can't be readily processed by GCN.
    #
    # See figures in http://shortn/_j1NsgZBqAr for edge-centric GCN.
    #
    # [1] Kipf and Welling, 2016.
    sparse_adj_weight = tf.expand_dims(
        sparse_adj_weight, axis=-1, name='sparse_adj_weight'
    )

    for i in range(self._num_gcn_layers):
      # For bi-directional graph.
      h_edges_i_j, h_edges_j_i = self.gather_to_edges(
          h_nodes=h_nodes,
          sparse_adj_i=sparse_adj_i,
          sparse_adj_j=sparse_adj_j,
          sparse_adj_weight=sparse_adj_weight,
      )
      h_edges = (
                    self._edge_fc_list[i](h_edges_i_j, training=training)
                    + self._edge_fc_list[i](h_edges_j_i, training=training)
                ) / 2.0
      h_nodes_new = self.scatter_to_nodes(h_edges, sparse_adj_i, sparse_adj_j)
      # Skip connection.
      h_nodes = h_nodes_new + h_nodes

    observation_hiddens = []
    observation_hiddens.append(h_metadata)

    h_all_edges_mean = tf.reduce_mean(h_edges, axis=1)
    observation_hiddens.append(h_all_edges_mean)

    if self._include_min_max_var:
      h_all_edges_var = tf.math.reduce_variance(h_edges, axis=1)
      observation_hiddens.append(h_all_edges_var)

      h_all_edges_max = tf.math.reduce_max(h_edges, axis=1)
      observation_hiddens.append(h_all_edges_max)

      h_all_edges_min = tf.math.reduce_min(h_edges, axis=1)
      observation_hiddens.append(h_all_edges_min)

    h_current_node = tf.gather(h_nodes, current_node, batch_dims=1)

    h_attended = self.self_attention(h_current_node, h_nodes, training=training)
    observation_hiddens.append(h_attended)

    h_current_node = tf.squeeze(h_current_node, axis=1)
    observation_hiddens.append(h_current_node)

    if self._is_augmented:
      augmented_embedding = self._augmented_embedding_layer(
          inputs['augmented_features']
      )
      observation_hiddens.append(augmented_embedding)

    finetune_value_only = tf.convert_to_tensor(
        finetune_value_only, dtype=tf.bool
    )
    h = tf.concat(observation_hiddens, axis=1)
    h = tf.cond(finetune_value_only, lambda: tf.stop_gradient(h), lambda: h)

    location_logits = self._policy_location_head(h, training=training)
    # smart_cond avoids using tf.cond when condition value is static.
    logits = {
        'location': smart_cond(
            is_eval,
            lambda: location_logits,
            lambda: self.add_noise(location_logits),
        ),
    }

    logits = {
        'location': tf.cond(
            finetune_value_only,
            lambda: tf.stop_gradient(logits['location']),
            lambda: logits['location'],
        )
    }

    if self._use_value_head:
      value = self._value_head(h, training=training)
      return logits, value
    else:
      return logits, None


class CircuitTrainingTPUModel(CircuitTrainingModel):
  """Model optimized for TPU performance using map_fn."""

  def _scatter_count(
      self, edge_h: tf.Tensor, indices: tf.Tensor
  ) -> tuple[tf.Tensor, tf.Tensor]:
    """Aggregate (reduce sum) edge embeddings to nodes.

    Args:
      edge_h: A [-1, #edges, h] tensor of edge embeddings.
      indices: A [-1, #edges] tensor of node index of an edge (sparse adjacency
        indices).

    Returns:
      A [-1, #nodes, h] tensor of aggregated node embeddings and a
      [-1, #nodes, 1] tensor of edge count per node for finding the mean.
    """
    batch = tf.shape(edge_h)[0]
    num_lattents = edge_h.shape[2]
    h_node = tf.zeros(
        [batch, self._observation_config.max_num_nodes, num_lattents]
    )
    count_edge = tf.zeros_like(h_node)
    count = tf.ones_like(edge_h)

    indices = tf.expand_dims(indices, axis=-1)
    h_node = tf.map_fn(
        lambda x: tf.tensor_scatter_nd_add(x[0], x[1], x[2]),
        (h_node, indices, edge_h),
        fn_output_signature=tf.TensorSpec(
            shape=(self._observation_config.max_num_nodes, num_lattents)
        ),
    )
    count_edge = tf.map_fn(
        lambda x: tf.tensor_scatter_nd_add(x[0], x[1], x[2]),
        (count_edge, indices, count),
        fn_output_signature=tf.TensorSpec(
            shape=(self._observation_config.max_num_nodes, num_lattents)
        ),
    )
    return h_node, count_edge

  def gather_to_edges(
      self,
      h_nodes: tf.Tensor,
      sparse_adj_i: tf.Tensor,
      sparse_adj_j: tf.Tensor,
      sparse_adj_weight: tf.Tensor,
  ) -> tuple[tf.Tensor, tf.Tensor]:
    """Gathers node embeddings to edges.

    For each edge, there are two node embeddings. It concats them together with
    the edge weight. It also masks the output with 0 for edges with no weight.

    Args:
      h_nodes: A [-1, #node, h] tensor of node embeddings.
      sparse_adj_i: A [-1, #edges] tensor for the 1st node index of an edge.
      sparse_adj_j: A [-1, #edges] tensor for the 2nd node index of an edge.
      sparse_adj_weight: A [-1, #edges] tensor for the weight of an edge. 0 for
        fake padded edges.

    Returns:
      A [-1, #edges, 2*h+1], [-1, #edges, 2*h+1]  tensor of edge embeddings.
    """

    h_edges_1 = tf.map_fn(
        lambda x: tf.gather(x[0], x[1], batch_dims=0),
        (h_nodes, sparse_adj_i),
        fn_output_signature=tf.float32,
    )
    h_edges_2 = tf.map_fn(
        lambda x: tf.gather(x[0], x[1], batch_dims=0),
        (h_nodes, sparse_adj_j),
        fn_output_signature=tf.float32,
    )
    h_edges = tf.concat([h_edges_1, h_edges_2, sparse_adj_weight], axis=2)
    mask = tf.broadcast_to(sparse_adj_weight != 0.0, tf.shape(h_edges))
    return tf.where(mask, h_edges, tf.zeros_like(h_edges))

  def call(
      self,
      inputs: dict[str, tf.Tensor],
      training: bool = False,
      is_eval: bool = False,
      finetune_value_only: bool = False,
  ) -> tuple[dict[str, tf.Tensor], tf.Tensor]:
    # Netlist metadata.
    netlist_metadata_inputs = [
        self._get_static_input(key, inputs)
        for key in observation_config_lib.NETLIST_METADATA
    ]

    # Graph.
    sparse_adj_weight = self._get_static_input('sparse_adj_weight', inputs)
    sparse_adj_i = tf.cast(
        self._get_static_input('sparse_adj_i', inputs), dtype=tf.int32
    )
    sparse_adj_j = tf.cast(
        self._get_static_input('sparse_adj_j', inputs), dtype=tf.int32
    )

    # Node features.
    node_types = self._get_static_input('node_types', inputs)
    is_node_placed = tf.cast(inputs['is_node_placed'], dtype=tf.float32)
    macros_w = self._get_static_input('macros_w', inputs)
    macros_h = self._get_static_input('macros_h', inputs)
    locations_x = inputs['locations_x']
    locations_y = inputs['locations_y']

    # Current node.
    current_node = tf.cast(inputs['current_node'], dtype=tf.int32)

    is_hard_macro = tf.cast(
        tf.math.equal(node_types, observation_config_lib.HARD_MACRO),
        dtype=tf.float32,
    )
    is_soft_macro = tf.cast(
        tf.math.equal(node_types, observation_config_lib.SOFT_MACRO),
        dtype=tf.float32,
    )
    is_port_cluster = tf.cast(
        tf.math.equal(node_types, observation_config_lib.PORT_CLUSTER),
        dtype=tf.float32,
    )

    netlist_metadata = tf.concat(netlist_metadata_inputs, axis=1)
    h_metadata = self._metadata_encoder(netlist_metadata, training=training)

    h_nodes = tf.stack(
        [
            locations_x,
            locations_y,
            macros_w,
            macros_h,
            is_hard_macro,
            is_soft_macro,
            is_port_cluster,
            is_node_placed,
        ],
        axis=2,
    )

    h_nodes = self._feature_encoder(h_nodes, training=training)

    # Edge-centric GCN
    #
    # Here, we are using a modified version of Graph Convolutional Network
    # (GCN)[1] that focuses on edge properties rather than node properties.
    # In this modified GCN, the features of neighbouring nodes are
    # mixed together to create edge features. Then, edge features are
    # aggregated on the connected nodes to create the output node embedding.
    # The GCN message passing happens indirectly between neighbouring nodes
    # through the mixing on the edges.
    #
    # Edge-centric GCN for Circuit Training
    #
    # The nodes of the circuit training observation graph are hard macros,
    # soft macros, and port clusters and the edges are the wires between them.
    # The intuition behind using edge-centric GCN was that the wirelength and
    # congestion costs (reward signals) depends on properties of the
    # wires (edge) and not the macros.
    # This architecture has shown promising results on predicting supervised
    # graph regression for predicting wirelength and congestion and we hope
    # it performs well in reinforcement setting to predict value and policy.
    #
    # An alternative approach was applying original GCN on the Line Graph of
    # the ckt graph (see https://en.wikipedia.org/wiki/Line_graph).
    # Nodes of the line graph correspond to the edges of the original graph.
    # However, the adjacency matrix of the line graph will be prohibitively
    # large and can't be readily processed by GCN.
    #
    # See figures in http://shortn/_j1NsgZBqAr for edge-centric GCN.
    #
    # [1] Kipf and Welling, 2016.

    def update_tpu(h_nodes, i=0):
      # Optimizing scatter/gather performance on TPUs.
      # For bi-directional graph.
      h_edges_1 = tf.map_fn(
          lambda x: tf.gather(x[0], x[1], batch_dims=0),
          (h_nodes, sparse_adj_i),
          fn_output_signature=tf.float32,
      )
      h_edges_2 = tf.map_fn(
          lambda x: tf.gather(x[0], x[1], batch_dims=0),
          (h_nodes, sparse_adj_j),
          fn_output_signature=tf.float32,
      )

      h_edges_12 = tf.concat([h_edges_1, h_edges_2, sparse_adj_weight], axis=-1)
      mask = tf.broadcast_to(sparse_adj_weight != 0.0, tf.shape(h_edges_12))
      h_edges_i_j = tf.where(mask, h_edges_12, tf.zeros_like(h_edges_12))

      # h_edges_j_i = self.gather_to_edges(
      h_edges_21 = tf.concat([h_edges_2, h_edges_1, sparse_adj_weight], axis=-1)
      h_edges_j_i = tf.where(mask, h_edges_21, tf.zeros_like(h_edges_21))

      h_edges = (
                    self._edge_fc_list[i](h_edges_i_j, training=training)
                    + self._edge_fc_list[i](h_edges_j_i, training=training)
                ) / 2.0

      h_node = tf.zeros_like(h_nodes)
      num_lattents = h_edges.shape[2]
      count_edge = tf.zeros_like(h_node)
      count = tf.ones_like(h_edges)
      indices = tf.expand_dims(sparse_adj_i, axis=-1)
      h_nodes_1 = tf.map_fn(
          lambda x: tf.tensor_scatter_nd_add(x[0], x[1], x[2]),
          (h_node, indices, h_edges),
          fn_output_signature=tf.TensorSpec(
              shape=(self._observation_config.max_num_nodes, num_lattents)
          ),
      )
      count_1 = tf.map_fn(
          lambda x: tf.tensor_scatter_nd_add(x[0], x[1], x[2]),
          (count_edge, indices, count),
          fn_output_signature=tf.TensorSpec(
              shape=(self._observation_config.max_num_nodes, num_lattents)
          ),
      )
      indices = tf.expand_dims(sparse_adj_j, axis=-1)
      h_nodes_2 = tf.map_fn(
          lambda x: tf.tensor_scatter_nd_add(x[0], x[1], x[2]),
          (h_node, indices, h_edges),
          fn_output_signature=tf.TensorSpec(
              shape=(self._observation_config.max_num_nodes, num_lattents)
          ),
      )
      count_2 = tf.map_fn(
          lambda x: tf.tensor_scatter_nd_add(x[0], x[1], x[2]),
          (count_edge, indices, count),
          fn_output_signature=tf.TensorSpec(
              shape=(self._observation_config.max_num_nodes, num_lattents)
          ),
      )

      h_nodes_new = (h_nodes_1 + h_nodes_2) / (count_1 + count_2 + self.EPSILON)
      # Skip connection.
      h_nodes = h_nodes_new + h_nodes
      return h_nodes, h_edges

    sparse_adj_weight = tf.expand_dims(
        sparse_adj_weight, axis=-1, name='sparse_adj_weight'
    )

    h_nodes = tf.identity(h_nodes, 'initial_h_nodes')
    for i in range(self._num_gcn_layers):
      h_nodes, h_edges = update_tpu(h_nodes, i)

    observation_hiddens = []
    observation_hiddens.append(h_metadata)

    h_all_edges_mean = tf.reduce_mean(h_edges, axis=1)
    observation_hiddens.append(h_all_edges_mean)

    if self._include_min_max_var:
      h_all_edges_var = tf.math.reduce_variance(h_edges, axis=1)
      observation_hiddens.append(h_all_edges_var)

      h_all_edges_max = tf.math.reduce_max(h_edges, axis=1)
      observation_hiddens.append(h_all_edges_max)

      h_all_edges_min = tf.math.reduce_min(h_edges, axis=1)
      observation_hiddens.append(h_all_edges_min)

    h_current_node = tf.map_fn(
        lambda x: tf.gather(x[0], x[1], batch_dims=0),
        (h_nodes, current_node),
        fn_output_signature=tf.float32,
    )

    h_attended = self.self_attention(h_current_node, h_nodes, training=training)
    observation_hiddens.append(h_attended)

    h_current_node = tf.squeeze(h_current_node, axis=1)
    observation_hiddens.append(h_current_node)

    if self._is_augmented:
      augmented_embedding = self._augmented_embedding_layer(
          inputs['augmented_features']
      )
      observation_hiddens.append(augmented_embedding)

    finetune_value_only = tf.convert_to_tensor(
        finetune_value_only, dtype=tf.bool
    )
    h = tf.concat(observation_hiddens, axis=1)
    h = tf.cond(finetune_value_only, lambda: tf.stop_gradient(h), lambda: h)

    location_logits = self._policy_location_head(h, training=training)
    # smart_cond avoids using tf.cond when condition value is static.
    logits = {
        'location': smart_cond(
            is_eval,
            lambda: location_logits,
            lambda: self.add_noise(location_logits),
        ),
    }
    value = self._value_head(h, training=training)
    logits = {
        'location': tf.cond(
            finetune_value_only,
            lambda: tf.stop_gradient(logits['location']),
            lambda: logits['location'],
        )
    }

    return logits, value


@gin.configurable(module='circuittraining.models')
class GrlModel(network.Network):
  """Circuit GRL Model used as part of the canonical version."""

  def __init__(
      self,
      input_tensors_spec: types.NestedTensorSpec,
      output_tensors_spec: types.NestedTensorSpec,
      all_static_features: Dict[str, np.ndarray],
      name: Optional[Text] = None,
      state_spec: types.NestedTensorSpec = (),
      policy_noise_weight: float = 0.0,
      use_model_tpu: bool = True,
      is_augmented: bool = False,
      use_value_head: bool = True,
      seed: int = 0,
  ):
    super(GrlModel, self).__init__(
        input_tensor_spec=input_tensors_spec, state_spec=state_spec, name=name
    )

    if use_model_tpu:
      self._model = CircuitTrainingTPUModel(
          policy_noise_weight=policy_noise_weight,
          all_static_features=all_static_features,
          is_augmented=is_augmented,
          seed=seed,
      )
    else:
      self._model = CircuitTrainingModel(
          policy_noise_weight=policy_noise_weight,
          all_static_features=all_static_features,
          is_augmented=is_augmented,
          use_value_head=use_value_head,
          seed=seed,
      )

  def call(self, inputs, network_state=(), finetune_value_only=False):
    logits, value = self._model(inputs, finetune_value_only=finetune_value_only)
    return {'logits': logits, 'value': value}, network_state


@gin.configurable(module='circuittraining.models')
class GrlPolicyModel(network.DistributionNetwork):
  """Circuit GRL Model."""

  def __init__(
      self,
      shared_network: network.Network,
      input_tensors_spec: types.NestedTensorSpec,
      output_tensors_spec: types.NestedTensorSpec,
      name: Optional[Text] = 'GrlPolicyModel',
  ):
    super(GrlPolicyModel, self).__init__(
        input_tensor_spec=input_tensors_spec,
        state_spec=(),
        output_spec=output_tensors_spec,
        name=name,
    )

    self._input_tensors_spec = input_tensors_spec
    self._shared_network = shared_network
    self._output_tensors_spec = output_tensors_spec

    n_unique_actions = np.unique(
        output_tensors_spec.maximum - output_tensors_spec.minimum + 1
    )
    input_param_spec = {
        'logits': tensor_spec.TensorSpec(
            shape=n_unique_actions, dtype=tf.float32, name=name + '_logits'
        )
    }
    self._output_dist_spec = distribution_spec.DistributionSpec(
        tfp.distributions.Categorical,
        input_param_spec,
        sample_spec=output_tensors_spec,
        dtype=output_tensors_spec.dtype,
    )

  @property
  def output_spec(self):
    return self._output_dist_spec

  @property
  def distribution_tensor_spec(self):
    return self._output_dist_spec

  def call(self, inputs, step_types=None, network_state=()):
    outer_rank = nest_utils.get_outer_rank(inputs, self._input_tensors_spec)
    if outer_rank == 0:
      inputs = tf.nest.map_structure(lambda x: tf.reshape(x, (1, -1)), inputs)
    model_out, _ = self._shared_network(inputs)

    paddings = tf.ones_like(inputs['mask'], dtype=tf.float32) * (
        -(2.0 ** 32) + 1
    )
    masked_logits = tf.where(
        tf.cast(inputs['mask'], tf.bool),
        model_out['logits']['location'],
        paddings,
    )

    output_dist = self._output_dist_spec.build_distribution(
        logits=masked_logits
    )

    return output_dist, network_state


@gin.configurable(module='circuittraining.models')
class GrlValueModel(network.Network):
  """Circuit GRL Model."""

  def __init__(
      self,
      input_tensors_spec: types.NestedTensorSpec,
      shared_network: network.Network,
      name: Optional[Text] = None,
  ):
    super(GrlValueModel, self).__init__(
        input_tensor_spec=input_tensors_spec, state_spec=(), name=name
    )

    self._input_tensors_spec = input_tensors_spec
    self._shared_network = shared_network
    self._finetune_value_only = False

  def set_finetune_value_only(self, finetune_value_only: bool):
    self._finetune_value_only = finetune_value_only

  def call(self, inputs, step_types=None, network_state=()):
    outer_rank = nest_utils.get_outer_rank(inputs, self._input_tensors_spec)
    if outer_rank == 0:
      inputs = tf.nest.map_structure(lambda x: tf.reshape(x, (1, -1)), inputs)
    model_out, _ = self._shared_network(
        inputs, finetune_value_only=self._finetune_value_only
    )

    def squeeze_value_dim(value):
      # Make value_prediction's shape from [B, T, 1] to [B, T].
      return tf.squeeze(value, -1)

    return squeeze_value_dim(model_out['value']), network_state


@gin.configurable(module='circuittraining.models')
class GrlQModel(network.Network):
  """Circuit GRL Model."""

  def __init__(
      self,
      input_tensors_spec: types.NestedTensorSpec,
      shared_network: network.Network,
      name: Optional[Text] = None,
  ):
    super(GrlQModel, self).__init__(
        input_tensor_spec=input_tensors_spec, state_spec=(), name=name
    )

    self._input_tensors_spec = input_tensors_spec
    self._shared_network = shared_network

  def call(self, inputs, step_types=None, network_state=()):
    outer_rank = nest_utils.get_outer_rank(inputs, self._input_tensors_spec)
    if outer_rank == 0:
      inputs = tf.nest.map_structure(lambda x: tf.reshape(x, (1, -1)), inputs)
    model_out, _ = self._shared_network(inputs)

    paddings = tf.ones_like(inputs['mask'], dtype=tf.float32) * (
        -(2.0 ** 32) + 1
    )
    masked_logits = tf.where(
        tf.cast(inputs['mask'], tf.bool),
        model_out['logits']['location'],
        paddings,
    )

    return masked_logits, network_state


def create_grl_actor_critic_models(
    observation_tensor_spec: types.NestedTensorSpec,
    action_tensor_spec: types.NestedTensorSpec,
    all_static_features: Dict[str, np.ndarray],
    use_model_tpu: bool = False,
    is_augmented: bool = False,
    seed: int = 0,
):
  """Create the GRL actor and value networks from scratch.

  Args:
    observation_tensor_spec: tensor spec for the observations.
    action_tensor_spec: tensor spec for the actions.
    all_static_features: static features from the environment to pass into the
      models.
    use_model_tpu: boolean flag indicating the versions of the GRL models to
      create. TPU models leverage map_fn to speed up performance on TPUs. Both
      versions generate the same output given the same inputs.
    is_augmented: Whether the model uses augmented features.
    seed: Random seed.

  Returns:
    A tuple containing the GRL policy model and value model.
  """
  grl_shared_net = GrlModel(
      observation_tensor_spec,
      action_tensor_spec,
      all_static_features=all_static_features,
      use_model_tpu=use_model_tpu,
      is_augmented=is_augmented,
      seed=seed,
  )
  grl_actor_net = GrlPolicyModel(
      grl_shared_net, observation_tensor_spec, action_tensor_spec
  )
  grl_value_net = GrlValueModel(observation_tensor_spec, grl_shared_net)
  return grl_actor_net, grl_value_net


def create_grl_q_models(
    observation_tensor_spec: types.NestedTensorSpec,
    action_tensor_spec: types.NestedTensorSpec,
    all_static_features: Dict[str, np.ndarray],
    use_model_tpu: bool = False,
    is_augmented: bool = False,
    seed: int = 0,
):
  """Create the GRL actor and value networks from scratch.

  Args:
    observation_tensor_spec: tensor spec for the observations.
    action_tensor_spec: tensor spec for the actions.
    all_static_features: static features from the environment to pass into the
      models.
    use_model_tpu: boolean flag indicating the versions of the GRL models to
      create. TPU models leverage map_fn to speed up performance on TPUs. Both
      versions generate the same output given the same inputs.
    is_augmented: Whether the model uses augmented features.
    seed: Random seed.

  Returns:
    A tuple containing the GRL policy model and value model.
  """
  grl_shared_net = GrlModel(
      observation_tensor_spec,
      action_tensor_spec,
      all_static_features=all_static_features,
      use_model_tpu=use_model_tpu,
      use_value_head=False,
      is_augmented=is_augmented,
      seed=seed,
  )

  grl_q_net = GrlQModel(observation_tensor_spec, grl_shared_net)
  return grl_q_net
