"""Various metrics re-written in tf for speed."""
import tensorflow as tf
import tensorflow_probability as tfp

###############################################################################
# Spearman corr-coeff.
#
# NOTE: The results appear to be similar to but slightly different from scipy's implementation.


@tf.function
def _rank(x):
    # TODO: Vectorize this.
    return tf.scatter_nd(
        tf.argsort(x, axis=-1, direction="ASCENDING")[:, None],
        # +1 to get the rank starting in 1 instead of 0
        1 + tf.range(tf.shape(x)[0], dtype=x.dtype)[:, None],
        tf.shape(x[:, None]),
    )


@tf.function
def spearmanr_vv(x, y):
    return tf.reshape(tfp.stats.correlation(_rank(x), _rank(y)), [])


@tf.function
def spearmanr_mv(X, y):
    # TODO: Vectorize only the rank computation, call the tfp.stats.correlation on the matrices.
    return tf.vectorized_map(lambda x: spearmanr_vv(x, y), tf.transpose(X))


# @tf.function
# def spearmanr_mv(X, y):
#     # TODO: Vectorize only the rank computation, call the tfp.stats.correlation on the matrices.
#     rk_X = tf.vectorized_map(lambda x: tf.squeeze(_rank(x), axis=-1), tf.transpose(X))
#     rk_X = tf.transpose(rk_X)
#     return tf.squeeze(tfp.stats.correlation(rk_X, _rank(y)), axis=-1)


# def spearmanr(x, y, dtype=tf.float32):
#     x = _set_up_spearmanr(x, dtype)
#     y = _set_up_spearmanr(y, dtype)
#     return _spearmanr(x, y)
