from diffusers import StableDiffusionPipeline, DiffusionPipeline, StableDiffusion3Pipeline, FLUXPipeline
import torch
import json

model_id_sd15 = "pt-sk/stable-diffusion-1.5"
model_id_sd21 = "stabilityai/stable-diffusion-2-1"
model_id_sdxl = "stabilityai/stable-diffusion-xl-base-1.0"
model_id_sd3 = "stabilityai/stable-diffusion-3-medium"
model_id_flux = "black-forest-labs/FLUX.1-dev"

sd15 = StableDiffusionPipeline.from_pretrained(model_id_sd15, torch_dtype=torch.float16)
sd15 = sd15.to("cuda")
sd21 = StableDiffusionPipeline.from_pretrained(model_id_sd21, torch_dtype=torch.float16)
sd21 = sd21.to("cuda")
sdxl = DiffusionPipeline.from_pretrained(model_id_sdxl, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
sdxl = sdxl.to("cuda")
sd3 = StableDiffusion3Pipeline.from_pretrained(model_id_sd3, torch_dtype=torch.float16)
sd3 = sd3.to("cuda")
flux = FLUXPipeline.from_pretrained(model_id_flux, torch_dtype=torch.bfloat16)


with open('data/attribute_color_data.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

for i in range(len(data)):
    prompt = data[i]
    image_sd15 = sd15(prompt=prompt, generator=torch.Generator("cpu").manual_seed(0)).images[0]
    image_sd15.save(f"../datasets/rank_compositional_generation/train/iteration0/attribute_binding/color/sd15/sd15_prompt_{i}.png")
    image_sd21 = sd21(prompt=prompt, generator=torch.Generator("cpu").manual_seed(0)).images[0]
    image_sd21.save(f"../datasets/rank_compositional_generation/train/iteration0/attribute_binding/color/sd21/sd21_prompt_{i}.png")
    image_sdxl = sdxl(prompt=prompt, generator=torch.Generator("cpu").manual_seed(0)).images[0]
    image_sdxl.save(f"../datasets/rank_compositional_generation/train/iteration0/attribute_binding/color/sdxl/sdxl_prompt_{i}.png")
    image_sd3= sd3(prompt=prompt, num_inference_steps=28, guidance_scale=7.0, generator=torch.Generator("cpu").manual_seed(0)).images[0]
    image_sd3.save(f"../datasets/rank_compositional_generation/train/iteration0/attribute_binding/color/sd3/sd3_prompt_{i}.png")
    image_flux = flux(prompt=prompt, guidance_scale=3.5, num_inference_steps=50, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0)).images[0]
    image_flux.save(f"../datasets/rank_compositional_generation/train/iteration0/attribute_binding/color/flux/flux_prompt_{i}.png")