# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MLP to compute basic linear functions of one-hot encoded integers."""

from typing import Callable

import numpy as np

from tracr.craft import bases
from tracr.craft import transformers
from tracr.craft import vectorspace_fns

_ONE_SPACE = bases.VectorSpaceWithBasis.from_names(["one"])


def map_categorical_mlp(
    input_space: bases.VectorSpaceWithBasis,
    output_space: bases.VectorSpaceWithBasis,
    operation: Callable[[bases.BasisDirection], bases.BasisDirection],
) -> transformers.MLP:
  """Returns an MLP that encodes any categorical function of a single variable f(x).

  The hidden layer is the identity and output combines this with a lookup table
    output_k = sum(f(i)*input_i for all i in input space)

  Args:
    input_space: space containing the input x.
    output_space: space containing possible outputs.
    operation: A function operating on basis directions.
  """

  def operation_fn(direction):
    if direction in input_space:
      output_direction = operation(direction)
      if output_direction in output_space:
        return output_space.vector_from_basis_direction(output_direction)
    return output_space.null_vector()

  first_layer = vectorspace_fns.Linear.from_action(input_space, output_space,
                                                   operation_fn)

  second_layer = vectorspace_fns.project(output_space, output_space)

  return transformers.MLP(first_layer, second_layer)


def map_categorical_to_numerical_mlp(
    input_space: bases.VectorSpaceWithBasis,
    output_space: bases.VectorSpaceWithBasis,
    operation: Callable[[bases.Value], float],
) -> transformers.MLP:
  """Returns an MLP to compute f(x) from a categorical to a numerical variable.

  The hidden layer is the identity and output combines this with a lookup table
    output = sum(f(i)*input_i for all i in input space)

  Args:
    input_space: Vector space containing the input x.
    output_space: Vector space to write the numerical output to.
    operation: A function operating on basis directions.
  """
  bases.ensure_dims(output_space, num_dims=1, name="output_space")
  out_vec = output_space.vector_from_basis_direction(output_space.basis[0])

  def operation_fn(direction):
    if direction in input_space:
      return operation(direction.value) * out_vec
    return output_space.null_vector()

  first_layer = vectorspace_fns.Linear.from_action(input_space, output_space,
                                                   operation_fn)

  second_layer = vectorspace_fns.project(output_space, output_space)

  return transformers.MLP(first_layer, second_layer)


def sequence_map_categorical_mlp(
    input1_space: bases.VectorSpaceWithBasis,
    input2_space: bases.VectorSpaceWithBasis,
    output_space: bases.VectorSpaceWithBasis,
    operation: Callable[[bases.BasisDirection, bases.BasisDirection],
                        bases.BasisDirection],
    one_space: bases.VectorSpaceWithBasis = _ONE_SPACE,
    hidden_name: bases.Name = "__hidden__",
) -> transformers.MLP:
  """Returns an MLP that encodes a categorical function of two variables f(x, y).

  The hidden layer of the MLP computes the logical and of all input directions
    hidden_i_j = ReLU(x_i+x_j-1)

  And the output combines this with a lookup table
    output_k = sum(f(i, j)*hidden_i_j for all i,j in input space)

  Args:
    input1_space: Vector space containing the input x.
    input2_space: Vector space containing the input y.
    output_space: Vector space to write outputs to.
    operation: A function operating on basis directions.
    one_space: a reserved 1-d space that always contains a 1.
    hidden_name: Name for hidden dimensions.
  """
  bases.ensure_dims(one_space, num_dims=1, name="one_space")

  if not set(input1_space.basis).isdisjoint(input2_space.basis):
    raise ValueError("Input spaces to a SequenceMap must be disjoint. "
                     "If input spaces are the same, use Map instead!")

  input_space = bases.direct_sum(input1_space, input2_space, one_space)

  def to_hidden(x, y):
    return bases.BasisDirection(hidden_name, (x.name, x.value, y.name, y.value))

  def from_hidden(h):
    x_name, x_value, y_name, y_value = h.value
    x_dir = bases.BasisDirection(x_name, x_value)
    y_dir = bases.BasisDirection(y_name, y_value)
    return x_dir, y_dir

  hidden_dir = []
  for dir1 in input1_space.basis:
    for dir2 in input2_space.basis:
      hidden_dir.append(to_hidden(dir1, dir2))
  hidden_space = bases.VectorSpaceWithBasis(hidden_dir)

  def logical_and(direction):
    if direction in one_space:
      out = bases.VectorInBasis(hidden_space.basis,
                                -np.ones(hidden_space.num_dims))
    elif direction in input1_space:
      dir1 = direction
      out = hidden_space.null_vector()
      for dir2 in input2_space.basis:
        out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2))
    else:
      dir2 = direction
      out = hidden_space.null_vector()
      for dir1 in input1_space.basis:
        out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2))
    return out

  first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space,
                                                   logical_and)

  def operation_fn(direction):
    dir1, dir2 = from_hidden(direction)
    output_direction = operation(dir1, dir2)
    if output_direction in output_space:
      return output_space.vector_from_basis_direction(output_direction)
    else:
      return output_space.null_vector()

  second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space,
                                                    operation_fn)

  return transformers.MLP(first_layer, second_layer)
