from typing import Dict
import torch
import math
from momentum import momentum_update


def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
    """
    Compute derivative approximation.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    """
    difference_distance = current['activated_steps'][-1] - current['activated_steps'][-2]
    # difference_distance = current['activated_times'][-1] - current['activated_times'][-2]

    updated_taylor_factors = {}
    updated_taylor_factors[0] = feature

    for i in range(cache_dic['max_order']):
        if (current['step'] == current['num_steps'] - 1):
            updated_taylor_factors[i + 1] = 0 
        if (cache_dic['cache'][-1][current['layer']][current['module']].get(i, None) is not None) and (current['step'] < (current['num_steps'] - cache_dic['first_enhance'] + 1)):
            updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic['cache'][-1][current['layer']][current['module']][i]) / difference_distance
        else:
            break
    
    cache_dic['cache'][-1][current['layer']][current['module']] = updated_taylor_factors

    momentum_update(cache_dic, current, feature, updated_taylor_factors)





def taylor_cache_init(cache_dic: Dict, current: Dict):
    """
    Initialize Taylor cache and expand storage for different-order derivatives.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    """
    if current['step'] == (current['num_steps'] - 1):
        cache_dic['cache'][-1][current['layer']][current['module']] = {}
        cache_dic['Taylor_momentum'][-1][current['layer']][current['module']]={}
        cache_dic['momentum_beta'][-1][current['layer']][current['module']] = (-0.01)


def taylor_formula(module_dict: Dict, distance: int) -> torch.Tensor:
    """
    Compute Taylor expansion error.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    """

    output = 0

    for i in range(len(module_dict)):
        output += (1 / math.factorial(i)) * module_dict[i] * (distance ** i)

    return output