# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for using gin configurations with T5X binaries."""
import os
from typing import Optional, Sequence

from absl import app
from absl import logging
from clu import metric_writers
import gin
import jax
import tensorflow as tf



def parse_gin_flags(gin_search_paths: Sequence[str],
                    gin_files: Sequence[str],
                    gin_bindings: Sequence[str],
                    skip_unknown: bool = False,
                    finalize_config: bool = True):
  """Parses provided gin files override params.

  Args:
    gin_search_paths: paths that will be searched for gin files.
    gin_files: paths to gin config files to be parsed. Files will be parsed in
      order with conflicting settings being overriden by later files. Paths may
      be relative to paths in `gin_search_paths`.
    gin_bindings: individual gin bindings to be applied after the gin files are
      parsed. Will be applied in order with conflicting settings being overriden
      by later oens.
    skip_unknown: whether to ignore unknown bindings or raise an error (default
      behavior).
    finalize_config: whether to finalize the config so that it cannot be
      modified (default behavior).
  """
  # We import t5.data here since it includes gin configurable functions commonly
  # used by task modules.
  # TODO(adarob): Strip gin from t5.data and remove this import.
  import t5.data  # pylint:disable=unused-import,g-import-not-at-top
  # Register .gin file search paths with gin
  for gin_file_path in gin_search_paths:
    gin.add_config_file_search_path(gin_file_path)


  # Parse config files and bindings passed via flag.
  gin.parse_config_files_and_bindings(
      gin_files,
      gin_bindings,
      skip_unknown=skip_unknown,
      finalize_config=finalize_config)
  logging.info('Gin Configuration:\n%s', gin.config_str())


def rewrite_gin_args(args: Sequence[str]) -> Sequence[str]:
  """Rewrite `--gin.NAME=VALUE` flags to `--gin_bindings=NAME=VALUE`."""

  def _rewrite_gin_arg(arg):
    if not arg.startswith('--gin.'):
      return arg
    if '=' not in arg:
      raise ValueError(
          "Gin bindings must be of the form '--gin.<param>=<value>', got: " +
          arg)
    # Strip '--gin.'
    arg = arg[6:]
    name, value = arg.split('=', maxsplit=1)
    r_arg = f'--gin_bindings={name} = {value}'
    print(f'Rewritten gin arg: {r_arg}')
    return r_arg

  return [_rewrite_gin_arg(arg) for arg in args]


@gin.register
def summarize_gin_config(model_dir: str,
                         summary_writer: Optional[metric_writers.MetricWriter],
                         step: int):
  """Writes gin config to the model dir and TensorBoard summary."""
  if jax.process_index() == 0:
    config_str = gin.config_str()
    tf.io.gfile.makedirs(model_dir)
    # Write the config as JSON.
    with tf.io.gfile.GFile(os.path.join(model_dir, 'config.gin'), 'w') as f:
      f.write(config_str)
    # Include a raw dump of the json as a text summary.
    if summary_writer is not None:
      summary_writer.write_texts(step, {'config': gin.markdown(config_str)})
      summary_writer.flush()


def run(main):
  """Wrapper for app.run that rewrites gin args before parsing."""
  app.run(
      main,
      flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))


# ====================== Configurable Utility Functions ======================


@gin.configurable
def sum_fn(var1=gin.REQUIRED, var2=gin.REQUIRED):
  """sum function to use inside gin files."""
  return var1 + var2


@gin.configurable
def bool_fn(var1=gin.REQUIRED):
  """bool function to use inside gin files."""
  return bool(var1)
