import torch
from flux.model import Flux
from torch import Tensor
from scaling_cache.cache_functions import cache_init

CAL_AMOUNT_LIST = []

def denoise_cache(
    model: Flux,
    model_name: str,
    mode: str,
    # model input
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    # sampling parameters
    timesteps: list[float],
    guidance: float = 4.0,
):  
    # init cache
    # this is ignored for schnell
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        model.current['t'] = t_curr
        pred = model(
            img=img,
            img_ids=img_ids,
            txt=txt,
            txt_ids=txt_ids,
            y=vec,
            timesteps=t_vec,
            guidance=guidance_vec,
        )
        #print(img.shape)
        img = img + (t_prev - t_curr) * pred
        model.current['step'] += 1
    model.cache_release()
    return img
