from absl import flags

from influence.full_last import compute_influence_full_last

flags.DEFINE_enum('if_method', 'full_last', ['full_last'], help='method to estimate influence function')
flags.DEFINE_enum('inv_method', 'base', ["base"],
help='matrix inverse computation method')
FLAGS = flags.FLAGS

def compute_influence(state, data_loader, data_loader_hess, num_classes, num_samples, *args, **kwargs):
    estimate_method = globals()[f'compute_influence_{FLAGS.if_method}']
    return estimate_method(state, data_loader, data_loader_hess, num_classes, num_samples, *args, **kwargs)