def init_cache(diffloss_d, num_sample_step, cache_type):
    # cache = {}
    # cache[-1] = {}
    # for index in range(diffloss_d):
    #     cache[-1][index] = {}
    #     cache[-1][index]['mlp_index'] = -1
    # cache['cache_type'] = cache_type
    # for i in range(num_sample_step):
    #     cache[i] = {}
    #     for j in range(diffloss_d):
    #         cache[i][j] = {}
    #         cache[i][j]['mlp_index'] = -1
    
    # return cache
    
    return {
        -1: {index: {} for index in range(diffloss_d)}
    }
    

    