import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from diffusers import DDIMScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from typing import Tuple
import argparse
import os


class InversionSampler:
    """
    Inversion Sampler class implementing DDIM inversion and sampling functions
    """

    def __init__(self, scheduler: DDIMScheduler, device: str = "cuda"):
        self.scheduler = scheduler
        self.device = device

    @torch.no_grad()
    def ddim_sampling(
            self,
            initial_latents: torch.Tensor,
            unet_ori: torch.nn.Module,
            unet_fin: torch.nn.Module,
            text_embeddings: torch.Tensor,
            guidance_scale: float = 7.5,
            num_inference_steps: int = 50,
    ) -> torch.Tensor:
        """
        Perform DDIM sampling to generate images from noise latents.

        Args:
            initial_latents: Initial noise tensor (batch_size, 4, H, W)
            unet_ori: Original UNet model for unconditional noise prediction
            unet_fin: Finetuned UNet model for conditional noise prediction
            text_embeddings: Conditional text embeddings (batch_size * 2, seq_len, emb_dim)
            guidance_scale: Scale for classifier-free guidance
            num_inference_steps: Number of denoising steps

        Returns:
            Generated image latents (batch_size, 4, H, W)
        """
        batch_size = initial_latents.shape[0]

        # Configure scheduler timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=self.device)
        latents = initial_latents.to(self.device).detach().clone()

        # Denoising loop
        for t in tqdm(self.scheduler.timesteps, desc="DDIM Sampling"):
            # Expand latents for both conditional and unconditional outputs
            latent_model_input = self.scheduler.scale_model_input(latents, t)

            # Predict noise residual
            noise_pred_cond = unet_fin(
                latent_model_input,
                t.expand(batch_size),
                text_embeddings
            ).sample
            noise_pred_uncond = unet_ori(
                latent_model_input,
                t.expand(batch_size),
                text_embeddings
            ).sample

            # Apply classifier-free guidance
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

            # Update latents
            latents = self.scheduler.step(noise_pred, t, latents, eta=0.0).prev_sample

        return latents

    @torch.no_grad()
    def ddim_inversion(
            self,
            target_latents: torch.Tensor,
            unet: torch.nn.Module,
            text_embeddings: torch.Tensor,
            uncond_embeddings: torch.Tensor,
            guidance_scale: float = 7.5,
            num_inversion_steps: int = 50,
    ) -> torch.Tensor:
        """
        Perform DDIM inversion to reconstruct noise latents from images.

        Args:
            target_latents: Target image latents (batch_size, 4, H, W)
            unet: UNet model for noise prediction
            text_embeddings: Conditional text embeddings (batch_size * 2, seq_len, emb_dim)
            uncond_embeddings: Unconditional embeddings (batch_size * 2, seq_len, emb_dim)
            guidance_scale: Scale for classifier-free guidance
            num_inversion_steps: Number of inversion steps

        Returns:
            Reconstructed noise latents (batch_size, 4, H, W)
        """
        batch_size = target_latents.shape[0]

        # Combine conditional and unconditional embeddings
        conditioning = torch.cat([uncond_embeddings, text_embeddings])

        # Configure scheduler timesteps in reverse order
        self.scheduler.set_timesteps(num_inversion_steps, device=self.device)
        reverse_timesteps = reversed(self.scheduler.timesteps)
        latents = target_latents.to(self.device).detach().clone()

        # Inversion loop
        for i, t in enumerate(tqdm(reverse_timesteps, desc="DDIM Inversion")):
            # Skip last step for reconstruction
            if i >= num_inversion_steps - 1:
                continue

            # Expand latents for both conditional and unconditional outputs
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # Predict noise residual
            noise_pred = unet(
                latent_model_input,
                t.expand(batch_size * 2),
                conditioning
            ).sample

            # Apply classifier-free guidance
            noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

            # Approximate inversion using current and next timestep
            current_t = max(0, t.item() - (1000 // num_inversion_steps))  # t
            next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t+1
            alpha_t = self.scheduler.alphas_cumprod[current_t]
            alpha_t_next = self.scheduler.alphas_cumprod[next_t]

            # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents))
            latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
                    1 - alpha_t_next
            ).sqrt() * noise_pred

        return latents

    @torch.no_grad()
    def inversion_based_sampling(
            self,
            target_latents: torch.Tensor,
            text_embeddings: torch.Tensor,
            uncond_embeddings: torch.Tensor,
            original_unet: torch.nn.Module,
            finetuned_unet: torch.nn.Module,
            num_steps: int = 50,
            guidance_scale_inv: float = 7.5,
            guidance_scale_sam: float = 7.5,
    ) -> torch.Tensor:
        """
        Perform inversion-based sampling using two UNet models.

        Args:
            target_latents: Target image latents (batch_size, 4, H, W)
            text_embeddings: Conditional text embeddings
            uncond_embeddings: Unconditional embeddings
            original_unet: Pretrained UNet for inversion
            finetuned_unet: Finetuned UNet for sampling
            num_steps: Number of inversion/sampling steps
            guidance_scale_inv: Scale for classifier-free guidance during inversion
            guidance_scale_sam: Scale for classifier-free guidance during sampling

        Returns:
            Generated image latents (batch_size, 4, H, W)
        """
        # Step 1: Inversion with original model
        reconstructed_latents = self.ddim_inversion(
            target_latents=target_latents,
            unet=original_unet,
            text_embeddings=text_embeddings,
            uncond_embeddings=uncond_embeddings,
            guidance_scale=guidance_scale_inv,
            num_inversion_steps=num_steps
        )

        # Step 2: Sampling with finetuned model
        generated_latents = self.ddim_sampling(
            initial_latents=reconstructed_latents,
            unet_ori=original_unet,
            unet_fin=finetuned_unet,
            text_embeddings=text_embeddings,
            guidance_scale=guidance_scale_sam,
            num_inference_steps=num_steps
        )

        return generated_latents


def load_models(args):
    """
    Load all required model components

    Args:
        args: Parsed command line arguments

    Returns:
        Tuple of model components: (original_unet, finetuned_unet, text_encoder, tokenizer, scheduler, vae)
    """
    # Load original model components
    original_unet = UNet2DConditionModel.from_pretrained(args.model_name, subfolder="unet").to(args.device)
    text_encoder = CLIPTextModel.from_pretrained(args.model_name, subfolder="text_encoder").to(args.device)
    tokenizer = CLIPTokenizer.from_pretrained(args.model_name, subfolder="tokenizer")
    finetuned_unet = UNet2DConditionModel.from_pretrained(args.model_name, subfolder="unet").to(args.device)

    # Load weights
    original_unet_path = args.original_unet_path.format(prompt=args.prompt)
    finetuned_unet_path = args.finetuned_unet_path.format(prompt=args.prompt)

    original_unet.load_state_dict(torch.load(original_unet_path))
    finetuned_unet.load_state_dict(torch.load(finetuned_unet_path))

    # Load scheduler
    scheduler = DDIMScheduler.from_pretrained(args.model_name, subfolder="scheduler")

    # Load VAE decoder (for converting latents to images)
    vae = AutoencoderKL.from_pretrained(args.model_name, subfolder="vae").to(args.device)

    return original_unet, finetuned_unet, text_encoder, tokenizer, scheduler, vae


def prepare_text_embeddings(prompt, tokenizer, text_encoder, device):
    """
    Prepare text embeddings for conditional and unconditional generation

    Args:
        prompt: Text prompt for generation
        tokenizer: Text tokenizer
        text_encoder: Text encoder model
        device: Device to place tensors on

    Returns:
        Tuple of (conditional_embeddings, unconditional_embeddings)
    """
    # Prepare conditional and unconditional prompts
    text_input = tokenizer(
        [prompt, ""],  # Conditional prompt + unconditional prompt
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    )

    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

    # Separate conditional and unconditional embeddings
    cond_embeddings, uncond_embeddings = text_embeddings.chunk(2)
    return cond_embeddings, uncond_embeddings


def load_latents_from_file(file_path, device):
    """
    Load stored latent tensors from .pt file

    Args:
        file_path: Path to the latents file
        device: Device to load tensors to

    Returns:
        Loaded latent tensors
    """
    latents = torch.load(file_path, map_location=device)
    print(f"Loaded latents from {file_path} with shape: {latents.shape}")
    return latents


def decode_latent_batch(vae, latents_batch):
    """
    Batch decode latent tensors to images

    Args:
        vae: VAE model for decoding
        latents_batch: Latent tensors (batch, 4, 64, 64)

    Returns:
        Numpy array of decoded images in uint8 format (batch, height, width, 3)
    """
    # Scale and decode
    latents = 1 / 0.18215 * latents_batch
    with torch.no_grad():
        images = vae.decode(latents).sample

    # Batch processing: (b, c, h, w) -> (b, h, w, c) numpy array
    images = (images / 2 + 0.5).clamp(0, 1)  # Range [0,1]
    images = images.cpu().permute(0, 2, 3, 1).numpy()  # Reorder dimensions
    images = (images * 255).astype(np.uint8)  # Convert to uint8
    return images


def create_collage(images, rows=None, cols=None):
    """
    Create a collage from a list of images

    Args:
        images: List of image arrays
        rows: Number of rows (optional)
        cols: Number of columns (optional)

    Returns:
        PIL Image object containing the collage
    """
    num_images = len(images)

    # Automatically determine layout (rows x columns)
    if rows is None and cols is None:
        cols = int(np.ceil(np.sqrt(num_images)))
        rows = int(np.ceil(num_images / cols))
    elif rows is None:
        rows = int(np.ceil(num_images / cols))
    elif cols is None:
        cols = int(np.ceil(num_images / rows))

    # Get single image dimensions (assuming all images have the same size)
    height, width, _ = images[0].shape
    print(f"Single image size: {width}x{height} pixels")

    # Calculate collage dimensions
    collage_width = cols * width
    collage_height = rows * height

    # Create blank canvas
    collage = Image.new('RGB', (collage_width, collage_height))

    # Paste images one by one
    for i, img in enumerate(images):
        row = i // cols
        col = i % cols

        # Calculate current position
        x = col * width
        y = row * height

        # Create PIL image and paste
        pil_img = Image.fromarray(img)
        collage.paste(pil_img, (x, y))

    return collage


def main():
    """
    Main function to execute inversion-based sampling
    """
    # 1. Parse arguments
    args = parse_args()

    # 2. Format path placeholders
    args.latents_path = args.latents_path.format(prompt=args.prompt)
    args.text_path = args.text_path.format(prompt=args.prompt)

    # 3. Load models
    original_unet, finetuned_unet, text_encoder, tokenizer, scheduler, vae = load_models(args)

    # 4. Initialize inversion sampler
    sampler = InversionSampler(scheduler, args.device)

    # 5. Prepare text embeddings
    cond_embeddings_ini, uncond_embeddings = prepare_text_embeddings(
        args.prompt, tokenizer, text_encoder, args.device
    )
    cond_embeddings = torch.load(args.text_path, map_location=args.device)
    emb_sim = F.cosine_similarity(cond_embeddings_ini.clone().flatten(), cond_embeddings.clone().flatten(), dim=0)
    print(f"Text embedding similarity: {emb_sim.item():.4f}")

    # 6. Load initial latents from file
    target_latents = load_latents_from_file(args.latents_path, args.device)

    # 7. Perform inversion-based sampling
    generated_latents = sampler.inversion_based_sampling(
        target_latents=target_latents,
        text_embeddings=cond_embeddings.expand(target_latents.shape[0], -1, -1),
        uncond_embeddings=uncond_embeddings.expand(target_latents.shape[0], -1, -1),
        original_unet=original_unet,
        finetuned_unet=finetuned_unet,
        num_steps=args.num_steps,
        guidance_scale_inv=args.guidance_scale_inv,
        guidance_scale_sam=args.guidance_scale_sam
    )

    # 8. Batch decode latent tensors
    generated_images = decode_latent_batch(vae, generated_latents)
    print(f"Decoded image shape: {generated_images.shape}")  # Should be (5, H, W, 3)

    # 9. Create collage
    collage = create_collage(
        images=[img for img in generated_images],
        rows=len(generated_images),  # Horizontal arrangement (single row)
        cols=1  # All images in one row
    )

    # 10. Save collage
    os.makedirs(args.output_dir, exist_ok=True)
    collage_path = f"{args.output_dir}/{args.prompt}.jpg"
    collage.save(collage_path)
    print(f"Saved collage to: {collage_path} Size: {collage.width}x{collage.height} pixels")

def parse_args():
    """
    Parse command line arguments for the inversion-based sampling process

    Returns:
        Parsed arguments object
    """
    parser = argparse.ArgumentParser(description="Inversion-based Sampling with Stable Diffusion")

    # Model parameters
    parser.add_argument("--model_name", type=str, default="sd_model/tiny-sd",
                        help="Pretrained model name or path")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device to run on (cuda/cpu)")

    # Sampling parameters
    parser.add_argument("--prompt", type=str, default="dog",
                        help="Generation prompt")
    parser.add_argument("--num_steps", type=int, default=50,
                        help="Number of inversion and sampling steps")
    parser.add_argument("--guidance_scale_inv", type=float, default=1.0,
                        help="Guidance scale during inversion")
    parser.add_argument("--guidance_scale_sam", type=float, default=2.0,
                        help="Guidance scale during sampling")

    # Path parameters
    parser.add_argument("--latents_path", type=str,
                        default="ckpts/GradCFG/{prompt}/0_1_50_0.01/dummy_images_epoch4000.pth",
                        help="Path to latents file, can use {prompt} as placeholder")
    parser.add_argument("--text_path", type=str,
                        default="ckpts/GradCFG/{prompt}/0_1_50_0.01/text_emb_epoch4000.pth",
                        help="Path to text embeddings file, can use {prompt} as placeholder")
    parser.add_argument("--original_unet_path", type=str,
                        default="ckpts/model_ckpt/{prompt}/epoch_0.pt",
                        help="Path to original UNet weights, can use {prompt} as placeholder")
    parser.add_argument("--finetuned_unet_path", type=str,
                        default="ckpts/model_ckpt/{prompt}/epoch_200.pt",
                        help="Path to finetuned UNet weights, can use {prompt} as placeholder")
    parser.add_argument("--output_dir", type=str, default="Inv_Sam_results",
                        help="Output directory")

    return parser.parse_args()

if __name__ == "__main__":
    main()