import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import argparse

import torch
from torchvision.transforms import ToTensor
from diffusers import StableDiffusionXLPipeline
from customize_scheduler.customize_euler import CustomEuler
from customize_pipeline.custom_sdxl_pipeline import CustomizeStableDiffusionXLPipeline
from reward_models.aesthetic_score.reward_model import AestheticClassifier
from reward_models.hps_v2_score.hps_score import HPSV2Score
from reward_models.reward_interface import UnifiedReward
from PIL import Image, ImageDraw
import time
import gc

if __name__ == "__main__":
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--prompt", type=str, 
    #     # default="A 3D Rendering of a cockatoo wearing sunglasses. The sunglasses have a deep black frame with bright pink lenses. Fashion photography, volumetric lighting, CG rendering")
    #     default="a photo of a frog holding an apple while smiling in the forest")
    #     # default="a styled face")

    # args = parser.parse_args()

    # reward = UnifiedReward("aesthetic")
    # img = "test.jpg"
    # score = reward.score(img, args.prompt)
    # print(score)
    image_list = []
    for beta in [0, 0.01, 0.05, 0.1, 0.5, 1.0]:
    # for beta in [0]:
    # for beta in [1, 2, 3, 4]:
    # for beta in [0.5]:
        

        # create a pipeline with the custom scheduler
        pipe = CustomizeStableDiffusionXLPipeline.from_pretrained(
            # "base_models/stabilityai/sdxl-turbo",
            "stabilityai/sdxl-turbo",
            use_safetensors=True,
            torch_dtype=torch.float16,
            variant="fp16",
        )

        # get reward
        # unified_reward = UnifiedReward("imagereward")
        unified_reward = UnifiedReward("pickscore")
        # unified_reward = UnifiedReward("hps_v2")
        # unified_reward = UnifiedReward("hacked_grey_reward")
        
        pipe.set_unified_reward_interface(unified_reward)

        # replace the scheduler with the custom scheduler
        pipe.scheduler = CustomEuler.from_config(
            pipe.scheduler.config,
        )

        pipe = pipe.to("cuda")

        generator = torch.Generator(device="cpu").manual_seed(0)

        # prompts = [
        #     "Saturn rises on the horizon.",
        #     "a watercolor painting of a super cute kitten wearing a hat of flowers",
        #     "A galaxy-colored figurine floating over the sea at sunset, photorealistic.",
        #     "fireclaw machine mecha animal beast robot of horizon forbidden west horizon zero dawn bioluminiscence, behance hd by jesper ejsing, byrhads, makoto shinkai and lois van baarle, ilya kuvshinov, rossdraws global illumination",
        #     "A swirling, multicolored portal emerges from the depths of an ocean of coffee, with waves of the rich liquid gently rippling outward. The portal engulfs a coffee cup, which serves as a gateway to a fantastical dimension. The surrounding digital art landscape reflects the colors of the portal, creating an alluring scene of endless possibilities.",
        # "A profile picture of an anime boy, half robot, brown hair",
        # "Detailed Portrait of a cute woman vibrant pixie hair by Yanjun Cheng and Hsiao-Ron Cheng and Ilya Kuvshinov, medium close up, portrait photography, rim lighting, realistic eyes, photorealism pastel, illustration",
        # "On the Mid-Autumn Festival, the bright full moon hangs in the night sky. A quaint pavilion is illuminated by dim lights, resembling a beautiful scenery in a painting. Camera type: close-up. Camera lens type: telephoto. Time of day: night. Style of lighting: bright. Film type: ancient style. HD."
        # ]
        prompts = ["A 3D Rendering of a cockatoo wearing sunglasses. The sunglasses have a deep black frame with bright pink lenses. Fashion photography, volumetric lighting, CG rendering"]


        for i, prompt in enumerate(prompts):
            experiment_config = {
                "beta": beta,
                "prompt": prompt,
                "use_r_guidance": True,
            }
        # import ipdb
        # ipdb.set_trace()
            image = pipe(
                prompt=prompt,
                num_inference_steps=1,
                guidance_scale=0.0, # this is important for turbo model
                generator=generator,
                experiment_config=experiment_config,
            ).images[0]

            image.save("generated_image/inference_test_%f.jpg"%beta)

            image_list.append(image)
            # del image, image_, image_
            gc.collect()
            torch.cuda.empty_cache()  

        
    for image in image_list:
        image_ = ToTensor()(image)
        score = unified_reward.score(image_, prompt)
        if isinstance(score, torch.Tensor):
            score = score.item()
        with open("image_scores.txt", "w") as score_file:
            score_file.write(f"Image {i}: {score}, Prompt: {prompt}\n")
        print(f"Generated image {i} with score: {score}")

        # import ipdb
        # ipdb.set_trace()
        # image_ = ToTensor()(image)
        # score = unified_reward.score(image_, args.prompt)
        # if isinstance(score, torch.Tensor):
        #     score = score.item()

        # # Create a new image with extra space at the bottom for text
        # new_height = image.height + 100  # Add 50 pixels for the text
        # new_image = Image.new("RGB", (image.width, new_height), (255, 255, 255))
        # new_image.paste(image, (0, 0))

        # # Draw the text on the new image
        # draw = ImageDraw.Draw(new_image)
        # text = (
        #     f"Prompt: {args.prompt}\n"
        #     f"Aesthetic Score (with grad): {score:.2f}\n"
        #     # f"Aesthetic Score (no grad): {no_grad_aesthetic_score:.2f}\n"
        #     f"Experiment Config:\n"
        #     f"  Beta: {experiment_config['beta']}\n"
        #     f"  Use R Guidance: {experiment_config['use_r_guidance']}"
        # )
        # draw.text((10, image.height + 10), text, fill=(0, 0, 0))

        # # Convert back to the original format if needed
        # image = new_image
        # timestamp = time.strftime("%Y%m%d-%H%M%S")
        # # image.save(os.path.join("experiment/aesthetic_score", f"generated_image_{timestamp}.png"))
        # image.save(os.path.join("experiment/pickscore", f"generated_image_{timestamp}.png"))

        # del pipe, unified_reward, image, image_, score
        # gc.collect()
        # torch.cuda.empty_cache()       
