import os
import numpy as np
import torch
import memory_management
import safetensors.torch as sf
from safetensors.torch import load_file
from PIL import Image
import diffusers
print(diffusers.__file__)
from diffusers_kdiffusion_sdxl import KDiffusionStableDiffusionXLPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPTextModel, CLIPTokenizer
from lib_layerdiffuse.vae import TransparentVAEDecoder, TransparentVAEEncoder
from lib_layerdiffuse.utils import download_model
import argparse
import glob



def generate(dataset_name, class_names, args):
    # Load models
    sdxl_name = 'SG161222/RealVisXL_V4.0'
    tokenizer = CLIPTokenizer.from_pretrained(
        sdxl_name, subfolder="tokenizer")
    tokenizer_2 = CLIPTokenizer.from_pretrained(
        sdxl_name, subfolder="tokenizer_2")
    text_encoder = CLIPTextModel.from_pretrained(
        sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16")
    text_encoder_2 = CLIPTextModel.from_pretrained(
        sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16")
    vae = AutoencoderKL.from_pretrained(
        sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16")  # bfloat16 vae
    unet = UNet2DConditionModel.from_pretrained(
        sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16")

    # This negative prompt is suggested by RealVisXL_V4 author
    # See also https://huggingface.co/SG161222/RealVisXL_V4.0
    # Note that in A111's normalization, a full "(full sentence)" is equal to "full sentence"
    # so we can just remove SG161222's braces

    # SDP
    unet.set_attn_processor(AttnProcessor2_0())
    vae.set_attn_processor(AttnProcessor2_0())

    # Download Mode
    path_ld_diffusers_sdxl_attn = download_model(
        url='https://huggingface.co/lllyasviel/LayerDiffuse_Diffusers/resolve/main/ld_diffusers_sdxl_attn.safetensors',
        local_path='path/to/icml24/LayerDiffuse_DiffusersCLI/models/ld_diffusers_sdxl_attn.safetensors'
    )

    path_ld_diffusers_sdxl_vae_transparent_encoder = download_model(
        url='https://huggingface.co/lllyasviel/LayerDiffuse_Diffusers/resolve/main/ld_diffusers_sdxl_vae_transparent_encoder.safetensors',
        local_path='path/to/icml24/LayerDiffuse_DiffusersCLI/models/ld_diffusers_sdxl_vae_transparent_encoder.safetensors'
    )

    path_ld_diffusers_sdxl_vae_transparent_decoder = download_model(
        url='https://huggingface.co/lllyasviel/LayerDiffuse_Diffusers/resolve/main/ld_diffusers_sdxl_vae_transparent_decoder.safetensors',
        local_path='path/to/icml24/LayerDiffuse_DiffusersCLI/models/ld_diffusers_sdxl_vae_transparent_decoder.safetensors'
    )

    # Modify
    sd_offset = sf.load_file(path_ld_diffusers_sdxl_attn)
    sd_origin = unet.state_dict()
    keys = sd_origin.keys()
    sd_merged = {}
    for k in sd_origin.keys():
        if k in sd_offset:
            sd_merged[k] = sd_origin[k] + sd_offset[k]
        else:
            sd_merged[k] = sd_origin[k]
    unet.load_state_dict(sd_merged, strict=True)
    del sd_offset, sd_origin, sd_merged, keys, k

    transparent_encoder = TransparentVAEEncoder(path_ld_diffusers_sdxl_vae_transparent_encoder)
    transparent_decoder = TransparentVAEDecoder(path_ld_diffusers_sdxl_vae_transparent_decoder)

    # Pipelines
    pipeline = KDiffusionStableDiffusionXLPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        text_encoder_2=text_encoder_2,
        tokenizer_2=tokenizer_2,
        unet=unet,
        scheduler=None,  # We completely give up diffusers sampling system and use A1111's method
    )
    
    class_cluster_loaded = []

    for class_id, class_name in enumerate(class_names):
        for cluster_id in range(args.nclusters):
            if class_name == 'no_detected':
                continue 
            print(f"Loading embed: {class_name} [{cluster_id}]")
            class_cluster_dir =  os.path.join(base_dir, class_name, f"cluster_{cluster_id}")

            # Paths to the textual inversion embeddings
            textual_inversion_path_1 = os.path.join("learned_embed_cdp", f"{dataset_name}-{args.strength}-{args.nclusters}clusters", class_name, f"cluster_{cluster_id}", "learned_embeds_1.safetensors")
            textual_inversion_path_2 = os.path.join("learned_embed_cdp", f"{dataset_name}-{args.strength}-{args.nclusters}clusters", class_name, f"cluster_{cluster_id}", "learned_embeds_2.safetensors")
            if not os.path.exists(textual_inversion_path_1) or not os.path.exists(textual_inversion_path_2):
                print(f"{textual_inversion_path_1} is not found")
                continue  # Skip this iteration if the condition is met
            
            state_dict1 = sf.load_file(textual_inversion_path_1)
            state_dict2 = sf.load_file(textual_inversion_path_2)
            
            state_dict1[f"<embed{class_id}_{cluster_id}>"] = state_dict1.pop('<embed>')
            state_dict2[f"<embed{class_id}_{cluster_id}>"] = state_dict2.pop('<embed>')
                
            # Load textual inversion embeddings for the current class
            pipeline.load_textual_inversion(state_dict1, token=f"<embed{class_id}_{cluster_id}>", text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
            pipeline.load_textual_inversion(state_dict2, token=f"<embed{class_id}_{cluster_id}>", text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
            class_cluster_loaded.append((class_name, class_id, cluster_id))
   
    memory_management.load_models_to_gpu([pipeline.text_encoder, pipeline.text_encoder_2, pipeline.vae, transparent_decoder, transparent_encoder, pipeline.unet])
    

    for class_name, class_id, cluster_id in class_cluster_loaded:
        if class_name == 'no_detected':
            continue 
        print(f"Processing class: {class_name} [{cluster_id}]")
        class_cluster_dir =  os.path.join(base_dir, class_name, f"cluster_{cluster_id}")
        save_dir = os.path.join("generated_fore", f"{dataset_name}-{args.strength}-{args.nclusters}clusters", class_name)
        os.makedirs(save_dir, exist_ok=True)
        positive_cond, positive_pooler = pipeline.encode_cropped_prompt_77tokens(
            prompt1=f"a photo of <embed{class_id}_{cluster_id}>", 
            prompt2=f"a photo of <embed{class_id}_{cluster_id}>"
        )
        negative_cond, negative_pooler = pipeline.encode_cropped_prompt_77tokens(
            prompt1='', 
            prompt2=''
        )
        
        # List of acceptable image file extensions
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')

        orig_imgpaths = []

        for file_name in os.listdir(class_cluster_dir):
            full_path = os.path.join(class_cluster_dir, file_name)
            # Check if the file has one of the image extensions
            if file_name.lower().endswith(image_extensions):
                orig_imgpaths.append(full_path)

        # num_train, num_total, max_strength = len(orig_imgpaths), 30, args.max_strength
        # if args.nclusters==1:
        #     strength = max_strength
        # else:
        #     ratio = num_train / num_total
        #     if ratio <= 1/3:
        #         strength = 0.0
        #     elif ratio <= 1/2:
        #         strength = max_strength / 2
        #     else: 
        #         strength = max_strength
        # if strength==0.0:
        #     for imgpath in orig_imgpaths:
        #         shutil.copy(imgpath, save_dir)
        #     continue
        
        strength = max(args.max_strength, args.strength)
        BATCH_SIZE_LIMIT = 5
        with torch.inference_mode():
            for imgpath in orig_imgpaths:
                initial_latent = [np.array(Image.open(imgpath).resize((512, 512), Image.LANCZOS))]
                initial_latent = transparent_encoder(pipeline.vae, initial_latent) * pipeline.vae.config.scaling_factor
                initial_latent = initial_latent.to(dtype=pipeline.unet.dtype, device=pipeline.unet.device)
                
                num_batches = (args.num_generated_per_image + BATCH_SIZE_LIMIT - 1) // BATCH_SIZE_LIMIT  # Calculate the number of batches needed
                all_images = []
                for batch_idx in range(num_batches):
                    current_batch_size = min(BATCH_SIZE_LIMIT, args.num_generated_per_image - batch_idx * BATCH_SIZE_LIMIT)
                    latents = pipeline(
                        initial_latent=initial_latent,
                        strength=strength,
                        num_inference_steps=25,
                        batch_size=current_batch_size,
                        prompt_embeds=positive_cond,
                        pooled_prompt_embeds=positive_pooler,
                        negative_prompt_embeds=negative_cond,
                        negative_pooled_prompt_embeds=negative_pooler,
                        generator=torch.Generator(device='cuda').manual_seed(42),
                        guidance_scale=7.0,
                    ).images
                    
                    latents = latents.to(dtype=pipeline.vae.dtype, device=pipeline.vae.device) / pipeline.vae.config.scaling_factor
                    transparent_images = transparent_decoder(pipeline.vae, latents)
                    all_images.extend(transparent_images)
                
                image_basename = os.path.basename(imgpath)
                image_basename, _ = os.path.splitext(image_basename)
                for i, image in enumerate(all_images):
                    save_path = os.path.join(save_dir, f"{image_basename}_cluster{cluster_id}_strength{strength}_{i}.png")
                    Image.fromarray(image).save(save_path, format='PNG')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate images using Stable Diffusion with Textual Inversion.")
    parser.add_argument("--num_generated_per_image", type=int, default=2, help="Number of images to generate per class.")
    parser.add_argument("--split", type=int, required=True, help="Index of the split for processing.")
    parser.add_argument("--nsplits", type=int, required=True, help="Index of the split for processing.")
    parser.add_argument("--dataset", type=str, choices=['cub', 'car', 'flower', "aircraft"])
    parser.add_argument("--strength", type=float, default=0.4)
    parser.add_argument("--max_strength", type=float, default=0.4)
    parser.add_argument("--nclusters", type=int)
    args = parser.parse_args()
    
    def split_list_into_parts(lst, num_parts):
        part_size = len(lst) // num_parts
        remainder = len(lst) % num_parts
        parts = [lst[i * part_size + min(i, remainder):(i + 1) * part_size + min(i + 1, remainder)] for i in range(num_parts)]
        return parts
    
    if args.dataset=='cub':
        dataset_name= "CUB_200_2011"
    elif args.dataset=='car':
        dataset_name= "StandfordCar"
    elif args.dataset=='aircraft':
        dataset_name= "Aircraft"
        
    base_dir = f"cdp_clusters/{dataset_name}/{args.nclusters}_clusters"
    class_names = sorted(os.listdir(base_dir))
    class_names_this_split = split_list_into_parts(class_names, args.nsplits)[args.split]
    generate(dataset_name, class_names_this_split, args)