# based on 
# https://bitbashing.io/comparing-floats.html
# https://www.mathworks.com/matlabcentral/answers/475832-ulp-measurement-implementation-and-flavors
import torch
import decompose_fp_cuda


def get_ulp_distance_cuda(x, y):
    """ 
        [Reference] https://bitbashing.io/comparing-floats.html
        Notice an interesting corollary: adjacent floats (of the same sign) have adjacent integer values 
        when reinterpreted as such. This reinterpretation is sometimes called type punning, 
        and we can use it to calculate the distance between values in ULPs.
    """
    s_x = x.clone().int()
    e_x = x.clone().int()
    f_x = x.clone().int()

    s_y = y.clone().int()
    e_y = y.clone().int()
    f_y = y.clone().int()

    decompose_fp_cuda.forward(s_x, e_x, f_x, x)
    decompose_fp_cuda.forward(s_y, e_y, f_y, y)

    # type-punning of float value
    x_int = (e_x << 23) + f_x
    y_int = (e_y << 23) + f_y

    # get ulp distance 
    # NOTE: we should not calculate ulp distance for NaN and inf, but ignore them in this function
    ulp_distance = (x_int - y_int).abs()
    
    # Max distance for differently-signed float
    ulp_distance[s_x!=s_y] = 0x7FFFFFFF # max int

    return ulp_distance

