import torch
from pipeline import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel


transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")

draft_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")

# parameters
inference_steps = 50
spec_k_2 = 10
spec_k_1 = 10
stage_1= inference_steps*0.10
seed = 42
prompt = "A person on a bench, and one on a wheelchair sitting by a seawall looking out toward the ocean"


def prompt_serve(prompt):    
    draft_pipe(
            height=512,
            width=512,
            prompt=prompt,
            num_inference_steps=inference_steps,
            guidance_scale=3.5,
            generator=torch.Generator(device="cuda").manual_seed(seed),
        )

    latents, latent_image_ids, timesteps, num_warmup_steps, guidance, output_type, return_dict, text_ids, prompt_embeds, pooled_prompt_embeds, joint_attention_kwargs = pipe(
            height=512,
            width=512,
            prompt=prompt,
            num_inference_steps=inference_steps,
            guidance_scale=3.5,
            generator=torch.Generator(device="cuda").manual_seed(seed),
        )

    for i, t in enumerate(timesteps):
        timestep = t.expand(latents.shape[0]).to(latents.dtype)
        if i < stage_1 or i%spec_k_2 == 0:
            noise_pred = pipe.denoise_forward(
                    latents,timestep,guidance,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,joint_attention_kwargs)
            latents = pipe.get_next_noise(noise_pred, i, t, latents, return_dict=False)

        else:
            noise_pred = draft_pipe.denoise_forward(
                    latents,timestep,guidance,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,joint_attention_kwargs)
            latents = pipe.get_next_noise(noise_pred,i,  t, latents, return_dict=False)

    output = pipe.__generate_image__(latents, 512, 512, return_dict, output_type).images[0]
    
    return output

image = prompt_serve(prompt = prompt)
image.save("flux-fast.png")
