import math
import torch

def cache_init(threshold=5, use_cpu=False):   
    '''
    Initialization for cache.
    '''
    cache = {}
   
    cache['fresh_threshold'] = threshold

    cache['current_step'] = 0
    cache['total_steps'] = 50
    cache['count'] = 0
    cache['order'] = 0
    
    cache['residual_cache'] = None
    cache['activated_steps'] = []
    cache['use_cpu'] = use_cpu

    return cache


def cache_init_no_cache(threshold=5):   
    '''
    Initialization for cache.
    '''
    cache = {}
    cache['fresh_threshold'] = threshold
    cache['current_step'] = 0
    cache['total_steps'] = 50
    cache['count'] = 0

    return cache


def update_step(cache, use_cache=True):
    cache['current_step'] = cache['current_step'] + 1
    if cache['current_step'] >= cache['total_steps']:
        cache['current_step'] = 0
        cache['count'] = 0
        if use_cache:
            cache['residual_cache'] = None
            cache['activated_steps'].clear()
    return cache

def update_residual_cache(cache, residual_cache):
    cache['activated_steps'].append(cache['current_step'])

    if cache['use_cpu']:
        cache['residual_cache'] = residual_cache.to("cpu")
    else:
        cache['residual_cache'] = residual_cache

    return cache


def update_feature(cache, current_feature):
    if cache['use_cpu']:
        return current_feature + cache['residual_cache'].to("cuda")
    else:
        return current_feature + cache['residual_cache']


def cal_if_full(cache):
    if cache['current_step'] == 0 or cache['count'] == cache['fresh_threshold'] - 1:
        full_flag = True
    else:
        full_flag = False

    cache['count'] = (cache['count'] + 1) % cache['fresh_threshold']
    return full_flag