import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import argparse

import torch
from torchvision.transforms import ToTensor
from diffusers import StableDiffusionXLPipeline
from customize_scheduler.customize_euler import CustomEuler
from customize_pipeline.custom_sdxl_pipeline import CustomizeStableDiffusionXLPipeline
from reward_models.aesthetic_score.reward_model import AestheticClassifier
from reward_models.hps_v2_score.hps_score import HPSV2Score
from reward_models.reward_interface import UnifiedReward
from PIL import Image, ImageDraw
import time
import gc


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str, 
        # default="A 3D Rendering of a cockatoo wearing sunglasses. The sunglasses have a deep black frame with bright pink lenses. Fashion photography, volumetric lighting, CG rendering")
        # default="a photo of a frog holding an apple while smiling in the forest")
        default="Saturn rises on the horizon.")
        # default="a styled face")

    args = parser.parse_args()

    # reward = UnifiedReward("aesthetic")
    # img = "test.jpg"
    # score = reward.score(img, args.prompt)
    # print(score)

    # for beta in [0, 0.1, 0.3, 0.5, 0.7, 1.0, 1.2]:
    for beta in [1]:
    # for beta in [0.5]:
        experiment_config = {
            "beta": beta,
            "prompt": args.prompt,
            "use_r_guidance": False,
        }

        # create a pipeline with the custom scheduler
        pipe = CustomizeStableDiffusionXLPipeline.from_pretrained(
            # "base_models/stabilityai/sdxl-turbo",
            "stabilityai/sdxl-turbo",
            use_safetensors=True,
            # torch_dtype=torch.float16,
            # variant="fp16",
        )

        # get reward
        # unified_reward = UnifiedReward("imagereward")
        unified_reward = UnifiedReward("pickscore")
        # unified_reward = UnifiedReward("hps_v2")
        # unified_reward = UnifiedReward("hacked_grey_reward")
        
        # pipe.set_unified_reward_interface(unified_reward)

        # replace the scheduler with the custom scheduler
        pipe.scheduler = CustomEuler.from_config(
            pipe.scheduler.config,
        )

        pipe = pipe.to("cuda")

        generator = torch.Generator(device="cpu").manual_seed(0)

        # import ipdb
        # ipdb.set_trace()
        image = pipe(
            prompt=args.prompt,
            num_inference_steps=1,
            guidance_scale=0.0, # this is important for turbo model
            generator=generator,
            experiment_config=experiment_config,
        ).images[0]

        image.save("inference_test_ours.jpg")

        # import ipdb
        # ipdb.set_trace()
        image_ = ToTensor()(image)
        score = unified_reward.score(image_, args.prompt)
        if isinstance(score, torch.Tensor):
            score = score.item()

        # Create a new image with extra space at the bottom for text
        new_height = image.height + 100  # Add 50 pixels for the text
        new_image = Image.new("RGB", (image.width, new_height), (255, 255, 255))
        new_image.paste(image, (0, 0))

        # Draw the text on the new image
        draw = ImageDraw.Draw(new_image)
        text = (
            f"Prompt: {args.prompt}\n"
            f"Aesthetic Score (with grad): {score:.2f}\n"
            # f"Aesthetic Score (no grad): {no_grad_aesthetic_score:.2f}\n"
            f"Experiment Config:\n"
            f"  Beta: {experiment_config['beta']}\n"
            f"  Use R Guidance: {experiment_config['use_r_guidance']}"
        )
        draw.text((10, image.height + 10), text, fill=(0, 0, 0))

        # Convert back to the original format if needed
        image = new_image
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        # image.save(os.path.join("experiment/aesthetic_score", f"generated_image_{timestamp}.png"))
        image.save(os.path.join("experiment/pickscore", f"generated_image_{timestamp}.png"))

        del pipe, unified_reward, image, image_, score
        gc.collect()
        torch.cuda.empty_cache()       
