from typing import Dict
import torch
import math

def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
    """
    Compute derivative approximation.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    """
    difference_distance = current['activated_steps'][-1] - current['activated_steps'][-2]
    # difference_distance = current['activated_times'][-1] - current['activated_times'][-2]

    updated_taylor_factors = {}
    updated_taylor_factors[0] = feature
    max_order = current.get('layer_max_order')
    for i in range(max_order):
        if (cache_dic['cache'][-1][current['layer']][current['module']].get(i, None) is not None) and (
                current['step'] < (current['num_steps'] - cache_dic['first_enhance'] + 1)):
            updated_taylor_factors[i + 1] = (updated_taylor_factors[i] -
                                             cache_dic['cache'][-1][current['layer']][current['module']][
                                                 i]) / difference_distance
        else:
            break
    
    cache_dic['cache'][-1][current['layer']][current['module']] = updated_taylor_factors

def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor: 
    """
    Compute Taylor expansion error.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    """
    x = current['step'] - current['activated_steps'][-1]
    # x = current['t'] - current['activated_times'][-1]
    output = 0

    for i in range(len(cache_dic['cache'][-1][current['layer']][current['module']])):
        output += (1 / math.factorial(i)) * cache_dic['cache'][-1][current['layer']][current['module']][i] * (x ** i)
    
    return output

def taylor_cache_init(cache_dic: Dict, current: Dict):
    """
    Initialize Taylor cache and expand storage for different-order derivatives.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    """
    if current['step'] == (current['num_steps'] - 1):
        cache_dic['cache'][-1][current['layer']][current['module']] = {}
        max_order = current.get('layer_max_order')
        for i in range(max_order + 1):
            cache_dic['cache'][-1][current['layer']][current['module']][i] = None


def cal_type(cache_dic, current):
    '''
    Determine calculation type for this step
    '''
    last_steps = (current['step'] <= 2)
    # cache_dic['first_enhance']个step被增强
    first_steps = (current['step'] > (current['num_steps'] - cache_dic['first_enhance'] - 1))

    fresh_interval = cache_dic['cache_interval']

    if (first_steps) or (last_steps) or (cache_dic['cache_counter'] == fresh_interval - 1):
        current['type'] = 'full'
        cache_dic['cache_counter'] = 0
        current['activated_steps'].append(current['step'])  # 只有这句有用
        # current['activated_times'].append(current['t'])

    else:
        cache_dic['cache_counter'] += 1
        current['type'] = 'Taylor'


def cache_init(model_kwargs, num_steps):
    '''
    Initialization for cache.
    '''
    cache_dic = {}
    cache = {}
    cache[-1] = {}

    for j in range(28):
        cache[-1][j] = {}
    for i in range(num_steps):
        cache[i] = {}
        for j in range(28):
            cache[i][j] = {}

    cache_dic['cache'] = cache
    cache_dic['flops'] = 0.0
    cache_dic['cache_interval'] = model_kwargs['cache_interval']
    cache_dic['test_FLOPs'] = model_kwargs['test_FLOPs']
    cache_dic['first_enhance'] = 2
    cache_dic['cache_counter'] = 0

    current = {}
    current['num_steps'] = num_steps
    current['activated_steps'] = [49]
    return cache_dic, current
