import numpy as np

'''
each scaler is a function that takes as input X (B x N x Din), adj (B x N x N),
and avg_d (dictionary containing averages over training set)
and returns X_scaled (B x N x Din) as output
'''


def scale_identity(h, D=None, avg_d=None):
    return h


def scale_amplification(h, D, avg_d):
    # log(D + 1) / d * h
    # where d is the average of the ``log(D + 1)`` in the training set
    return h * (np.log(D + 1) / avg_d["log"])


def scale_attenuation(h, D, avg_d):
    # (log(D + 1))^-1 / d * X
    # where d is the average of the ``log(D + 1))^-1`` in the training set
    return h * (avg_d["log"] / np.log(D + 1))


SCALERS = {"identity": scale_identity,
           "amplification": scale_amplification,
           "attenuation": scale_attenuation}
