import torch
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor
# from reward import CLIPScorer
device = "cuda"

# Initialize pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
).to(device)

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.scheduler.set_timesteps(50)
pipe.scheduler.eta = 0.0



def transport(z):
    """
    z: (n_chains, n_imgs, C, 64, 64)
    returns: List[List[PIL.Image]]
    """
    n_chains, n_imgs, C, H, W = z.shape

    z_flat = z.view(n_chains * n_imgs, C, H, W)

    result = pipe(
        prompt=[""] * (n_chains * n_imgs),
        latents=z_flat,
        num_inference_steps=50,
        guidance_scale=0.0,
    )

    images = result.images  # list length = n_chains * n_imgs

    chains = [
        images[i * n_imgs : (i + 1) * n_imgs]
        for i in range(n_chains)
    ]

    return chains
    


# def get_energy(z):
#     chains = transport(z)

#     n_chains = len(chains)
#     n_imgs = len(chains[0])

#     images_flat = [
#         img
#         for chain in chains
#         for img in chain
#     ]
#     clip_scorer = CLIPScorer(device=device)
#     rewards_flat = clip_scorer.score_images(
#         images_flat,
#         text_prompt="Picture of Horse",
#     )
    
#     return rewards_flat.view(n_chains, n_imgs)

torch.manual_seed(42)  # set seed to test consistency
batch_size = 4
n_chains = 1

print('getting here')
z = torch.randn(
    (batch_size, pipe.unet.config.in_channels, 64, 64),
    device=device,
    dtype=torch.float16,
)


# rewards = get_energy(z)
# print("First chain scores:")
# print(rewards[0])

# print("First image in first chain:")
# print(rewards[0, 0].item())


# print(f'Number of chains: {len(chains)}')
# print(f'Images per chain: {len(chains[0])}')

# print(f'1st image type: {type(chains[0][0])}')
# print(f'1st image size (W, H): {chains[0][0].size}')
# print(f'1st image mode: {chains[0][0].mode}')

prompts = ["a black dog and a white cat"] * batch_size

result = pipe(
    prompt=prompts,
    latents=z,
    num_inference_steps=50,
    guidance_scale=10.0,
)


images = result.images  

for i, img in enumerate(images):
    img.save(f"option_{i}.png")
    print(f"Saved: option_{i}.png")




# def transport(z)





# # Get PIL Image
# image = result.images[0]

# # change to tensor if needed
# # image_tensor = pil_to_tensor(image).float().to(device) / 255.0



# clip_scorer = CLIPScorer(device=device)
# score = clip_scorer.score_image(
#     image,  # Pass tensor, not PIL
#     "Picture of horse"
# )

# print(f"CLIP similarity score: {score:.4f}")


