import torch
from pipeline import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel


transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")

draft_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")

# parameters
inference_steps = 50
spec_k_2 = 9
spec_k_1 = 3
stage_2 = inference_steps*0.15
stage_1 = inference_steps*0.08
seed = 42
prompt = "A person on a bench, and one on a wheelchair sitting by a seawall looking out toward the ocean"


draft_pipe(
    height=512,
    width=512,
    prompt=prompt,
    num_inference_steps=50,
    guidance_scale=3.5,
    generator=torch.Generator(device="cuda").manual_seed(seed),
)


latents, latent_image_ids, timesteps, num_warmup_steps, guidance, output_type, return_dict, text_ids, prompt_embeds, pooled_prompt_embeds, joint_attention_kwargs = pipe(
    height=512,
    width=512,
    prompt=prompt,
    num_inference_steps=50,
    guidance_scale=3.5,
    generator=torch.Generator(device="cuda").manual_seed(seed),
)

skip_cnt = 0

continue_flag = False

for i, t in enumerate(timesteps):
    if continue_flag == True:
        continue_flag = False
        continue
    
    timestep = t.expand(latents.shape[0]).to(latents.dtype)
    
    if (i < stage_1) or (i < stage_2 and i % spec_k_1 == 0):
        noise_pred = pipe.denoise_forward(
                latents,timestep,guidance,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,joint_attention_kwargs)
        latents = pipe.get_next_noise(noise_pred, i, t, latents, return_dict=False)
        
    elif skip_cnt < (spec_k_2 - 1):
        noise_pred = draft_pipe.denoise_forward(
                latents,timestep,guidance,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,joint_attention_kwargs)
        latents = pipe.get_next_noise(noise_pred, i, t, latents, return_dict=False)
        if i >= stage_2:
            skip_cnt = skip_cnt + 1
    
    elif skip_cnt == (spec_k_2 - 1):
        noise_pred_draft = draft_pipe.denoise_forward(
                latents,timestep,guidance,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,joint_attention_kwargs)
        latents_draft = pipe.get_next_noise(noise_pred_draft, i, t, latents, return_dict=False)
            
        latents_target_2 = None
        if i < (len(timesteps) - 1):
            new_t = timesteps[i+1]
            new_timestep = new_t.expand(latents.shape[0]).to(latents.dtype)
            noise_pred_target_1, noise_pred_target_2 = pipe.denoise_batch_forward(
                    latents,timestep,latents_draft,new_timestep,guidance,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,joint_attention_kwargs)
            latents_target = pipe.get_next_noise(noise_pred_target_1, i, t, latents, return_dict=False) 
            latents_target_2 = pipe.get_next_noise(noise_pred_target_2, i+1, new_t, latents_draft, return_dict=False)
        else:
            noise_pred_target = pipe.denoise_forward(
                latents,timestep,guidance,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,joint_attention_kwargs)
            latents_target = pipe.get_next_noise(noise_pred_target, i, t, latents, return_dict=False) 
        l1_loss = torch.nn.functional.l1_loss(latents_draft, latents_target)
        skip_cnt = 0
        if l1_loss.item() > 0.02:
            print(f"verify loss: {i}, loss: {l1_loss.item()}")
            latents = latents_target
        else:
            if latents_target_2 is not None:
                latents = latents_target_2
                continue_flag = True
            else:
                latents = latents_target

output = pipe.__generate_image__(latents, 512, 512, return_dict, output_type).images[0]

output.save("flux-slow.png")
