# based on 
# https://bitbashing.io/comparing-floats.html
# https://www.mathworks.com/matlabcentral/answers/475832-ulp-measurement-implementation-and-flavors
import torch
from prealign_utils import decompose_fp

# FLT_EPSILON of C standard libraray represents difference between 1 and the least value greater than 1 that is representable
FLT_EPSILON = 1e-5


def get_ulp_distance(x, y):
    """ 
        [Reference] https://bitbashing.io/comparing-floats.html
        Notice an interesting corollary: adjacent floats (of the same sign) have adjacent integer values 
        when reinterpreted as such. This reinterpretation is sometimes called type punning, 
        and we can use it to calculate the distance between values in ULPs.
    """
    s_x, e_x, f_x = decompose_fp(x)
    s_y, e_y, f_y = decompose_fp(y)

    # type-punning of float value
    x_int = (e_x << 23) + f_x
    y_int = (e_y << 23) + f_y

    # get ulp distance 
    # NOTE: we should not calculate ulp distance for NaN and inf, but ignore them in this function
    ulp_distance = (x_int - y_int).abs()
    
    # Max distance for differently-signed float
    ulp_distance[s_x!=s_y] = 0x7FFFFFFF # max int

    return ulp_distance

def relativelyEqual(x, y, maxRelativeDiff=1e-5):
    """ 
        [Reference] https://bitbashing.io/comparing-floats.html
    """

    # get difference
    diff = (x - y).abs()

    # scale to the largets value
    max_val = torch.max(x.abs(), y.abs())

    # get scaled epsilon
    scaledEpsilon = maxRelativeDiff * max_val

    return diff.le(scaledEpsilon)


def  get_relative_difference(x, y):
    """ 
        [Reference] https://bitbashing.io/comparing-floats.html
    """
    return (x - y).abs() / torch.min(x.abs(), y.abs())

def  get_relative_error(answer, x):
    """ 
        [Reference] https://bitbashing.io/comparing-floats.html
    """
    zero_loc = answer.eq(0)
    error = ( (answer - x) / answer ).abs()

    # compensate error for 0
    # exact same
    exact_same_zero_loc = zero_loc & (answer.eq(x))
    error[exact_same_zero_loc] = 0
    # x is not 0
    not_same_zero_loc = zero_loc & (~answer.eq(x))
    error[not_same_zero_loc] = x[not_same_zero_loc].abs()

    return error

def get_epsilon_difference(x, y, maxRelativeDiff=1e-5):
    return get_relative_difference(x,y) / maxRelativeDiff


# get l2 norm error
def get_l2_norm_error(reference, data):

    diff = reference - data
    error = torch.mul(diff, diff).sum()
    ref = torch.mul(reference, reference).sum()

    if ( ref.abs() < 1e-7 ):
        print("ERROR, reference l2-norm is 0")
        return 0
    else:
        return error.sqrt() / ref.sqrt()

