import tensorflow as tf

_ALGORITHMS = dict()


def register(name):

  def add_to_dict(fn):
    global _ALGORITHMS
    _ALGORITHMS[name] = fn
    return fn

  return add_to_dict


def get_algorithm(args, G, F, X, Y, strategy: tf.distribute.Strategy = None):
  if args.algorithm not in _ALGORITHMS:
    raise KeyError(f'algorithm {args.algorithm} not found')

  return _ALGORITHMS[args.algorithm](args, G, F, X, Y, strategy)
