from diffusers import StableDiffusionPipeline, AutoencoderKL, StableDiffusionXLPipeline, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0
from LayerDiffuse_DiffusersCLI.diffusers_kdiffusion_sdxl import KDiffusionStableDiffusionXLPipeline

from transformers import CLIPTextModel, CLIPTokenizer
import safetensors.torch as sf
import torch
import os
from PIL import Image

from LayerDiffuse_DiffusersCLI.lib_layerdiffuse.utils import download_model
from LayerDiffuse_DiffusersCLI.lib_layerdiffuse.vae import TransparentVAEDecoder, TransparentVAEEncoder
from LayerDiffuse_DiffusersCLI import memory_management

# pipeline: KDiffusionStableDiffusionXLPipeline
def prepare_model():
    # Load models
    global pipeline
    sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'
    tokenizer = CLIPTokenizer.from_pretrained(
        sdxl_name, subfolder="tokenizer", local_files_only=True)
    tokenizer_2 = CLIPTokenizer.from_pretrained(
        sdxl_name, subfolder="tokenizer_2", local_files_only=True)
    text_encoder = CLIPTextModel.from_pretrained(
        sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", local_files_only=True)
    text_encoder_2 = CLIPTextModel.from_pretrained(
        sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", local_files_only=True)
    # vae = AutoencoderKL.from_pretrained(
    #     sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16")  # bfloat16 vae
    vae = AutoencoderKL.from_pretrained(
        "madebyollin/sdxl-vae-fp16-fix",
        # "pretrained/SDXL/vae_fp16",
        # local_files_only=True,
        use_safetensors=True,
        torch_dtype=torch.float16,
        local_files_only=True,
    )
    unet = UNet2DConditionModel.from_pretrained(
        sdxl_name, subfolder="unet", torch_dtype=torch.float16, local_files_only=True)

    # pipe = pipe.to("cuda")

    unet.set_attn_processor(AttnProcessor2_0())
    vae.set_attn_processor(AttnProcessor2_0())

    # Download Mode
    path_ld_diffusers_sdxl_attn = download_model(
        url='https://huggingface.co/lllyasviel/LayerDiffuse_Diffusers/resolve/main/ld_diffusers_sdxl_attn.safetensors',
        local_path='/root/.cache/layerdiffusion/models/ld_diffusers_sdxl_attn.safetensors'
    )

    path_ld_diffusers_sdxl_vae_transparent_encoder = download_model(
        url='https://huggingface.co/lllyasviel/LayerDiffuse_Diffusers/resolve/main/ld_diffusers_sdxl_vae_transparent_encoder.safetensors',
        local_path='/root/.cache/layerdiffusion/models/ld_diffusers_sdxl_vae_transparent_encoder.safetensors'
    )

    path_ld_diffusers_sdxl_vae_transparent_decoder = download_model(
        url='https://huggingface.co/lllyasviel/LayerDiffuse_Diffusers/resolve/main/ld_diffusers_sdxl_vae_transparent_decoder.safetensors',
        local_path='/root/.cache/layerdiffusion/models/ld_diffusers_sdxl_vae_transparent_decoder.safetensors'
    )

    # Modify

    sd_offset = sf.load_file(path_ld_diffusers_sdxl_attn)
    sd_offset = {k: v.float() for k, v in sd_offset.items()}
    sd_origin = unet.state_dict()
    keys = sd_origin.keys()
    sd_merged = {}

    for k in sd_origin.keys():
        if k in sd_offset:
            sd_merged[k] = sd_origin[k] + sd_offset[k]
        else:
            sd_merged[k] = sd_origin[k]

    unet.load_state_dict(sd_merged, strict=True)
    del sd_offset, sd_origin, sd_merged, keys, k

    global transparent_encoder, transparent_decoder
    transparent_encoder = TransparentVAEEncoder(path_ld_diffusers_sdxl_vae_transparent_encoder).to("cuda")
    transparent_decoder = TransparentVAEDecoder(path_ld_diffusers_sdxl_vae_transparent_decoder).to("cuda")

    pipeline = KDiffusionStableDiffusionXLPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        text_encoder_2=text_encoder_2,
        tokenizer_2=tokenizer_2,
        unet=unet,
        scheduler=None,  # We completely give up diffusers sampling system and use A1111's method
    )
    vae.to("cuda")
    text_encoder.to("cuda")
    # tokenizer.to("cuda")
    text_encoder_2.to("cuda")
    # tokenizer_2.to("cuda")
    unet.to("cuda")
    # pipeline.to("cuda")
def layerDiffusion_text2img(prompt, output_dir):
    global transparent_encoder, transparent_decoder
    global pipeline
    

    default_negative = 'low quality, blurry, noisy, overexposed, artifacts, low detail, grainy, over-saturated'
    with torch.inference_mode():
        guidance_scale = 5.0

        rng = torch.Generator(device=memory_management.gpu).manual_seed(12345)

        # memory_management.load_models_to_gpu([text_encoder, text_encoder_2])

        positive_cond, positive_pooler = pipeline.encode_cropped_prompt_77tokens(
            # 'glass bottle, high quality'
            prompt,
        )

        negative_cond, negative_pooler = pipeline.encode_cropped_prompt_77tokens(default_negative)

        # memory_management.load_models_to_gpu([unet])
        initial_latent = torch.zeros(size=(1, 4, 128, 128), dtype=pipeline.unet.dtype, device=pipeline.unet.device)
        latents = pipeline(
            initial_latent=initial_latent,
            strength=1.0,
            num_inference_steps=25,
            batch_size=1,
            prompt_embeds=positive_cond,
            negative_prompt_embeds=negative_cond,
            pooled_prompt_embeds=positive_pooler,
            negative_pooled_prompt_embeds=negative_pooler,
            generator=rng,
            guidance_scale=guidance_scale,
        ).images

        # memory_management.load_models_to_gpu([vae, transparent_decoder, transparent_encoder])
        latents = latents.to(dtype=pipeline.vae.dtype, device=pipeline.vae.device) / pipeline.vae.config.scaling_factor
        result_list, vis_list = transparent_decoder(pipeline.vae, latents)

        result_list_path = os.path.join(output_dir, f"{str(prompt.replace(' ', '_'))}.png")
        # vis_list_path = os.path.join(output_dir, f"{str(prompt.replace(' ', '_'))}.png")

        Image.fromarray(result_list[0]).save(result_list_path, format='PNG')
        # Image.fromarray(vis_list[0]).save(f'./t2i_{i}_visualization.png', format='PNG')

        # for i, image in enumerate(result_list):
        #     Image.fromarray(image).save(f'./t2i_{i}_transparent.png', format='PNG')

        # for i, image in enumerate(vis_list):
        #     Image.fromarray(image).save(f'./t2i_{i}_visualization.png', format='PNG')
        return result_list_path