# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for calculating losses."""

from typing import Dict, List, Tuple
import chex
from clrs._src import probing
from clrs._src import specs

import haiku as hk
import jax
import jax.numpy as jnp

import optax

_Array = chex.Array
_DataPoint = probing.DataPoint
_Location = specs.Location
_OutputClass = specs.OutputClass
_PredTrajectory = Dict[str, _Array]
_PredTrajectories = List[_PredTrajectory]
_Type = specs.Type

EPS = 1e-12


def _expand_to(x: _Array, y: _Array) -> _Array:
  while len(y.shape) > len(x.shape):
    x = jnp.expand_dims(x, -1)
  return x


def _expand_and_broadcast_to(x: _Array, y: _Array) -> _Array:
  return jnp.broadcast_to(_expand_to(x, y), y.shape)


def output_loss_chunked(truth: _DataPoint, pred: _Array,
                        is_last: _Array, nb_nodes: int) -> float:
  """Output loss for time-chunked training."""

  mask = None

  if truth.type_ == _Type.SCALAR:
    loss = (pred - truth.data)**2

  elif truth.type_ == _Type.MASK:
    loss = (
        jnp.maximum(pred, 0) - pred * truth.data +
        jnp.log1p(jnp.exp(-jnp.abs(pred))))
    mask = (truth.data != _OutputClass.MASKED)

  elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
    mask = jnp.any(truth.data == _OutputClass.POSITIVE, axis=-1)
    masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype(
        jnp.float32)
    loss = -jnp.sum(masked_truth * jax.nn.log_softmax(pred), axis=-1)

  elif truth.type_ == _Type.POINTER:
    loss = -jnp.sum(
        hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), axis=-1)

  elif truth.type_ == _Type.PERMUTATION_POINTER:
    # Predictions are NxN logits aiming to represent a doubly stochastic matrix.
    # Compute the cross entropy between doubly stochastic pred and truth_data
    loss = -jnp.sum(truth.data * pred, axis=-1)

  if mask is not None:
    mask = mask * _expand_and_broadcast_to(is_last, loss)
  else:
    mask = _expand_and_broadcast_to(is_last, loss)
  total_mask = jnp.maximum(jnp.sum(mask), EPS)
  return jnp.sum(jnp.where(mask, loss, 0.0)) / total_mask


def output_loss(truth: _DataPoint, pred: _Array, nb_nodes: int) -> float:
  """Output loss for full-sample training."""

  if truth.type_ == _Type.SCALAR:
    total_loss = jnp.mean((pred - truth.data)**2)

  elif truth.type_ == _Type.MASK:
    loss = (
        jnp.maximum(pred, 0) - pred * truth.data +
        jnp.log1p(jnp.exp(-jnp.abs(pred))))
    mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32)
    total_loss = jnp.sum(loss * mask) / jnp.sum(mask)

  elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
    masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype(
        jnp.float32)
    total_loss = (-jnp.sum(masked_truth * jax.nn.log_softmax(pred)) /
                  jnp.sum(truth.data == _OutputClass.POSITIVE))

  elif truth.type_ == _Type.POINTER:
    total_loss = (
        jnp.mean(-jnp.sum(
            hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred),
            axis=-1)))

  elif truth.type_ == _Type.PERMUTATION_POINTER:
    # Predictions are NxN logits aiming to represent a doubly stochastic matrix.
    # Compute the cross entropy between doubly stochastic pred and truth_data
    total_loss = jnp.mean(-jnp.sum(truth.data * pred, axis=-1))

  return total_loss


def hint_loss_chunked(
    truth: _DataPoint,
    pred: _Array,
    is_first: _Array,
    nb_nodes: int,
):
  """Hint loss for time-chunked training."""
  loss, mask = _hint_loss(
      truth_data=truth.data,
      truth_type=truth.type_,
      pred=pred,
      nb_nodes=nb_nodes,
  )

  mask *= (1 - _expand_to(is_first, loss)).astype(jnp.float32)
  loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS)
  return loss


def hint_loss(
    truth: _DataPoint,
    preds: List[_Array],
    lengths: _Array,
    nb_nodes: int,
    verbose: bool = False,
):
  """Hint loss for full-sample training."""
  total_loss = 0.
  verbose_loss = {}
  length = truth.data.shape[0] - 1

  loss, mask = _hint_loss(
      truth_data=truth.data[1:],
      truth_type=truth.type_,
      pred=jnp.stack(preds),
      nb_nodes=nb_nodes,
  )
  mask *= _is_not_done_broadcast(lengths, jnp.arange(length)[:, None], loss)
  loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS)
  if verbose:
    verbose_loss['loss_' + truth.name] = loss
  else:
    total_loss += loss

  return verbose_loss if verbose else total_loss


def _hint_loss(
    truth_data: _Array,
    truth_type: str,
    pred: _Array,
    nb_nodes: int,
) -> Tuple[_Array, _Array]:
  """Hint loss helper."""
  mask = None
  if truth_type == _Type.SCALAR:
    loss = (pred - truth_data)**2

  elif truth_type == _Type.MASK:
    # self-supervised objective
    # used for type MASK as state type from specs is common for all algos
    T, B, N, H = pred.shape
    temp = 1e-1
    loss = 0

    eps = 1e-2
    weight = 1.

    clean = pred[:, 0:B:2] # (T, B_2, N, H)
    aug = pred[:, 1:B:2]

    for node in range(N):
      node_loss = 0.
      clean_node = jnp.expand_dims(clean[:, :, node], 2) # (T, B_2, 1, H)
      aug_node = jnp.expand_dims(aug[:, :, node], 2)


      sim_clean_aug = optax.cosine_similarity(clean_node, aug_node, eps) / temp # (T, B_2, 1)
      sim_clean_other = optax.cosine_similarity(clean_node, aug, eps) / temp # (T, B_2, N)
      node_loss += jnp.mean(jax.nn.logsumexp(sim_clean_other, axis=-1, keepdims=True) - sim_clean_aug, axis=0).mean()

      loss += node_loss / N
    loss *= weight

  elif truth_type == _Type.MASK_ONE:
    loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1,
                    keepdims=True)

  elif truth_type == _Type.CATEGORICAL:
    loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1)
    mask = jnp.any(truth_data == _OutputClass.POSITIVE, axis=-1).astype(
        jnp.float32)

  elif truth_type == _Type.POINTER:
    loss = -jnp.sum(
        hk.one_hot(truth_data, nb_nodes) * jax.nn.log_softmax(pred),
        axis=-1)

  elif truth_type == _Type.PERMUTATION_POINTER:
    # Predictions are NxN logits aiming to represent a doubly stochastic matrix.
    # Compute the cross entropy between doubly stochastic pred and truth_data
    loss = -jnp.sum(truth_data * pred, axis=-1)

  if mask is None:
    mask = jnp.ones_like(loss)
  return loss, mask


def _is_not_done_broadcast(lengths, i, tensor):
  is_not_done = (lengths > i + 1) * 1.0
  while len(is_not_done.shape) < len(tensor.shape):  # pytype: disable=attribute-error  # numpy-scalars
    is_not_done = jnp.expand_dims(is_not_done, -1)
  return is_not_done
