import argparse
import torch
import os
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, DEISMultistepScheduler, UNet2DConditionModel
import torch.nn.functional as F


def set_seed(seed: int):
    """Set random seed for reproducibility"""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def create_collage(images, rows=None, cols=None):
    """
    Create a collage from a list of images

    Args:
        images: List of image arrays
        rows: Number of rows (optional)
        cols: Number of columns (optional)

    Returns:
        PIL Image object containing the collage
    """
    num_images = len(images)

    # Automatically determine layout (rows x columns)
    if rows is None and cols is None:
        cols = int(np.ceil(np.sqrt(num_images)))
        rows = int(np.ceil(num_images / cols))
    elif rows is None:
        rows = int(np.ceil(num_images / cols))
    elif cols is None:
        cols = int(np.ceil(num_images / rows))

    # Get single image dimensions (assuming all images have the same size)
    height, width, _ = images[0].shape
    print(f"Single image size: {width}x{height} pixels")

    # Calculate collage dimensions
    collage_width = cols * width
    collage_height = rows * height

    # Create blank canvas
    collage = Image.new('RGB', (collage_width, collage_height))

    # Paste images one by one
    for i, img in enumerate(images):
        row = i // cols
        col = i % cols

        # Calculate current position
        x = col * width
        y = row * height

        # Create PIL image and paste
        pil_img = Image.fromarray(img)
        collage.paste(pil_img, (x, y))

    return collage

def process_prompt(prompt, args, device, dtype, tokenizer, text_encoder, vae, unet):
    """
    Process a single prompt using the restored text embedding
    """
    # Create output directory
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    print(f"\nProcessing prompt: {prompt}")
    print(f"Output directory: {args.output_dir}")

    # Format the text embedding path using the prompt
    text_path = args.text_embedding_path.format(prompt=prompt)

    if not os.path.exists(text_path):
        print(f"Warning: Fine-tuned embedding not found for prompt: {prompt}")
        print(f"Path: {text_path}")
        return {
            "prompt": prompt,
            "status": "missing_embedding",
            "output_dir": args.output_dir
        }

    try:
        fine_tuned_embeddings = torch.load(text_path, map_location=device)
        print(f"Successfully loaded fine-tuned embedding for: {prompt}")
    except Exception as e:
        print(f"Error loading embedding for {prompt}: {str(e)}")
        return {
            "prompt": prompt,
            "status": "load_error",
            "output_dir": args.output_dir
        }

    # Configure scheduler
    scheduler = DEISMultistepScheduler.from_pretrained(
        "sd_model/tiny-sd", subfolder="scheduler")
    scheduler.set_timesteps(args.num_steps)

    # Create unconditional embeddings
    uncond_input = tokenizer(
        [""] * args.num_samples,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids)[0]

    # Initialize latent variables
    latents = torch.randn(
        (args.num_samples, 4, 64, 64),
        device=device,
        dtype=dtype,
        generator=torch.Generator(device).manual_seed(args.seed)
    )

    # Generate images with fine-tuned text embedding
    print(f"\nGenerating images with fine-tuned text embedding for: {prompt}")
    cond_embeddings = fine_tuned_embeddings.expand(args.num_samples, -1, -1)
    text_embeddings_cat = torch.cat([uncond_embeddings, cond_embeddings])

    current_latents = latents.clone()

    for t in tqdm(scheduler.timesteps):
        t = t.to(device)
        latent_model_input = torch.cat([current_latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        with torch.no_grad():
            noise_pred = unet(
                latent_model_input,
                t.reshape(1).repeat(args.num_samples * 2),
                text_embeddings_cat
            ).sample

        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + args.guidance_scale * (
                noise_pred_text - noise_pred_uncond
        )

        current_latents = scheduler.step(
            noise_pred,
            t,
            current_latents,
        ).prev_sample

    # Decode images
    current_latents = 1 / 0.18215 * current_latents
    with torch.no_grad():
        images = vae.decode(current_latents).sample

    images = (images / 2 + 0.5).clamp(0, 1)
    images = images.cpu().permute(0, 2, 3, 1).float().numpy()

    # We take only the first image as the long image
    long_image = (images * 255).astype(np.uint8)
    collage = create_collage(
        images=[img for img in long_image],
        rows=len(long_image),  # Horizontal arrangement (single row)
        cols=1  # All images in one row
    )

    # Save the long image

    output_path = os.path.join(args.output_dir, f"{prompt.replace(' ', '_')}.png")
    collage.save(output_path)
    print(f"Saved long image: {output_path}")

    return {
        "prompt": prompt,
        "status": "success",
        "output_dir": args.output_dir,
        "output_path": output_path
    }


def main():
    # Parse command line arguments
    args = parse_args()

    # Set random seed
    set_seed(args.seed)

    # Create output directory
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    # Determine device
    if args.device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = args.device

    # Check device availability
    if device == "cuda" and not torch.cuda.is_available():
        print("Warning: CUDA requested but not available. Falling back to CPU.")
        device = "cpu"

    dtype = torch.float32

    print(f"Processing prompt: {args.prompt}")
    print(f"Using device: {device}")
    print(f"Output directory: {args.output_dir}")
    print(f"Text embedding path template: {args.text_embedding_path}")

    # Load model components
    print("\nLoading model components...")
    tokenizer = CLIPTokenizer.from_pretrained(
        args.model_dir, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(
        args.model_dir, subfolder="text_encoder").to(device)
    vae = AutoencoderKL.from_pretrained(
        args.model_dir, subfolder="vae").to(device)
    unet = UNet2DConditionModel.from_pretrained(
        args.model_dir, subfolder="unet").to(device)
    unet.eval()

    # Process the prompt
    try:
        result = process_prompt(
            args.prompt, args, device, dtype, tokenizer, text_encoder, vae, unet
        )
        print(f"\nCompleted processing: {args.prompt}")
        print("=" * 50)

        # Print result
        print("\n===== PROCESSING SUMMARY =====")
        if result['status'] == 'success':
            print(f"Successfully generated long image")
            print(f"Image saved at: {result['output_path']}")
        else:
            print(f"Processing failed: {result['status']}")

        print("\nAll processing completed!")
    except Exception as e:
        print(f"\nError processing prompt '{args.prompt}': {str(e)}")
        print("=" * 50)
        print("\nProcessing failed!")


def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Generate images using restored text embeddings")
    parser.add_argument("--prompt", type=str, default="dog",
                        help="Prompt to process")
    parser.add_argument("--model_dir", type=str, default="sd_model/tiny-sd",
                        help="Directory containing the model")
    parser.add_argument("--output_dir", type=str, default="generated_images",
                        help="Output directory for generated images")
    parser.add_argument("--text_embedding_path", type=str,
                        default="ckpts/GradCFG/{prompt}/100_1_100_0.01/text_emb_epoch4000.pth",
                        help="Path template for restored text embeddings (use {prompt} as placeholder)")
    parser.add_argument("--num_samples", type=int, default=5,
                        help="Number of images to generate per prompt")
    parser.add_argument("--num_steps", type=int, default=50,
                        help="Number of denoising steps")
    parser.add_argument("--guidance_scale", type=float, default=7.5,
                        help="Guidance scale for generation")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    parser.add_argument("--device", type=str, default="auto",
                        choices=["auto", "cuda", "cpu"],
                        help="Device to use for computation (auto, cuda, or cpu)")
    return parser.parse_args()



if __name__ == "__main__":
    main()