

def calc_time_sample_prob(time_diff, alpha = -0.01, min_time_diff=0):
    time_diff = time_diff.view(-1)
    # time_diff: [B]
    prob = time_diff.new_zeros(time_diff.size())
    # print(prob.size())
    n = time_diff[time_diff>min_time_diff] # only if time_diff > 0
    probs = (n * alpha).exp() # * p
    # print(probs)
    prob[time_diff>min_time_diff] = probs
    return prob

