# -*- coding: utf-8 -*-
"""no kwargs pickscore enabled lambda mixing.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1gZCrjLtNHU9rDe8LNiAoGiaeq8debPRg
"""

from diffusers import StableDiffusionPipeline, UNet2DConditionModel, StableDiffusionXLPipeline
import torch
torch.set_grad_enabled(False)

import os
import argparse
import json, sys
import prompts
torch.set_grad_enabled(False)


parser = argparse.ArgumentParser(
    description="Run SDXL lambda sampling with a custom λ value."
)
parser.add_argument(
    "--Lambda",
    type=float,
    default=1.0,
    help="Lambda value for the sampling (e.g. 0.125)."
)

parser.add_argument(
    "--Beta",
    type=float,
    default=1.0,
    help="Lambda value for the sampling (e.g. 0.125)."
)


parser.add_argument(
    "--model_name",
    type=str,
    default=None,
    help='Name of the model, e.g. "beta2000_64acc_600epoch".'
)


args = parser.parse_args()



dpo_unet = UNet2DConditionModel.from_pretrained(
                            # 'mhdang/dpo-sd1.5-text2image-v1',
                            # 'mhdang/dpo-sdxl-text2image-v1',
                            # alternatively use local ckptdir (*/checkpoint-n/)
                            args.model_name,
                            subfolder='unet',
                            torch_dtype=torch.float16
).to('cuda')

# pretrained_model_name = "CompVis/stable-diffusion-v1-4"
pretrained_model_name = "runwayml/stable-diffusion-v1-5"
# pretrained_model_name = "stabilityai/stable-diffusion-xl-base-1.0"
gs = (5 if 'stable-diffusion-xl' in pretrained_model_name else 7.5)

if 'stable-diffusion-xl' in pretrained_model_name:
    pipe = StableDiffusionXLPipeline.from_pretrained(
        pretrained_model_name, torch_dtype=torch.float16,
        variant="fp16", use_safetensors=True
    ).to("cuda")
else:
    pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name,
                                                   torch_dtype=torch.float16)
pipe = pipe.to('cuda')
pipe.safety_checker = None # Trigger-happy, blacks out >50% of "robot tiger"

# from diffusers import DDPMScheduler

# # Use the configuration of the default scheduler to initialize the DDPMScheduler
# pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config, steps_offset=0)

# # Can do clip_utils, aes_utils, hps_utils
# from utils.pickscore_utils import Selector
# # Score generations automatically w/ reward model
# ps_selector = Selector('cuda')

from typing import List, Optional, Tuple, Union

def randn_tensor(
    shape: Union[Tuple, List],
    generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
    device: Optional[Union[str, "torch.device"]] = None,
    dtype: Optional["torch.dtype"] = None,
    layout: Optional["torch.layout"] = None,
):
    """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
    passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
    is always created on the CPU.
    """
    # device on which tensor is created defaults to device
    if isinstance(device, str):
        device = torch.device(device)
    rand_device = device
    batch_size = shape[0]

    layout = layout or torch.strided
    device = device or torch.device("cpu")

    if generator is not None:
        gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
        if gen_device_type != device.type and gen_device_type == "cpu":
            rand_device = "cpu"
            if device != "mps":
                logger.info(
                    f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
                    f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
                    f" slightly speed up this function by passing a generator that was created on the {device} device."
                )
        elif gen_device_type != device.type and gen_device_type == "cuda":
            raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")

    # make sure generator list of length 1 is treated like a non-list
    if isinstance(generator, list) and len(generator) == 1:
        generator = generator[0]

    if isinstance(generator, list):
        shape = (1,) + shape[1:]
        latents = [
            torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
            for i in range(batch_size)
        ]
        latents = torch.cat(latents, dim=0).to(device)
    else:
        latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)

    return latents

from diffusers import EulerAncestralDiscreteScheduler
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
# from diffusers.utils import randn_tensor # For noise generation
import PIL.Image

class CustomEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
    def step(
        self,
        Lambda : float,
        model_output_ref: torch.Tensor,
        model_output_beta: torch.Tensor,
        timestep: Union[float, torch.Tensor],
        sample: torch.Tensor,
        generator: Optional[torch.Generator] = None,
        return_dict: bool = True,
    ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:

        if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
            raise ValueError(
                (
                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
                    " one of the `scheduler.timesteps` as a timestep."
                ),
            )

        if not self.is_scale_input_called:
            logger.warning(
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]

        # Upcast to avoid precision issues when computing prev_sample
        sample = sample.to(torch.float32)


        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
        if self.config.prediction_type == "epsilon":
            pred_original_sample_ref = sample - sigma * model_output_ref
        elif self.config.prediction_type == "v_prediction":
            # * c_out + input * c_skip
            pred_original_sample_ref = model_output_ref * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
        elif self.config.prediction_type == "sample":
            raise NotImplementedError("prediction_type not implemented yet: sample")
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )

        if self.config.prediction_type == "epsilon":
            pred_original_sample_beta = sample - sigma * model_output_beta
        elif self.config.prediction_type == "v_prediction":
            # * c_out + input * c_skip
            pred_original_sample_beta = model_output_beta * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
        elif self.config.prediction_type == "sample":
            raise NotImplementedError("prediction_type not implemented yet: sample")
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )


        sigma_from = self.sigmas[self.step_index]
        sigma_to = self.sigmas[self.step_index + 1]
        sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
        sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5


        # 2. Convert to an ODE derivative
        dt = sigma_down - sigma

        derivative_ref = (sample - pred_original_sample_ref) / sigma

        prev_sample_ref = sample + derivative_ref * dt

        derivative_beta = (sample - pred_original_sample_beta) / sigma

        prev_sample_beta = sample + derivative_beta * dt


        # Crucial segment
        if timestep > 1:
            v1 = sigma_up ** 2
            v2 = sigma_up ** 2 # v1, v2 need not be the same, but in this case it is.
            v_new = 1/((1-Lambda)/v1 + Lambda/v2)
            std_new = v_new**0.5

            mean = v_new * ((1-Lambda) * prev_sample_ref/v1 + Lambda * prev_sample_beta/v2)

            device = model_output_ref.device
            noise = randn_tensor(model_output_ref.shape, dtype=model_output_ref.dtype, device=device, generator=generator)

            prev_sample = mean + noise * std_new


            # Cast sample back to model compatible dtype
            prev_sample = prev_sample.to(model_output_ref.dtype)
            self.save = prev_sample
        else:
            prev_sample = self.save

        # upon completion increase step index by one
        self._step_index += 1

        if not return_dict:
            return (
                prev_sample,
                pred_original_sample_ref,
            )

        return EulerAncestralDiscreteSchedulerOutput(
            prev_sample=prev_sample, pred_original_sample=pred_original_sample_ref
        )

from diffusers import DDPMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
# from diffusers.utils import randn_tensor # For noise generation
import PIL.Image

class CustomCombinedDDPMScheduler(DDPMScheduler):

    def step(
        self,
        lmbda : float,
        model_output1: torch.FloatTensor, # Noise prediction from UNet 1
        model_output2: torch.FloatTensor, # Noise prediction from UNet 2
        timestep: int,
        sample: torch.Tensor,
        generator=None,
        return_dict: bool = True,
    ) -> Union[DDPMSchedulerOutput, Tuple]:

        t = timestep

        prev_t = self.previous_timestep(t)

        if model_output1.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
            model_output1, predicted_variance1 = torch.split(model_output1, sample.shape[1], dim=1)
        else:
            predicted_variance1 = None


        if model_output2.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
            model_output2, predicted_variance2 = torch.split(model_output2, sample.shape[1], dim=1)
        else:
            predicted_variance2 = None


        # 1. compute alphas, betas
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        current_beta_t = 1 - current_alpha_t


        # 2. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (15) from https://huggingface.co/papers/2006.11239
        if self.config.prediction_type == "epsilon":
            pred_original_sample1 = (sample - beta_prod_t ** (0.5) * model_output1) / alpha_prod_t ** (0.5)
        elif self.config.prediction_type == "sample":
            pred_original_sample1 = model_output1
        elif self.config.prediction_type == "v_prediction":
            pred_original_sample1 = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output1
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
                " `v_prediction`  for the DDPMScheduler."
            )


        if self.config.prediction_type == "epsilon":
            pred_original_sample2 = (sample - beta_prod_t ** (0.5) * model_output2) / alpha_prod_t ** (0.5)
        elif self.config.prediction_type == "sample":
            pred_original_sample2 = model_output2
        elif self.config.prediction_type == "v_prediction":
            pred_original_sample2 = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output2
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
                " `v_prediction`  for the DDPMScheduler."
            )


        # 3. Clip or threshold "predicted x_0"
        if self.config.thresholding:
            pred_original_sample1 = self._threshold_sample(pred_original_sample1)
        elif self.config.clip_sample:
            pred_original_sample1 = pred_original_sample1.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )


        # 3. Clip or threshold "predicted x_0"
        if self.config.thresholding:
            pred_original_sample2 = self._threshold_sample(pred_original_sample2)
        elif self.config.clip_sample:
            pred_original_sample2 = pred_original_sample2.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )


        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
        # See formula (7) from https://huggingface.co/papers/2006.11239
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
        current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t


        # 5. Compute predicted previous sample µ_t
        # See formula (7) from https://huggingface.co/papers/2006.11239
        pred_prev_sample1 = pred_original_sample_coeff * pred_original_sample1 + current_sample_coeff * sample

        pred_prev_sample2 = pred_original_sample_coeff * pred_original_sample2 + current_sample_coeff * sample

        # print("buut")
        # print(self.variance_type)
        # 6. Add noise
        variance = 0

        device = model_output1.device
        variance_noise = randn_tensor(
            model_output1.shape, generator=generator, device=device, dtype=model_output1.dtype
        )
        if self.variance_type == "fixed_small_log":
            v1 = self._get_variance(t, predicted_variance=predicted_variance1)
            v2 = self._get_variance(t, predicted_variance=predicted_variance2)
            v = 1/((1-lmbda)/v1 + lmbda/v2)
            print(v)
            variance = v * variance_noise
        elif self.variance_type == "learned_range":
            v1 = self._torch.exp(0.5 * self._get_variance(t, predicted_variance=predicted_variance1))
            v2 = self._torch.exp(0.5 * self._get_variance(t, predicted_variance=predicted_variance2))
            v = 1/((1-lmbda)/v1 + lmbda/v2)
            print("kuku")
            variance = v * variance_noise
        else:
            v1 = (self._get_variance(t, predicted_variance=predicted_variance1) ** 0.5)
            v2 = (self._get_variance(t, predicted_variance=predicted_variance2) ** 0.5)
            v = 1/((1-lmbda)/v1 + lmbda/v2)
            variance = v * variance_noise


        if t > 1:
            pred_prev_sample = v * ((1-lmbda)*pred_prev_sample1/v1 + lmbda * pred_prev_sample2/v2)
            pred_prev_sample = pred_prev_sample + variance
            self.save = pred_prev_sample
        else:

            pred_prev_sample = self.save


        # print(pred_prev_sample)


        # # 6. Add noise
        # variance = 0
        # if t > 0:
        #     device = model_output2.device
        #     variance_noise = randn_tensor(
        #         model_output2.shape, generator=generator, device=device, dtype=model_output2.dtype
        #     )
        #     if self.variance_type == "fixed_small_log":
        #         variance = self._get_variance(t, predicted_variance=predicted_variance2) * variance_noise
        #     elif self.variance_type == "learned_range":
        #         variance = self._get_variance(t, predicted_variance=predicted_variance2)
        #         variance = torch.exp(0.5 * variance) * variance_noise
        #     else:
        #         variance = (self._get_variance(t, predicted_variance=predicted_variance2) ** 0.5) * variance_noise

        # pred_prev_sample = pred_prev_sample2 + variance

        if not return_dict:
            return (
                pred_prev_sample,
                pred_original_sample2, #doesn't matter
            )

        return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample1)

from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import deprecate, replace_example_docstring
from typing import Union, List, Optional, Callable, Dict, Any

from diffusers.image_processor import PipelineImageInput

from diffusers.utils import (
    USE_PEFT_BACKEND,
    deprecate,
    is_torch_xla_available,
    logging,
    replace_example_docstring,
    scale_lora_layers,
    unscale_lora_layers,
)

if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):
    r"""
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`List[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    """
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

class CustomCombinedPipeline(StableDiffusionPipeline):
      def __init__(self,
                 vae, text_encoder, tokenizer, unet, # base unet
                 dpo_unet, # dpo unet
                 scheduler, # Instance of CustomCombinedDDPMScheduler
                 safety_checker, feature_extractor,
                 image_encoder,
                 requires_safety_checker: bool = True,
                 lmbda: float = 0.5): # Default weight for the base unet

        super().__init__(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler,
                         safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, requires_safety_checker=requires_safety_checker)

        self.shceduler = scheduler
        # Register dpo_unet so it's also moved to device, etc.
        # self.dpo_unet = dpo_unet # Direct assignment
        self.register_modules(dpo_unet=dpo_unet) # Use this to properly register
        self.lmbda = lmbda

      @torch.no_grad()
      # @replace_example_docstring(EXAMPLE_DOC_STRING)
      def __call__(
          self,
          prompt: Union[str, List[str]] = None,
          height: Optional[int] = None,
          width: Optional[int] = None,
          num_inference_steps: int = 50,
          timesteps: List[int] = None,
          sigmas: List[float] = None,
          guidance_scale: float = 7.5,
          negative_prompt: Optional[Union[str, List[str]]] = None,
          num_images_per_prompt: Optional[int] = 1,
          eta: float = 0.0,
          generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
          latents: Optional[torch.Tensor] = None,
          prompt_embeds: Optional[torch.Tensor] = None,
          negative_prompt_embeds: Optional[torch.Tensor] = None,
          ip_adapter_image: Optional[PipelineImageInput] = None,
          ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
          output_type: Optional[str] = "pil",
          return_dict: bool = True,
          cross_attention_kwargs: Optional[Dict[str, Any]] = None,
          guidance_rescale: float = 0.0,
          clip_skip: Optional[int] = None,
          callback_on_step_end = None,
          callback_on_step_end_tensor_inputs: List[str] = ["latents"],
          **kwargs,
      ):




          callback = kwargs.pop("callback", None)
          callback_steps = kwargs.pop("callback_steps", None)

          if callback is not None:
              deprecate(
                  "callback",
                  "1.0.0",
                  "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
              )
          if callback_steps is not None:
              deprecate(
                  "callback_steps",
                  "1.0.0",
                  "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
              )

          # if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
          #     callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

          # 0. Default height and width to unet
          if not height or not width:
              height = (
                  self.unet.config.sample_size
                  if self._is_unet_config_sample_size_int
                  else self.unet.config.sample_size[0]
              )
              width = (
                  self.unet.config.sample_size
                  if self._is_unet_config_sample_size_int
                  else self.unet.config.sample_size[1]
              )
              height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
          # to deal with lora scaling and other possible forward hooks

          # 1. Check inputs. Raise error if not correct
          self.check_inputs(
              prompt,
              height,
              width,
              callback_steps,
              negative_prompt,
              prompt_embeds,
              negative_prompt_embeds,
              ip_adapter_image,
              ip_adapter_image_embeds,
              callback_on_step_end_tensor_inputs,
          )

          self._guidance_scale = guidance_scale
          self._guidance_rescale = guidance_rescale
          self._clip_skip = clip_skip
          self._cross_attention_kwargs = cross_attention_kwargs
          self._interrupt = False

          # 2. Define call parameters
          if prompt is not None and isinstance(prompt, str):
              batch_size = 1
          elif prompt is not None and isinstance(prompt, list):
              batch_size = len(prompt)
          else:
              batch_size = prompt_embeds.shape[0]

          device = self._execution_device

          # 3. Encode input prompt
          lora_scale = (
              self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
          )

          prompt_embeds, negative_prompt_embeds = self.encode_prompt(
              prompt,
              device,
              num_images_per_prompt,
              self.do_classifier_free_guidance,
              negative_prompt,
              prompt_embeds=prompt_embeds,
              negative_prompt_embeds=negative_prompt_embeds,
              lora_scale=lora_scale,
              clip_skip=self.clip_skip,
          )

          # For classifier free guidance, we need to do two forward passes.
          # Here we concatenate the unconditional and text embeddings into a single batch
          # to avoid doing two forward passes
          if self.do_classifier_free_guidance:
              prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

          if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
              image_embeds = self.prepare_ip_adapter_image_embeds(
                  ip_adapter_image,
                  ip_adapter_image_embeds,
                  device,
                  batch_size * num_images_per_prompt,
                  self.do_classifier_free_guidance,
              )

          # 4. Prepare timesteps
          timesteps, num_inference_steps = retrieve_timesteps(
              scheduler=self.scheduler, num_inference_steps=num_inference_steps, device=device, timesteps=timesteps, sigmas=sigmas
          )

          # 5. Prepare latent variables
          num_channels_latents = self.unet.config.in_channels
          latents = self.prepare_latents(
              batch_size * num_images_per_prompt,
              num_channels_latents,
              height,
              width,
              prompt_embeds.dtype,
              device,
              generator,
              latents,
          )

          # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
          extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

          # 6.1 Add image embeds for IP-Adapter
          added_cond_kwargs = (
              {"image_embeds": image_embeds}
              if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
              else None
          )

          # 6.2 Optionally get Guidance Scale Embedding
          timestep_cond = None
          if self.unet.config.time_cond_proj_dim is not None:
              guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
              timestep_cond = self.get_guidance_scale_embedding(
                  guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
              ).to(device=device, dtype=latents.dtype)

          # 7. Denoising loop
          num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
          self._num_timesteps = len(timesteps)
          with self.progress_bar(total=num_inference_steps) as progress_bar:
              for i, t in enumerate(timesteps):
                  if self.interrupt:
                      continue

                  # expand the latents if we are doing classifier free guidance
                  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)


                  base_noise_pred = self.unet(
                      latent_model_input,
                      t,
                      encoder_hidden_states=prompt_embeds,
                      timestep_cond=timestep_cond,
                      cross_attention_kwargs=self.cross_attention_kwargs,
                      added_cond_kwargs=added_cond_kwargs,
                      return_dict=False,
                  )[0]

                  dpo_noise_pred = self.dpo_unet(
                      latent_model_input,
                      t,
                      encoder_hidden_states=prompt_embeds,
                      timestep_cond=timestep_cond,
                      cross_attention_kwargs=self.cross_attention_kwargs,
                      added_cond_kwargs=added_cond_kwargs,
                      return_dict=False,
                  )[0]

                  # # predict the noise residual
                  # noise_pred = self.unet(
                  #     latent_model_input,
                  #     t,
                  #     encoder_hidden_states=prompt_embeds,
                  #     timestep_cond=timestep_cond,
                  #     cross_attention_kwargs=self.cross_attention_kwargs,
                  #     added_cond_kwargs=added_cond_kwargs,
                  #     return_dict=False,
                  # )[0]

                  # perform guidance
                  if self.do_classifier_free_guidance:
                      base_noise_pred_uncond, base_noise_pred_text = base_noise_pred.chunk(2)
                      base_noise_pred = base_noise_pred_uncond + self.guidance_scale * (base_noise_pred_text - base_noise_pred_uncond)

                      dpo_noise_pred_uncond, dpo_noise_pred_text = dpo_noise_pred.chunk(2)
                      dpo_noise_pred = dpo_noise_pred_uncond + self.guidance_scale * (dpo_noise_pred_text - dpo_noise_pred_uncond)

                  if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                      # Based on 3.4. in https://huggingface.co/papers/2305.08891
                      base_noise_pred = rescale_noise_cfg(base_noise_pred, base_noise_pred_text, guidance_rescale=self.guidance_rescale)
                      dpo_noise_pred = rescale_noise_cfg(dpo_noise_pred, dpo_noise_pred_text, guidance_rescale=self.guidance_rescale)

                  # compute the previous noisy sample x_t -> x_t-1
                  latents = self.scheduler.step(Lambda = self.lmbda, model_output_ref = base_noise_pred, model_output_beta = dpo_noise_pred, timestep = t, sample = latents, **extra_step_kwargs, return_dict=False)[0]

                  if callback_on_step_end is not None:
                      callback_kwargs = {}
                      for k in callback_on_step_end_tensor_inputs:
                          callback_kwargs[k] = locals()[k]
                      callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                      latents = callback_outputs.pop("latents", latents)
                      prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                      negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

                  # call the callback, if provided
                  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                      progress_bar.update()
                      if callback is not None and i % callback_steps == 0:
                          step_idx = i // getattr(self.scheduler, "order", 1)
                          callback(step_idx, t, latents)

                  if XLA_AVAILABLE:
                      xm.mark_step()

          if not output_type == "latent":
              image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
                  0
              ]
              image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
          else:
              image = latents
              has_nsfw_concept = None

          if has_nsfw_concept is None:
              do_denormalize = [True] * image.shape[0]
          else:
              do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
          image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

          # Offload all models
          self.maybe_free_model_hooks()

          if not return_dict:
              return (image, has_nsfw_concept)

          return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

# custom_scheduler = CustomCombinedDDPMScheduler.from_config(pipe.scheduler.config, steps_offset=0)
custom_scheduler = CustomEulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config, steps_offset=0)

passed_image_encoder = getattr(pipe, 'image_encoder', None)

combined_pipe = CustomCombinedPipeline(
    vae=pipe.vae,
    text_encoder=pipe.text_encoder,
    tokenizer=pipe.tokenizer,
    unet=pipe.unet,          # Base UNet
    dpo_unet=dpo_unet,       # DPO UNet
    scheduler=custom_scheduler,
    safety_checker=pipe.safety_checker,
    feature_extractor=pipe.feature_extractor,
    image_encoder=passed_image_encoder,
    requires_safety_checker=(pipe.safety_checker is not None),
    lmbda=args.Lambda   # Example weight, you can make this configurable
).to("cuda")

unets = [pipe.unet, dpo_unet]
names = ["Orig. SDXL", "DPO SDXL"]

def gen(image_id, prompt, seed=0, run_baseline=True, output_dir="sd15_outputs"):
    ims = []
    generator = torch.Generator(device='cuda')

    # print(f"Prompt: {prompt}\nSeed: {seed}\n{names[unet_i]}")

    print(prompt)

    generator = generator.manual_seed(seed)

    im = combined_pipe(prompt=prompt, generator=generator, guidance_scale=gs, num_inference_steps=50).images[0]


    filename = f"{output_dir}/{image_id}_seed{seed}.png"
    #im.save(filename)
    ims.append(im)
    print(f"Saved to {filename}")
    return ims

# import
from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch


# # Can do clip_utils, aes_utils, hps_utils
from utils.aes_utils import Selector
# # Score generations automatically w/ reward model
ps_selector = Selector('cuda')

example_prompts = [
    "A pile of sand swirling in the wind forming the shape of a dancer",
    "A giant dinosaur frozen into a glacier and recently discovered by scientists, cinematic still",
    "a smiling beautiful sorceress with long dark hair and closed eyes wearing a dark top surrounded by glowing fire sparks at night, magical light fog, deep focus+closeup, hyper-realistic, volumetric lighting, dramatic lighting, beautiful composition, intricate details, instagram, trending, photograph, film grain and noise, 8K, cinematic, post-production",
    "A purple raven flying over big sur, light fog, deep focus+closeup, hyper-realistic, volumetric lighting, dramatic lighting, beautiful composition, intricate details, instagram, trending, photograph, film grain and noise, 8K, cinematic, post-production",
    "a smiling beautiful sorceress wearing a modest high necked blue suit surrounded by swirling rainbow aurora, hyper-realistic, cinematic, post-production",
    "Anthro humanoid turtle skydiving wearing goggles, gopro footage",
    "A man in a suit surfing in a river",
    "photo of a zebra dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography",
    "A typhoon in a tea cup, digital render",
    "A cute puppy leading a session of the United Nations, newspaper photography",
    "Worm eye view of rocketship",
    "Glass spheres in the desert, refraction render",
    "anthropmorphic coffee bean drinking coffee",
    "A baby kangaroo in a trenchcoat",
    "A towering hurricane of rainbow colors towering over a city, cinematic digital art",
    "A redwood tree rising up out of the ocean"
]

partipromt_50 = [
    'a family of four posing on the moon',
    'a family of bears passing by the geyser Old Faithful',
    'a high-quality photograph of an armadillo playing a bagpipe while standing on one leg',
    'The Statue of Liberty with the Manhattan skyline in the background.',
    'a milk container in a refrigerator',
    'a Christmas tree',
    'a dolphin in an astronaut suit on saturn',
    "the saying 'do unto others as they would do unto you' written on a white background",
    'a thumbnail image of a person skiing',
    'an old-fashioned cocktail next to a napkin',
    'a small kitchen with a white goat in it',
    'view of a clock tower from above',
    'A helicopter flies over Yosemite.',
    'A punk rock platypus in a studded leather jacket shouting into a microphone while standing on a boulder',
    'background pattern with alternating roses and skulls',
    'molecule',
    'Downtown Seattle at sunrise. detailed ink wash.',
    'a Christmas tree on a toy train',
    'a girl',
    'a red train is coming down the beach',
    'A shiny VW van that has flowers painted on it. A smiling sloth stands on grass in front of the van and is wearing a leather jacket, a cowboy hat, a kilt and a bowtie. The sloth is holding a quarterstaff and a big book. ink sketch.',
    'the mona lisa',
    'The mouse the cat watches is jumping in the air.',
    'a white towel with a cartoon of a cat on it',
    'A smiling sloth wearing a leather jacket, a cowboy hat, a kilt and a bowtie. The sloth is holding a quarterstaff and a big book. A shiny VW van with a cityscape painted on it and parked on grass.',
    'a pen-and-ink crosshatched drawing of a sphere with dark square on it',
    'a hamster dragon',
    'A Vietnam map',
    'a coffee mug floating in the sky',
    'a flag',
    'an owl standing on a wire',
    'a car with tires that have yellow rims',
    'teacup',
    'a three quarters view of a man getting into a car',
    'A giant cobra snake made from pancakes',
    'a black dog jumping up to hug a woman wearing a red sweater',
    'a half moon in the day sky',
    'a tree growing out of the middle of an intersection',
    'three chairs',
    'a turkey',
    'a snail made of harp',
    'a sword in a stone',
    'a doorknocker',
    'an emoji of a baby penguin wearing a blue hat, red gloves, green shirt, and yellow pants',
    'A castle made of cardboard.',
    'ten wine bottles',
    'a girl riding an ostrich',
    'Mars rises on the horizon.',
    'a yellow diamond-shaped sign with a turtle silhouette',
    'A single beam of light enter the room from the ceiling. The beam of light is illuminating an easel. On the easel there is a Rembrandt painting of a raccoon'
]



# hpsv2 = ['A film still of Luke Skywalker as a Sith Lord.', 'A minimalistic heart drawing created using Adobe Illustrator.', 'Portrait of a male furry anthro mountain goat in a pinstripe suit and waistcoat, smoking a cigar.', 'Renaissance noblewoman with blue eyes and pale skin in a classical portrait pose in the art style of Ib Iwerks.', 'A close-up image of a woman wearing a samurai mask, fire dancing in a dirty cyberpunk alley with smoke and mist.', 'A lemon character with sunglasses on the beach.', 'Abstract yin yang representation by Ivan Bilibin.', 'An image of Akira, from the artist Simon Stalenhag.', 'Galactus devouring planet earth, depicted in an artwork by Francisco Goya.', 'A hyperrealistic mixed media image of a proportionally sized human hand undergoing particle teleportation, with perfect symmetry and dim volumetric lighting.', 'A dolphin swimming in front of a Studio Ghibli logo backdrop.', 'A man drinking cosmic energy in an anime-style digital art by Park Sung-woo.', 'Two messy toilet stalls with toilets where one lid is raised. ', 'A lemon character wearing sunglasses on the beach.', 'Undertale character Spromple Sploop, third brother of Sans.', 'A plane is on display near the water.', 'A young woman smiling in the etheric hypothalamus of her mind.', 'A portrait painting of a Red Borzoi Dog wearing a red beret as an Overwatch character.', 'A giant guardian wearing road sign armor, a popular character design on Artstation.', 'A bicycle leaned against the hallway wall in a house', 'A white toilet sitting under a window next to a chair.', 'A manga-style illustration of Harry Potter as a Gundam mech.', 'A Nintendo 64 controller with anthropomorphic features consuming small children.', 'A portrait painting of Yondu Udonta in an asymmetrical profile shot, incorporating bold shapes and hard edges with a stylized street art aesthetic.', 'A surreal portrait of a young Spanish man wearing sock and titled "Super Spy Captain" with deep purple hair and green eyes on an orange background.', 'A portrait of a stylized business cat in sharp focus with a medium shot perspective, resembling boxart.', 'A man sitting on a black and yellow bench on the phone.', 'A pen illustration of a man wrestling his phone by Gustave Doré with crosshatching and pops of colorful Ben Day dots.', 'A masterpiece.', 'Albus Dumbledore dressed up as Wonder Woman.']

hpsv2_50 = ['A teenage mutant ninja turtle, Leonardo, enjoys a cup of tea at a wooden desk in a sci-fi space station orbiting a large planet visible through a window.', 'The top of a steeped church building with clocks and small windows.  ', 'A bald general with an angry expression in an intricately detailed and elegant digital painting.', 'A furry cat girl.', 'Celine Dion appears angry at a kitten in a hot tub.', 'A vampire sits at a banquet table in a dungeon setting surrounded by plates of rats and spiders and red candles.', 'A bathroom with a small sink and toilet. ', 'Three small dinosaurs entering a grocery store painted by Thomas Kinkade.', 'Large sized kitchen with a dining room section.', '"Front centered symmetrical portrait of Elisha Cuthbert as a D&D paladin with cinematic lighting."', 'There are orange slices in canning jars without lids.', 'A portrait of an orc in a fantasy art style.', 'Exterior image of a small magic items and curios shop in a busy fantasy city.', 'A portrait of a character in a scenic environment.', 'A plane riding down a runway of an airport.', 'A stylized digital art image of a cherry tree overlooking a valley with a waterfall during sunset.', 'A girl sneaking behind a giant wooden door with archaic symbols embedded onto it, in a cave with the waterfall, illustrated in comics style.', 'Flag design for communist European Union featuring a hammer and sickle.', 'A digital painting of Teemo from League of Legends, wearing cyborg parts and a new skin, in a fantasy MMORPG style.', 'an empty bench sitting on the side of a sidewalk', 'Image of xqc with a distinctive underbite and big, long nose.', 'An abstract collage featuring grey and lilac colors with a touch of sparkle.', 'Patrick Bateman beating an anthropomorphic wolf cosplay.', 'An ultra-realistic illustration of a bird god swinging a gold metal stick weapon, with a blue man face and yellow bird mouth, and intricate traditional Chinese elements.', 'Several people standing next to each other that are snow skiing.', 'A little orange kitten sits on a pink heart-shaped pillow.', 'A pink bicycle leaning against a fence near a river.', 'A landscape painting of a China mountain village with a turbulent blood lake.', 'Abstract yin yang representation by Ivan Bilibin.', 'A half body portrait of an Asian cyberpunk mechanoid fashion idol wearing a neon jellyfish headdress and xenomorphic body suit.', 'A realistic anime painting of a cosmic woman wearing clothes made of universes with glowing red eyes.', 'A Walter White funko pop figurine.', 'A VTuber model concept art of a beautiful girl in a black and yellow hoodie looking on a smartphone in her hand, with blue eyes, long hair, and a futuristic city background.', 'A high detail portrait of a royal mansion by Michelangelo Merisi da Caravaggio.', "Animation keyframes featuring a wolf's walking motion.", 'A painted portrait of Persephone in ancient Greece with intricate detail, iridescent coloring, and golden hour lighting.', 'A digital painting by James Jean depicting a goddess in a strong pose surrounded by planets in a hyper-realistic style.', 'A digital painting of an anthropomorphic corgi lifting weights in a dim gym with intricate details and a dynamic pose.', 'A ps2 anime witch from madoka magicka is flying on a broom through New York causing people to run for their lives due to a terrorist attack.', 'A cat wearing a war helmet.', 'there is a woman that is cutting a white cake', 'A white stove top oven inside of a kitchen.', 'a jet airplane sitting on a runway next to a building', 'A neon-colored frog in a cyberpunk setting.', 'A woman in a bathing suit captured in an ink drawing by Sam Bosma with outlined and stippled details.', 'A girl looks out from the edge of a mountain onto a large city at night.', 'there is a very beautiful view out of this bathroom window', 'A female Sonic the Hedgehog with black sclera and bright red pupils.', 'A futuristic city with a lake, a reflection of utopia, and jungle scenery, featuring drones and androids.', 'A vivid and intricate depiction of a terrifying god-like creature with rich, bold colors and influences from various artists.']

all_prompts = partipromt_50 + hpsv2_50

# example_prompts = ['a stop sign with a large tree behind it', 'a stop sign knocked over on a sidewalk', 'teacups surounding a kettle', 'a dragon breathing fire', 'a dragon breathing fire on a castle', 'a dragon breathing fire onto a knight', 'a view of the Big Dipper in the night sky', 'a view of the Orion constellation in the night sky', 'a teddy bear to the right of a toy car', 'a toy car in front of a teddy bear', 'a large present with a red ribbon', 'a large present with a red ribbon to the left of a Christmas tree', 'a half empty bottle of red wine', 'a wine bottle with a lit candle stuck in its spout', 'a wine bottle with a red ribbon wrapped around it', "a kids' book cover with an illustration of white dog driving a red pickup truck", 'milk pouring into a large glass', 'milk pouring from a glass into a bowl', 'matching socks with cute cats on them', 'cash on a wooden table', 'a wood cabin with a fire pit in front of it', 'view of a clock tower on a cloudy day', 'a chimpanzee wearing a bowtie and playing a piano', 'a black baseball hat with a flame decal on it', 'black hi-top sneakers with the Nike swoosh', 'a roast turkey being taken out of the oven', 'a bamboo ladder propped up against an oak tree', 'a Tyrannosaurus Rex roaring in front of a palm tree', 'a Stegasaurus eating ferns', 'a Triceratops charging down a hill']


# partipromt_50 = [
#     'a family of four posing on the moon',
#     'a family of bears passing by the geyser Old Faithful',
#     'a high-quality photograph of an armadillo playing a bagpipe while standing on one leg',
#     'The Statue of Liberty with the Manhattan skyline in the background.',
#     'a milk container in a refrigerator',
#     'a Christmas tree',
#     'a dolphin in an astronaut suit on saturn',
#     "the saying 'do unto others as they would do unto you' written on a white background",
#     'a thumbnail image of a person skiing',
#     'an old-fashioned cocktail next to a napkin',
#     'a small kitchen with a white goat in it',
#     'view of a clock tower from above',
#     'A helicopter flies over Yosemite.',
#     'A punk rock platypus in a studded leather jacket shouting into a microphone while standing on a boulder',
#     'background pattern with alternating roses and skulls',
#     'molecule',
#     'Downtown Seattle at sunrise. detailed ink wash.',
#     'a Christmas tree on a toy train',
#     'a girl',
#     'a red train is coming down the beach',
#     'A shiny VW van that has flowers painted on it. A smiling sloth stands on grass in front of the van and is wearing a leather jacket, a cowboy hat, a kilt and a bowtie. The sloth is holding a quarterstaff and a big book. ink sketch.',
#     'the mona lisa',
#     'The mouse the cat watches is jumping in the air.',
#     'a white towel with a cartoon of a cat on it',
#     'A smiling sloth wearing a leather jacket, a cowboy hat, a kilt and a bowtie. The sloth is holding a quarterstaff and a big book. A shiny VW van with a cityscape painted on it and parked on grass.',
#     'a pen-and-ink crosshatched drawing of a sphere with dark square on it',
#     'a hamster dragon',
#     'A Vietnam map',
#     'a coffee mug floating in the sky',
#     'a flag',
#     'an owl standing on a wire',
#     'a car with tires that have yellow rims',
#     'teacup',
#     'a three quarters view of a man getting into a car',
#     'A giant cobra snake made from pancakes',
#     'a black dog jumping up to hug a woman wearing a red sweater',
#     'a half moon in the day sky',
#     'a tree growing out of the middle of an intersection',
#     'three chairs',
#     'a turkey',
#     'a snail made of harp',
#     'a sword in a stone',
#     'a doorknocker',
#     'an emoji of a baby penguin wearing a blue hat, red gloves, green shirt, and yellow pants',
#     'A castle made of cardboard.',
#     'ten wine bottles',
#     'a girl riding an ostrich',
#     'Mars rises on the horizon.',
#     'a yellow diamond-shaped sign with a turtle silhouette',
#     'A single beam of light enter the room from the ceiling. The beam of light is illuminating an easel. On the easel there is a Rembrandt painting of a raccoon'
# ]



# hpsv2 = ['A film still of Luke Skywalker as a Sith Lord.', 'A minimalistic heart drawing created using Adobe Illustrator.', 'Portrait of a male furry anthro mountain goat in a pinstripe suit and waistcoat, smoking a cigar.', 'Renaissance noblewoman with blue eyes and pale skin in a classical portrait pose in the art style of Ib Iwerks.', 'A close-up image of a woman wearing a samurai mask, fire dancing in a dirty cyberpunk alley with smoke and mist.', 'A lemon character with sunglasses on the beach.', 'Abstract yin yang representation by Ivan Bilibin.', 'An image of Akira, from the artist Simon Stalenhag.', 'Galactus devouring planet earth, depicted in an artwork by Francisco Goya.', 'A hyperrealistic mixed media image of a proportionally sized human hand undergoing particle teleportation, with perfect symmetry and dim volumetric lighting.', 'A dolphin swimming in front of a Studio Ghibli logo backdrop.', 'A man drinking cosmic energy in an anime-style digital art by Park Sung-woo.', 'Two messy toilet stalls with toilets where one lid is raised. ', 'A lemon character wearing sunglasses on the beach.', 'Undertale character Spromple Sploop, third brother of Sans.', 'A plane is on display near the water.', 'A young woman smiling in the etheric hypothalamus of her mind.', 'A portrait painting of a Red Borzoi Dog wearing a red beret as an Overwatch character.', 'A giant guardian wearing road sign armor, a popular character design on Artstation.', 'A bicycle leaned against the hallway wall in a house', 'A white toilet sitting under a window next to a chair.', 'A manga-style illustration of Harry Potter as a Gundam mech.', 'A Nintendo 64 controller with anthropomorphic features consuming small children.', 'A portrait painting of Yondu Udonta in an asymmetrical profile shot, incorporating bold shapes and hard edges with a stylized street art aesthetic.', 'A surreal portrait of a young Spanish man wearing sock and titled "Super Spy Captain" with deep purple hair and green eyes on an orange background.', 'A portrait of a stylized business cat in sharp focus with a medium shot perspective, resembling boxart.', 'A man sitting on a black and yellow bench on the phone.', 'A pen illustration of a man wrestling his phone by Gustave Doré with crosshatching and pops of colorful Ben Day dots.', 'A masterpiece.', 'Albus Dumbledore dressed up as Wonder Woman.']

#hpsv2_50 = ['A teenage mutant ninja turtle, Leonardo, enjoys a cup of tea at a wooden desk in a sci-fi space station orbiting a large planet visible through a window.', 'The top of a steeped church building with clocks and small windows.  ', 'A bald general with an angry expression in an intricately detailed and elegant digital painting.', 'A furry cat girl.', 'Celine Dion appears angry at a kitten in a hot tub.', 'A vampire sits at a banquet table in a dungeon setting surrounded by plates of rats and spiders and red candles.', 'A bathroom with a small sink and toilet. ', 'Three small dinosaurs entering a grocery store painted by Thomas Kinkade.', 'Large sized kitchen with a dining room section.', '"Front centered symmetrical portrait of Elisha Cuthbert as a D&D paladin with cinematic lighting."', 'There are orange slices in canning jars without lids.', 'A portrait of an orc in a fantasy art style.', 'Exterior image of a small magic items and curios shop in a busy fantasy city.', 'A portrait of a character in a scenic environment.', 'A plane riding down a runway of an airport.', 'A stylized digital art image of a cherry tree overlooking a valley with a waterfall during sunset.', 'A girl sneaking behind a giant wooden door with archaic symbols embedded onto it, in a cave with the waterfall, illustrated in comics style.', 'Flag design for communist European Union featuring a hammer and sickle.', 'A digital painting of Teemo from League of Legends, wearing cyborg parts and a new skin, in a fantasy MMORPG style.', 'an empty bench sitting on the side of a sidewalk', 'Image of xqc with a distinctive underbite and big, long nose.', 'An abstract collage featuring grey and lilac colors with a touch of sparkle.', 'Patrick Bateman beating an anthropomorphic wolf cosplay.', 'An ultra-realistic illustration of a bird god swinging a gold metal stick weapon, with a blue man face and yellow bird mouth, and intricate traditional Chinese elements.', 'Several people standing next to each other that are snow skiing.', 'A little orange kitten sits on a pink heart-shaped pillow.', 'A pink bicycle leaning against a fence near a river.', 'A landscape painting of a China mountain village with a turbulent blood lake.', 'Abstract yin yang representation by Ivan Bilibin.', 'A half body portrait of an Asian cyberpunk mechanoid fashion idol wearing a neon jellyfish headdress and xenomorphic body suit.', 'A realistic anime painting of a cosmic woman wearing clothes made of universes with glowing red eyes.', 'A Walter White funko pop figurine.', 'A VTuber model concept art of a beautiful girl in a black and yellow hoodie looking on a smartphone in her hand, with blue eyes, long hair, and a futuristic city background.', 'A high detail portrait of a royal mansion by Michelangelo Merisi da Caravaggio.', "Animation keyframes featuring a wolf's walking motion.", 'A painted portrait of Persephone in ancient Greece with intricate detail, iridescent coloring, and golden hour lighting.', 'A digital painting by James Jean depicting a goddess in a strong pose surrounded by planets in a hyper-realistic style.', 'A digital painting of an anthropomorphic corgi lifting weights in a dim gym with intricate details and a dynamic pose.', 'A ps2 anime witch from madoka magicka is flying on a broom through New York causing people to run for their lives due to a terrorist attack.', 'A cat wearing a war helmet.', 'there is a woman that is cutting a white cake', 'A white stove top oven inside of a kitchen.', 'a jet airplane sitting on a runway next to a building', 'A neon-colored frog in a cyberpunk setting.', 'A woman in a bathing suit captured in an ink drawing by Sam Bosma with outlined and stippled details.', 'A girl looks out from the edge of a mountain onto a large city at night.', 'there is a very beautiful view out of this bathroom window', 'A female Sonic the Hedgehog with black sclera and bright red pupils.', 'A futuristic city with a lake, a reflection of utopia, and jungle scenery, featuring drones and androids.', 'A vivid and intricate depiction of a terrifying god-like creature with rich, bold colors and influences from various artists.']

# all_prompts = partipromt_50 + hpsv2_50

all_scores = []

"""
for p in example_prompts:
    ims = gen(p) # could save these if desired
    scores = ps_selector.score(ims, p)
    all_scores.append(scores[0])
    print(scores)
"""

for i, p in enumerate(prompts.test_prompts, start=1):
    ims = gen(image_id=i, prompt=p)
    scores = ps_selector.score(ims, p)
    all_scores.append(scores[0])
    print(scores)

print(all_scores)
import statistics
print(statistics.median(all_scores))
print(statistics.mean(all_scores))

median_value = statistics.median(all_scores)
mean_value = statistics.mean(all_scores)


with open("_sd15_final_" + args.model_name + "_AES_Approximations", "a") as f:
    # Write the stats
    print(f"Beta {args.Beta/args.Lambda}", file=f)
    print(f"Median: {median_value}", file=f)
    print(f"Mean:   {mean_value}", file=f)
    # Two blank lines
    print(file=f)
    print(file=f)


# to get partiprompts captions
# from datasets import load_dataset
# dataset = load_dataset("nateraw/parti-prompts")
# print(dataset['train']['Prompt'])

