from videosys.utils.logging import logger

BWCACHE_MANAGER = None

class BWCacheConfig:
    def __init__(
        self,
        thresh: float = None,
        reuse_interval: int = None,
        last_step: float = None,
    ):
        self.thresh = thresh
        self.reuse_interval = reuse_interval
        self.last_step = last_step

class BWCacheManager:
    def __init__(self, config: BWCacheConfig):
        self.config: BWCacheConfig = config

        init_prompt = f"Init BWCache."
        init_prompt += f" thresh: {config.thresh}, reuse_interval: {config.reuse_interval}, last_step: {config.last_step}."
        logger.info(init_prompt)

    def if_reuse_cache(self, acu_l1, depth):
        return acu_l1 / depth < self.config.thresh
    
    def get_reuse_interval(self):
        return self.config.reuse_interval
    
    def get_last_step(self):
        return self.config.last_step


def if_reuse_cache(acu_l1, depth):
    return BWCACHE_MANAGER.if_reuse_cache(acu_l1, depth)

def get_reuse_interval():
    return BWCACHE_MANAGER.get_reuse_interval()

def get_last_step():
    return BWCACHE_MANAGER.get_last_step()

def set_bwcache_manager(config: BWCacheConfig):
    global BWCACHE_MANAGER
    BWCACHE_MANAGER = BWCacheManager(config)

def enable_bwcache():
    if BWCACHE_MANAGER is None:
        return False
    return True