from typing import Dict
import torch

def momentum_update(cache_dic: Dict, current: Dict, feature: torch.Tensor, updated_taylor_factors):
    """
    Compute derivative approximation.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    """

    momentum={}
    momentum[0]=feature

    # Access the relevant cache data once to reduce lookup overhead
    # cache_data_layer = cache_dic['cache'][-1][current['layer']][current['module']]

    beta = cache_dic['momentum_beta'][-1][current['stream']][current['layer']][current['module']]

    t = current['step']
    fix = 1 - abs(beta) ** t
    for i in range(cache_dic['max_order']):
        # cache_data = cache_data_layer.get(i, None)
        if (current['step'] <= cache_dic['first_enhance'] - 2):
            if (i == 0):
                momentum[i + 1] = feature
            else:
                momentum[i + 1] = 0
        elif ((cache_dic['Taylor_momentum'][-1][current['stream']][current['layer']][current['module']].get(i, None) is not None) and
              (current['step'] > (cache_dic['first_enhance'] - 2))):
            momentum[i + 1] = (beta *cache_dic['Taylor_momentum'][-1][current['stream']][current['layer']][current['module']][i + 1] + (1 - beta) * updated_taylor_factors[i + 1]) / fix


        else:
            break

    cache_dic['Taylor_momentum'][-1][current['stream']][current['layer']][current['module']]= momentum


def update_beta(cache_dic, current, real_output, formula_value):


    norm_real_output = torch.norm(real_output)
    norm_formula_value =  torch.norm(formula_value)
    gamma = current['gamma']
    beta = cache_dic['momentum_beta'][-1][current['stream']][current['layer']][current['module']]
    if norm_real_output > norm_formula_value:
        beta += gamma
    elif norm_real_output < norm_formula_value:
        beta -= gamma

    a=current['a']
    b=current['b']

    if (beta >b):
        beta = b
    elif (beta <a):
        beta = a

    cache_dic['momentum_beta'][-1][current['stream']][current['layer']][current['module']] = beta

