# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Combines all steps of compiling a RASP program."""

from typing import Set

from tracr.compiler import assemble
from tracr.compiler import basis_inference
from tracr.compiler import craft_graph_to_model
from tracr.compiler import craft_model_to_transformer
from tracr.compiler import expr_to_craft_graph
from tracr.compiler import rasp_to_graph
from tracr.craft import bases
from tracr.rasp import rasp

COMPILER_BOS = "compiler_bos"
COMPILER_PAD = "compiler_pad"


def compile_rasp_to_model(
    program: rasp.SOp,
    vocab: Set[rasp.Value],
    max_seq_len: int,
    causal: bool = False,
    compiler_bos: str = COMPILER_BOS,
    compiler_pad: str = COMPILER_PAD,
    mlp_exactness: int = 100) -> assemble.AssembledTransformerModel:
  """Compile a RASP program to transformer weights.

  Args:
    program: the RASP program to compile.
    vocab: the set of vocab tokens expected by RASP.
    max_seq_len: the maximum sequence length for the compiled model.
    causal: if True, outputs a model with causal masking.
    compiler_bos: the name of the special BOS token that will be added by the
      compiler. Must not be present in the vocab.
    compiler_pad: the name of the special PAD token that will be added by the
      compiler. Must not be present in the vocab.
    mlp_exactness: Controls the approximation of the MLP layers. In theory,
      larger values yield a better approximation. But too large values can cause
      numerical issues due to large parameter norms. Reasonable values are
      between 1 and 100.

  Returns:
    The compiled model.
  """

  if compiler_bos in vocab:
    raise ValueError("Compiler BOS token must not be present in the vocab. "
                     f"Found '{compiler_bos}' in {vocab}")

  if compiler_pad in vocab:
    raise ValueError("Compiler PAD token must not be present in the vocab. "
                     f"Found '{compiler_pad}' in {vocab}")

  extracted = rasp_to_graph.extract_rasp_graph(program)
  graph, sources, sink = extracted.graph, extracted.sources, extracted.sink

  basis_inference.infer_bases(
      graph,
      sink,
      vocab,
      max_seq_len,
  )

  expr_to_craft_graph.add_craft_components_to_rasp_graph(
      graph,
      bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos),
      mlp_exactness=mlp_exactness,
  )

  craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources)

  return craft_model_to_transformer.craft_model_to_transformer(
      craft_model=craft_model,
      graph=graph,
      sink=sink,
      max_seq_len=max_seq_len,
      causal=causal,
      compiler_bos=compiler_bos,
      compiler_pad=compiler_pad,
  )
