import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from diffusers import DDIMScheduler, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

# --- 1. Model Loading & Setup ---
def load_models(model_id="runwayml/stable-diffusion-v1-5", device="cuda"):
    """Loads Stable Diffusion model components from Hugging Face."""
    tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
    vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
    unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(device)
    scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
    return tokenizer, text_encoder, vae, unet, scheduler

def image_to_latent(image, vae):
    """Converts a PIL Image to a latent vector using the VAE."""
    image = image.resize((512, 512), Image.LANCZOS)
    image_np = np.array(image).astype(np.float32) / 255.0 * 2.0 - 1.0
    image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0).to(vae.device)
    latent = vae.encode(image_tensor).latent_dist.sample() * vae.config.scaling_factor
    return latent

# --- 2. DDIM Inversion ---
@torch.no_grad()
def invert(initial_latent, unet, scheduler, text_embeds, num_steps=50, guidance_scale=1.0):
    """Saves the latents that reconstruct the original image via DDIM Inversion."""
    print("Inverting the source image...")
    scheduler.set_timesteps(num_steps)
    latents = initial_latent.clone()
    inverted_latents = {scheduler.timesteps[0].item(): latents} # Store the last latent

    for t in tqdm(reversed(scheduler.timesteps[1:])):
        noise_pred = unet(latents, t, encoder_hidden_states=text_embeds).sample
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[1] # prev_sample
        inverted_latents[t.item()] = latents
        
    return inverted_latents

# --- 3. Core Prism-Edit Logic ---
def prism_edit(
    unet, scheduler, vae,
    uncond_embeds, src_embeds, tgt_embeds,
    initial_latent, inverted_latents,
    num_steps=50,
    probe_steps=10,
    mask_threshold=1.0,
    guidance_scale=7.5
):
    """Performs the Prism-Edit algorithm."""
    # --- STAGE 1: Semantic Map Extraction ---
    print("\n--- Stage 1: Extracting Semantic Map ---")
    scheduler.set_timesteps(num_steps)
    x_t = inverted_latents[scheduler.timesteps[0].item()]
    accumulated_delta_epsilon = 0

    for i in range(probe_steps):
        t = scheduler.timesteps[i]
        # Predict unconditional and target noise
        noise_pred_uncond = unet(x_t, t, encoder_hidden_states=uncond_embeds).sample
        noise_pred_tgt = unet(x_t, t, encoder_hidden_states=tgt_embeds).sample
        accumulated_delta_epsilon += (noise_pred_tgt - noise_pred_uncond)
        
        # Update latent for the next probe step
        x_t = inverted_latents[scheduler.timesteps[i+1].item()]

    semantic_map_raw = torch.linalg.norm(accumulated_delta_epsilon, dim=1, keepdim=True)
    mean, std = torch.mean(semantic_map_raw), torch.std(semantic_map_raw)
    semantic_map = (semantic_map_raw - mean) / (std + 1e-8)
    final_mask = (semantic_map >= mask_threshold).float()

    # --- STAGE 2: Mask-based Editing ---
    print("\n--- Stage 2: Disentangled Denoising with Mask ---")
    latents = initial_latent.clone()

    for t in tqdm(scheduler.timesteps):
        # 1. Static Blending: Preserve non-masked regions with Inverted Latents
        latents = latents * final_mask + inverted_latents[t.item()] * (1 - final_mask)
        
        # 2. Dynamic Guidance Modulation
        latent_model_input = torch.cat([latents] * 2) # for uncond and cond
        text_embeds = torch.cat([uncond_embeds, tgt_embeds])
        
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        
        delta_epsilon = noise_pred_text - noise_pred_uncond
        
        # Apply guidance only on the masked regions. 
        # Optionally, you can scale the guidance. And you can also change the `final_mask` to a soft mask.
        modulated_noise = noise_pred_uncond + guidance_scale * (delta_epsilon * final_mask)
        
        # 3. Scheduler Step
        latents = scheduler.step(modulated_noise, t, latents, return_dict=False)[0]

    # --- Final Image Decoding ---
    image = vae.decode(latents / vae.config.scaling_factor).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    return Image.fromarray((image[0] * 255).astype(np.uint8))


if __name__ == "__main__":
    # --- User Configuration ---
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    MODEL_ID = "runwayml/stable-diffusion-v1-5"
    
    SOURCE_IMAGE_PATH = "path/to/your/source_image.png" # <<< Path to the source image to be edited
    SOURCE_PROMPT = "a photo of a horse"                 # <<< Prompt describing the source image
    TARGET_PROMPT = "a photo of a zebra"                 # <<< Prompt describing the target edit
    
    # --- Load Models ---
    tokenizer, text_encoder, vae, unet, scheduler = load_models(MODEL_ID, DEVICE)
    
    # --- Encode Prompts ---
    def encode_prompt(prompt):
        inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        return text_encoder(inputs.input_ids.to(DEVICE))[0]

    uncond_embeds = encode_prompt([""])
    src_embeds = encode_prompt([SOURCE_PROMPT])
    tgt_embeds = encode_prompt([TARGET_PROMPT])

    # --- Run Inversion ---
    source_image = Image.open(SOURCE_IMAGE_PATH).convert("RGB")
    initial_latent = image_to_latent(source_image, vae)
    inverted_latents = invert(initial_latent, unet, scheduler, src_embeds)
    
    # --- Run Prism-Edit ---
    edited_image = prism_edit(
        unet, scheduler, vae,
        uncond_embeds, src_embeds, tgt_embeds,
        initial_latent, inverted_latents
    )
    
    # --- Save Result ---
    edited_image.save("edited_image.png")
    print("Edited image saved to edited_image.png")