import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import pandas as pd
import numpy as np
from utils.stable_diffusion import (
    load_sd_components,
    load_text_components,
    generate_images,
)
from tqdm import tqdm

# load Stable Diffusion components and remove gradient computation
vae, unet, scheduler = load_sd_components("v1-4")
tokenizer, text_encoder = load_text_components("v1-4")

torch_device = "cuda"
vae.to(torch_device)
vae.requires_grad_(False)
text_encoder.to(torch_device)
text_encoder.requires_grad_(False)
unet.to(torch_device)
unet.requires_grad_(False)
pass


def get_gen_images():
    prompts = pd.read_csv("prompts/benign_prompts.csv").prompt.values
    image_paths, image_prompts = [], []

    for i, prompt in tqdm(enumerate(prompts)):
        images = generate_images(
            prompts=[prompt],
            vae=vae,
            unet=unet,
            scheduler=scheduler,
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            num_inference_steps=50,
            guidance_scale=7,
            seed=42,
        )

        image_path = os.path.join(
            "generated_images",
            f"image_{i}_{prompt[:20].replace(' ', '_').replace('/', '_')}.png",
        )
        images[0].save(image_path)
        image_paths.append(image_path)
        image_prompts.append(prompt)

    pd.DataFrame.from_dict(
        {
            "image_path": image_paths,
            "prompt": image_prompts,
            "blocked_indices": [None] * len(image_paths),
        }
    ).to_csv("prompts/gen_images_with_prompts.csv", index=False, sep=";")


def get_vm_images():
    vm_df = pd.read_csv("prompts/memorized_laion_prompts.csv", sep=";")
    vm_df = vm_df.loc[vm_df["type"] == "VM"]
    all_images = sorted(os.listdir("memorized_images"))
    vm_images = [
        f"memorized_images/{name}"
        for name in all_images
        if int(name.split("_")[1][:-4]) in vm_df.Index.values
    ]
    vm_prompts = vm_df.Caption.values

    indices = np.random.permutation(len(vm_images))[:100]
    vm_images = [vm_images[i] for i in indices]
    vm_prompts = [vm_prompts[i] for i in indices]

    pd.DataFrame.from_dict(
        {
            "image_path": vm_images,
            "prompt": vm_prompts,
            "blocked_indices": [None] * len(vm_images),
        }
    ).to_csv("prompts/vm_images_with_prompts.csv", index=False, sep=";")


def get_tm_images():
    tm_df = pd.read_csv("prompts/memorized_laion_prompts.csv", sep=";")
    tm_df = tm_df.loc[tm_df["type"] == "TM"]
    all_images = sorted(os.listdir("memorized_images"))
    tm_images = [
        f"memorized_images/{name}"
        for name in all_images
        if int(name.split("_")[1][:-4]) in tm_df.Index.values
    ]
    tm_prompts = tm_df.Caption.values

    indices = np.random.permutation(len(tm_images))[:100]
    tm_images = [tm_images[i] for i in indices]
    tm_prompts = [tm_prompts[i] for i in indices]

    pd.DataFrame.from_dict(
        {
            "image_path": tm_images,
            "prompt": tm_prompts,
            "blocked_indices": [None] * len(tm_images),
        }
    ).to_csv("prompts/tm_images_with_prompts.csv", index=False, sep=";")


def get_generated_tm_images():
    tm_images_with_prompts = pd.read_csv("prompts/tm_images_with_prompts.csv", sep=";")

    prompts = tm_images_with_prompts.prompt.values
    image_paths, image_prompts = [], []

    for i, prompt in tqdm(enumerate(prompts)):
        images = generate_images(
            prompts=[prompt],
            vae=vae,
            unet=unet,
            scheduler=scheduler,
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            num_inference_steps=50,
            guidance_scale=7,
            seed=42,
        )

        image_path = os.path.join(
            "generated_tm",
            f"image_{i}_{prompt[:20].replace(' ', '_').replace('/', '_')}.png",
        )
        images[0].save(image_path)
        image_paths.append(image_path)
        image_prompts.append(prompt)

    pd.DataFrame.from_dict(
        {
            "image_path": image_paths,
            "prompt": image_prompts,
            "blocked_indices": [None] * len(image_paths),
        }
    ).to_csv("prompts/gen_tm_images_with_prompts.csv", index=False, sep=";")


def main():
    get_gen_images()
    get_vm_images()
    get_tm_images()
    get_generated_tm_images()


if __name__ == "__main__":
    main()
