# coding=utf-8
# Copyright 2021 The Neural Sddp Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Conditional piecewise linear nn."""

from typing import Callable, Sequence
from flax import linen as nn
import jax
import jax.numpy as jnp
from ot_metrics import emd_approx


class MLP(nn.Module):
  sizes: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for size in self.sizes[:-1]:
      x = nn.Dense(size)(x)
      x = nn.relu(x)
    return nn.Dense(self.sizes[-1])(x)


class CondPiecewiseNN(nn.Module):
  """Generate piecewise linear functions based on conditional features."""
  num_vars: int
  num_stages: int
  hidden_size: int
  num_pieces: int
  num_layers: int
  kernel_init: Callable = nn.initializers.xavier_uniform()
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, cond_feat, stage_idx):
    time_embedding = nn.Embed(num_embeddings=self.num_stages - 1,
                              features=self.hidden_size)(stage_idx)
    cond_feat = nn.Dense(self.hidden_size)(cond_feat) + time_embedding
    l = [self.hidden_size] * self.num_layers
    l.append((self.num_vars + 1) * self.num_pieces)
    joint_param = MLP(tuple(l))(cond_feat)
    params = jnp.reshape(joint_param, (-1, self.num_pieces, self.num_vars + 1))
    return params

  def emd_approx(self, pred_params, target_pieces):
    """Calculate batched emd distance.

    Args:
      pred_params: pred_params of shape [bsize, self.n_pieces, n_vars + 1]
      target_pieces: tensor of shape [bsize, n_pieces, n_vars + 1]
    Returns:
      loss
    """
    return jax.vmap(emd_approx)(target_pieces, pred_params)

  def mse(self, pred_params, target_pieces):
    """Calculate mean square error."""
    n = min(pred_params.shape[1], target_pieces.shape[1])
    dist = pred_params[:, -n:, :] - target_pieces[:, -n:, :]
    dist = jnp.sum(dist ** 2, axis=[1, 2])
    return dist
