from typing import Dict
import torch


def momentum_update(cache_dic: Dict, current: Dict, feature: torch.Tensor, updated_taylor_factors):
    """
    Compute derivative approximation.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    """

    momentum={}
    momentum[0]=feature

    beta = cache_dic['momentum_beta'][-1][current['layer']][current['module']]

    t = current['step']
    fix = 1 - abs(beta) ** t
    for i in range(cache_dic['max_order']):
        if (current['step'] >= current['num_steps'] - 1):
            if (i==0):
                momentum[i + 1] = feature
            else:
                momentum[i + 1] = 0
        elif ((cache_dic['Taylor_momentum'][-1][current['layer']][current['module']].get(i, None) is not None) and
               (current['step'] < (current['num_steps'] - cache_dic['first_enhance'] + 1))):

            momentum[i + 1] = (beta *cache_dic['Taylor_momentum'][-1][current['layer']][current['module']][i + 1] + (1 - beta) * updated_taylor_factors[i + 1]) / fix

        else:
            break


    cache_dic['Taylor_momentum'][-1][current['layer']][current['module']]= momentum


def update_beta(cache_dic, current, real_output, formula_value):


    #L2范数比较：
    norm_real_output = torch.norm(real_output)
    norm_formula_value =  torch.norm(formula_value)
    gamma = current['gamma']
    beta = cache_dic['momentum_beta'][-1][current['layer']][current['module']]
    if norm_real_output > norm_formula_value:
        beta += gamma
    elif norm_real_output < norm_formula_value:
        beta -= gamma

    a=current['a']
    b=current['b']

    if (beta >b):
        beta = b
    elif (beta <a):
        beta = a

    cache_dic['momentum_beta'][-1][current['layer']][current['module']] = beta

