def cache_init(model_kwargs, num_steps):   
    '''
    Initialization for cache.
    '''
    cache_dic = {}
    cache = {}
    cache[-1]={}
    Taylor_momentum={}
    Taylor_momentum[-1]={}
    # Taylor_momentum[-2]={}
    momentum_beta={}
    momentum_beta[-1]={}
    for j in range(28):
        cache[-1][j] = {}
        Taylor_momentum[-1][j]={}
        # Taylor_momentum[-2][j]={}
        momentum_beta[-1][j] = {}


    for i in range(num_steps):
        cache[i]={}
        Taylor_momentum[i]={}
        momentum_beta[i]={}
        for j in range(28):
            cache[i][j] = {}
            Taylor_momentum[i][j] = {}
            momentum_beta[i][j] = {}

    cache_dic['cache']                = cache
    cache_dic['flops']                = 0.0
    cache_dic['interval']             = model_kwargs['interval']
    cache_dic['max_order']            = model_kwargs['max_order']
    cache_dic['test_FLOPs']           = model_kwargs['test_FLOPs']
    cache_dic['first_enhance']        = 2
    cache_dic['cache_counter']        = 0
    cache_dic['Taylor_momentum']      = Taylor_momentum
    cache_dic['momentum_beta']        = momentum_beta


    current = {}
    current['num_steps'] = num_steps
    current['activated_steps'] = [49]
    # current['momentum_beta'] = (0)
    current['a'] = model_kwargs['a']
    current['b'] = model_kwargs['b']
    current['gamma'] = model_kwargs['gamma']

    return cache_dic, current
    