# coding: utf-8

import pathlib
import torch

import tools

# ---------------------------------------------------------------------------- #
# Automated attack loader

def register(name, unchecked, check):
  """ Simple registration-wrapper helper.
  Args:
    name      Attack name
    unchecked Associated function (see module description)
    check     Parameter validity check function
  """
  global attacks
  # Check if name already in use
  if name in attacks:
    tools.warning(f"Unable to register {name!r} attack: name already in use")
    return
  # Closure wrapping the call with checks
  def checked(f_real, **kwargs):
    # Check parameter validity
    message = check(f_real=f_real, **kwargs)
    if message is not None:
      raise tools.UserException(f"Attack {name!r} cannot be used with the given parameters: {message}")
    # Attack
    res = unchecked(f_real=f_real, **kwargs)
    # Forward asserted return value
    assert isinstance(res, list) and len(res) == f_real, f"Expected attack {name!r} to return a list of {f_real} Byzantine gradients, got {res!r}"
    return res
  # Select which function to call by default
  func = checked if __debug__ else unchecked
  # Bind all the (sub) functions to the selected function
  setattr(func, "check", check)
  setattr(func, "checked", checked)
  setattr(func, "unchecked", unchecked)
  # Export the selected function with the associated name
  attacks[name] = func

# Registered attacks (mapping name -> attack)
attacks = dict()

# Load native and all local modules
with tools.Context("attacks", None):
  tools.import_directory(pathlib.Path(__file__).parent, globals())

# Bind/overwrite the attack names with the associated attacks in globals()
for name, attack in attacks.items():
  globals()[name] = attack
