# -*- coding: utf-8 -*-
"""DeRaDiff_SDXL

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1MAYFdJE_M-p42tRRR9ZLTEtTeDz6ZImT
"""

from diffusers import StableDiffusionPipeline, UNet2DConditionModel, StableDiffusionXLPipeline
import torch
import os
import argparse
import json, sys
import prompts
torch.set_grad_enabled(False)

parser = argparse.ArgumentParser(
    description="Run SDXL lambda sampling with a custom λ value."
)
parser.add_argument(
    "--Lambda",
    type=float,
    default=1.0,
    help="Lambda value for the sampling (e.g. 0.125)."
)

parser.add_argument(
    "--Beta",
    type=float,
    default=1.0,
    help="Lambda value for the sampling (e.g. 0.125)."
)


parser.add_argument(
    "--model_name",
    type=str,
    default=None,
    help='Name of the model, e.g. "beta2000_64acc_600epoch".'
)


args = parser.parse_args()


dpo_unet = UNet2DConditionModel.from_pretrained(
                            #  'mhdang/dpo-sd1.5-text2image-v1',
                            # 'mhdang/dpo-sdxl-text2image-v1',
                            # alternatively use local ckptdir (*/checkpoint-n/)
                            # 'beta2000_64acc_600epoch',
                            # 'tmp-sdxl-1000',
                            args.model_name,
                            subfolder='unet',
                            torch_dtype=torch.float16
).to('cuda')

# pretrained_model_name = "CompVis/stable-diffusion-v1-4"
# pretrained_model_name = "runwayml/stable-diffusion-v1-5"
pretrained_model_name = "stabilityai/stable-diffusion-xl-base-1.0"
gs = (5 if 'stable-diffusion-xl' in pretrained_model_name else 7.5)

if 'stable-diffusion-xl' in pretrained_model_name:
    pipe = StableDiffusionXLPipeline.from_pretrained(
        pretrained_model_name, torch_dtype=torch.float16, device_map="balanced",
        variant="fp16", use_safetensors=True
    )#.to("cuda")
else:
    pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name, device_map="balanced",
                                                   torch_dtype=torch.float16)
#pipe = pipe.to('cuda')
pipe.safety_checker = None # Trigger-happy, blacks out >50% of "robot tiger"

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

def randn_tensor(
    shape: Union[Tuple, List],
    generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
    device: Optional[Union[str, "torch.device"]] = None,
    dtype: Optional["torch.dtype"] = None,
    layout: Optional["torch.layout"] = None,
):
    """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
    passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
    is always created on the CPU.
    """
    # device on which tensor is created defaults to device
    if isinstance(device, str):
        device = torch.device(device)
    rand_device = device
    batch_size = shape[0]

    layout = layout or torch.strided
    device = device or torch.device("cpu")

    if generator is not None:
        gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
        if gen_device_type != device.type and gen_device_type == "cpu":
            rand_device = "cpu"
            if device != "mps":
                logger.info(
                    f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
                    f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
                    f" slightly speed up this function by passing a generator that was created on the {device} device."
                )
        elif gen_device_type != device.type and gen_device_type == "cuda":
            raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")

    # make sure generator list of length 1 is treated like a non-list
    if isinstance(generator, list) and len(generator) == 1:
        generator = generator[0]

    if isinstance(generator, list):
        shape = (1,) + shape[1:]
        latents = [
            torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
            for i in range(batch_size)
        ]
        latents = torch.cat(latents, dim=0).to(device)
    else:
        latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)

    return latents

from diffusers import EulerAncestralDiscreteScheduler
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
# from diffusers.utils import randn_tensor # For noise generation
import PIL.Image

class CustomEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
    def step(
        self,
        Lambda : float,
        model_output_ref: torch.Tensor,
        model_output_beta: torch.Tensor,
        timestep: Union[float, torch.Tensor],
        sample: torch.Tensor,
        generator: Optional[torch.Generator] = None,
        return_dict: bool = True,
    ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:

        if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
            raise ValueError(
                (
                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
                    " one of the `scheduler.timesteps` as a timestep."
                ),
            )

        if not self.is_scale_input_called:
            logger.warning(
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]

        # Upcast to avoid precision issues when computing prev_sample
        sample = sample.to(torch.float32)


        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
        if self.config.prediction_type == "epsilon":
            pred_original_sample_ref = sample - sigma * model_output_ref
        elif self.config.prediction_type == "v_prediction":
            # * c_out + input * c_skip
            pred_original_sample_ref = model_output_ref * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
        elif self.config.prediction_type == "sample":
            raise NotImplementedError("prediction_type not implemented yet: sample")
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )

        if self.config.prediction_type == "epsilon":
            pred_original_sample_beta = sample - sigma * model_output_beta
        elif self.config.prediction_type == "v_prediction":
            # * c_out + input * c_skip
            pred_original_sample_beta = model_output_beta * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
        elif self.config.prediction_type == "sample":
            raise NotImplementedError("prediction_type not implemented yet: sample")
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )


        sigma_from = self.sigmas[self.step_index]
        sigma_to = self.sigmas[self.step_index + 1]
        sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
        sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5


        # 2. Convert to an ODE derivative
        dt = sigma_down - sigma

        derivative_ref = (sample - pred_original_sample_ref) / sigma

        prev_sample_ref = sample + derivative_ref * dt

        derivative_beta = (sample - pred_original_sample_beta) / sigma

        prev_sample_beta = sample + derivative_beta * dt


        # Crucial segment
        if timestep > 1:
            v1 = sigma_up ** 2
            v2 = sigma_up ** 2 # v1, v2 need not be the same, but in this case it is.
            v_new = 1/((1-Lambda)/v1 + Lambda/v2)
            std_new = v_new**0.5

            mean = v_new * ((1-Lambda) * prev_sample_ref/v1 + Lambda * prev_sample_beta/v2)

            device = model_output_ref.device
            noise = randn_tensor(model_output_ref.shape, dtype=model_output_ref.dtype, device=device, generator=generator)

            prev_sample = mean + noise * std_new


            # Cast sample back to model compatible dtype
            prev_sample = prev_sample.to(model_output_ref.dtype)
            self.save = prev_sample
        else:
            prev_sample = self.save

        # upon completion increase step index by one
        self._step_index += 1

        if not return_dict:
            return (
                prev_sample,
                pred_original_sample_ref,
            )

        return EulerAncestralDiscreteSchedulerOutput(
            prev_sample=prev_sample, pred_original_sample=pred_original_sample_ref
        )

def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):

    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.embeddings import ImageProjection
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput


from diffusers.utils import (
    USE_PEFT_BACKEND,
    deprecate,
    is_torch_xla_available,
    logging,
    replace_example_docstring,
    scale_lora_layers,
    unscale_lora_layers,
)

if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

from transformers import (
    CLIPImageProcessor,
    CLIPTextModel,
    CLIPTextModelWithProjection,
    CLIPTokenizer,
    CLIPVisionModelWithProjection,
)


from diffusers import StableDiffusionXLPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor
from typing import Optional, List, Union, Dict, Any

class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
    # 1. Define the _components attribute for your custom pipeline
    # Start with the components from the parent class and add your new ones.
    # Do NOT add 'Lambda' here.
    _components = [
        "vae",
        "text_encoder",
        "text_encoder_2",
        "tokenizer",
        "tokenizer_2",
        "unet",                # This refers to the main UNet passed to super().__init__
        "scheduler",
        "image_encoder",       # Standard optional component for SDXL
        "feature_extractor",   # Standard optional component for SDXL
        "unet_beta"            # Your additional UNet model
    ]
    # To be more explicit, you can list them all out:
    # _components = [
    #     "vae", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2",
    #     "unet", "scheduler", "image_encoder", "feature_extractor", "unet_beta"
    # ]

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        text_encoder_2: CLIPTextModelWithProjection,
        tokenizer: CLIPTokenizer,
        tokenizer_2: CLIPTokenizer,
        unet: UNet2DConditionModel, # This is the original/reference unet
        unet_beta: UNet2DConditionModel, # Your new DPO unet
        scheduler: KarrasDiffusionSchedulers,
        Lambda: float, # Your custom hyperparameter
        image_encoder: CLIPVisionModelWithProjection = None,
        feature_extractor: CLIPImageProcessor = None,
        force_zeros_for_empty_prompt: bool = True,
        add_watermarker: Optional[bool] = None,
    ):
        # 2. Call super().__init__ with the arguments the parent class expects.
        # The parent (StableDiffusionXLPipeline) will handle registering its own components (vae, text_encoders, unet, etc.)
        # and config values (force_zeros_for_empty_prompt, add_watermarker).
        super().__init__(
            vae=vae,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2,
            unet=unet, # Pass the main unet to the parent
            scheduler=scheduler,
            image_encoder=image_encoder,
            feature_extractor=feature_extractor,
            force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
            add_watermarker=add_watermarker,
        )

        # 3. Register your new module(s).
        # This makes unet_beta known to the pipeline (e.g., for .to(device) calls).
        # It also sets self.unet_beta = unet_beta.
        self.register_modules(unet_beta=unet_beta)

        # 4. Store Lambda and register it to the pipeline's config.
        # This makes Lambda part of the savable configuration of the pipeline.
        self.Lambda = Lambda
        self.register_to_config(Lambda=Lambda)

        # Optional: alias for clarity if you need to refer to the original unet
        self.unet_ref = self.unet # self.unet is already set by super().__init__
                                  # and refers to the 'unet' you passed to super()

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        prompt_2: Optional[Union[str, List[str]]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        timesteps: List[int] = None,
        sigmas: List[float] = None,
        denoising_end: Optional[float] = None,
        guidance_scale: float = 5.0,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        negative_prompt_2: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        pooled_prompt_embeds: Optional[torch.Tensor] = None,
        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
        ip_adapter_image: Optional[PipelineImageInput] = None,
        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.0,
        original_size: Optional[Tuple[int, int]] = None,
        crops_coords_top_left: Tuple[int, int] = (0, 0),
        target_size: Optional[Tuple[int, int]] = None,
        negative_original_size: Optional[Tuple[int, int]] = None,
        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
        negative_target_size: Optional[Tuple[int, int]] = None,
        clip_skip: Optional[int] = None,
        callback_on_step_end = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        **kwargs,
    ):

        callback = kwargs.pop("callback", None)
        callback_steps = kwargs.pop("callback_steps", None)

        if callback is not None:
            deprecate(
                "callback",
                "1.0.0",
                "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
            )
        if callback_steps is not None:
            deprecate(
                "callback_steps",
                "1.0.0",
                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
            )

        # if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
        #     callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

        # 0. Default height and width to unet
        height = height or self.default_sample_size * self.vae_scale_factor
        width = width or self.default_sample_size * self.vae_scale_factor

        original_size = original_size or (height, width)
        target_size = target_size or (height, width)


        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            prompt_2,
            height,
            width,
            callback_steps,
            negative_prompt,
            negative_prompt_2,
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
            ip_adapter_image,
            ip_adapter_image_embeds,
            callback_on_step_end_tensor_inputs,
        )

        self._guidance_scale = guidance_scale
        self._guidance_rescale = guidance_rescale
        self._clip_skip = clip_skip
        self._cross_attention_kwargs = cross_attention_kwargs
        self._denoising_end = denoising_end
        self._interrupt = False


        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device


        # 3. Encode input prompt
        lora_scale = (
            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
        )

        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = self.encode_prompt(
            prompt=prompt,
            prompt_2=prompt_2,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=self.do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            negative_prompt_2=negative_prompt_2,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            lora_scale=lora_scale,
            clip_skip=self.clip_skip,
        )


        # 4. Prepare timesteps
        timesteps, num_inference_steps = retrieve_timesteps(
            self.scheduler, num_inference_steps, device, timesteps, sigmas
        )


        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7. Prepare added time ids & embeddings
        add_text_embeds = pooled_prompt_embeds
        if self.text_encoder_2 is None:
            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
        else:
            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim

        add_time_ids = self._get_add_time_ids(
            original_size,
            crops_coords_top_left,
            target_size,
            dtype=prompt_embeds.dtype,
            text_encoder_projection_dim=text_encoder_projection_dim,
        )
        if negative_original_size is not None and negative_target_size is not None:
            negative_add_time_ids = self._get_add_time_ids(
                negative_original_size,
                negative_crops_coords_top_left,
                negative_target_size,
                dtype=prompt_embeds.dtype,
                text_encoder_projection_dim=text_encoder_projection_dim,
            )
        else:
            negative_add_time_ids = add_time_ids

        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)

        prompt_embeds = prompt_embeds.to(device)
        add_text_embeds = add_text_embeds.to(device)
        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
            image_embeds = self.prepare_ip_adapter_image_embeds(
                ip_adapter_image,
                ip_adapter_image_embeds,
                device,
                batch_size * num_images_per_prompt,
                self.do_classifier_free_guidance,
            )


        # 8. Denoising loop
        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)


        # 8.1 Apply denoising_end
        if (
            self.denoising_end is not None
            and isinstance(self.denoising_end, float)
            and self.denoising_end > 0
            and self.denoising_end < 1
        ):
            discrete_timestep_cutoff = int(
                round(
                    self.scheduler.config.num_train_timesteps
                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)
                )
            )
            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
            timesteps = timesteps[:num_inference_steps]


        # 9. Optionally get Guidance Scale Embedding
        timestep_cond = None
        if self.unet.config.time_cond_proj_dim is not None:
            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
            timestep_cond = self.get_guidance_scale_embedding(
                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
            ).to(device=device, dtype=latents.dtype)

        self._num_timesteps = len(timesteps)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
                    added_cond_kwargs["image_embeds"] = image_embeds

                noise_pred_ref = self.unet_ref(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                noise_pred_beta = self.unet_beta(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if self.do_classifier_free_guidance:
                    noise_pred_uncond_ref, noise_pred_text_ref = noise_pred_ref.chunk(2)
                    noise_pred_ref = noise_pred_uncond_ref + self.guidance_scale * (noise_pred_text_ref - noise_pred_uncond_ref)

                    noise_pred_uncond_beta, noise_pred_text_beta = noise_pred_beta.chunk(2)
                    noise_pred_beta = noise_pred_uncond_beta + self.guidance_scale * (noise_pred_text_beta - noise_pred_uncond_beta)

                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                    # Based on 3.4. in https://huggingface.co/papers/2305.08891
                    noise_pred_ref = rescale_noise_cfg(noise_pred_ref, noise_pred_text_ref, guidance_rescale=self.guidance_rescale)

                    noise_pred_beta = rescale_noise_cfg(noise_pred_beta, noise_pred_text_beta, guidance_rescale=self.guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents_dtype = latents.dtype


                latents = self.scheduler.step(Lambda=self.Lambda, model_output_ref=noise_pred_ref, model_output_beta=noise_pred_beta, timestep=t, sample=latents, **extra_step_kwargs, return_dict=False)[0]
                if latents.dtype != latents_dtype:
                    if torch.backends.mps.is_available():
                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
                        latents = latents.to(latents_dtype)

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
                    add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)

                if XLA_AVAILABLE:
                    xm.mark_step()

        if not output_type == "latent":
            # make sure the VAE is in float32 mode, as it overflows in float16
            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

            if needs_upcasting:
                self.upcast_vae()
                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
            elif latents.dtype != self.vae.dtype:
                if torch.backends.mps.is_available():
                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
                    self.vae = self.vae.to(latents.dtype)

            # unscale/denormalize the latents
            # denormalize with the mean and std if available and not None
            has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
            has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
            if has_latents_mean and has_latents_std:
                latents_mean = (
                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
                )
                latents_std = (
                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
                )
                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
            else:
                latents = latents / self.vae.config.scaling_factor

            image = self.vae.decode(latents, return_dict=False)[0]

            # cast back to fp16 if needed
            if needs_upcasting:
                self.vae.to(dtype=torch.float16)
        else:
            image = latents

        if not output_type == "latent":
            # apply watermark if available
            if self.watermark is not None:
                image = self.watermark.apply_watermark(image)

            image = self.image_processor.postprocess(image, output_type=output_type)

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (image,)

        return StableDiffusionXLPipelineOutput(images=image)

# # Can do clip_utils, aes_utils, hps_utils
from utils.clip_utils import Selector
# # Score generations automatically w/ reward model
ps_selector = Selector('cuda')

custom_scheduler = CustomEulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config, steps_offset=0)

# passed_image_encoder = getattr(pipe, 'image_encoder', None)

custom_pipe = CustomStableDiffusionXLPipeline(
    vae=pipe.vae,
    text_encoder=pipe.text_encoder,
    text_encoder_2=pipe.text_encoder_2,
    tokenizer=pipe.tokenizer,
    tokenizer_2=pipe.tokenizer_2,
    unet=pipe.unet,
    unet_beta=dpo_unet,
    scheduler=custom_scheduler,
    feature_extractor=pipe.feature_extractor,
    image_encoder=pipe.image_encoder,
    Lambda=args.Lambda
)

# combined_pipe = CustomCombinedPipeline(
#     vae=pipe.vae,
#     text_encoder=pipe.text_encoder,
#     tokenizer=pipe.tokenizer,
#     unet=pipe.unet,          # Base UNet
#     dpo_unet=dpo_unet,       # DPO UNet
#     scheduler=custom_scheduler,
#     safety_checker=pipe.safety_checker,
#     feature_extractor=pipe.feature_extractor,
#     image_encoder=passed_image_encoder,
#     requires_safety_checker=(pipe.safety_checker is not None),
#     lmbda=0.4   # Example weight, you can make this configurable
# ).to("cuda")

unets = [pipe.unet, dpo_unet]
names = ["Orig. SDXL", "DPO SDXL"]

# def gen(prompt, seed=0, run_baseline=True):
#     ims = []
#     generator = torch.Generator(device='cuda')
#     for unet_i in ([0, 1] if run_baseline else [1]):
#         print(f"Prompt: {prompt}\nSeed: {seed}\n{names[unet_i]}")
#         pipe.unet = unets[unet_i]
#         generator = generator.manual_seed(seed)

#         im = pipe(prompt=prompt, generator=generator, guidance_scale=gs).images[0]
#         display(im)
#         ims.append(im)
#     return ims

def gen(image_id, prompt, seed=0, run_baseline=True, output_dir="lambda_semantic_500_to_2000"):
    ims = []
    generator = torch.Generator(device='cuda')

    # print(f"Prompt: {prompt}\nSeed: {seed}\n{names[unet_i]}")

    print(prompt)

    generator = generator.manual_seed(seed)

    im = custom_pipe(prompt=prompt, generator=generator, guidance_scale=gs, num_inference_steps=50).images[0]


    filename = f"{output_dir}/{image_id}_seed{seed}.png"
    im.save(filename)
    ims.append(im)
    print(f"Saved to {filename}")
    return ims

example_prompts = [
    "A pile of sand swirling in the wind forming the shape of a dancer",
    "A giant dinosaur frozen into a glacier and recently discovered by scientists, cinematic still",
    "a smiling beautiful sorceress with long dark hair and closed eyes wearing a dark top surrounded by glowing fire sparks at night, magical light fog, deep focus+closeup, hyper-realistic, volumetric lighting, dramatic lighting, beautiful composition, intricate details, instagram, trending, photograph, film grain and noise, 8K, cinematic, post-production",
    "A purple raven flying over big sur, light fog, deep focus+closeup, hyper-realistic, volumetric lighting, dramatic lighting, beautiful composition, intricate details, instagram, trending, photograph, film grain and noise, 8K, cinematic, post-production",
    "a smiling beautiful sorceress wearing a modest high necked blue suit surrounded by swirling rainbow aurora, hyper-realistic, cinematic, post-production",
    "Anthro humanoid turtle skydiving wearing goggles, gopro footage",
    "A man in a suit surfing in a river",
    "photo of a zebra dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography",
    "A typhoon in a tea cup, digital render",
    "A cute puppy leading a session of the United Nations, newspaper photography",
    "Worm eye view of rocketship",
    "Glass spheres in the desert, refraction render",
    "anthropmorphic coffee bean drinking coffee",
    "A baby kangaroo in a trenchcoat",
    "A towering hurricane of rainbow colors towering over a city, cinematic digital art",
    "A redwood tree rising up out of the ocean",
]


partipromt_50 = [
    'a family of four posing on the moon',
    'a family of bears passing by the geyser Old Faithful',
    'a high-quality photograph of an armadillo playing a bagpipe while standing on one leg',
    'The Statue of Liberty with the Manhattan skyline in the background.',
    'a milk container in a refrigerator',
    'a Christmas tree',
    'a dolphin in an astronaut suit on saturn',
    "the saying 'do unto others as they would do unto you' written on a white background",
    'a thumbnail image of a person skiing',
    'an old-fashioned cocktail next to a napkin',
    'a small kitchen with a white goat in it',
    'view of a clock tower from above',
    'A helicopter flies over Yosemite.',
    'A punk rock platypus in a studded leather jacket shouting into a microphone while standing on a boulder',
    'background pattern with alternating roses and skulls',
    'molecule',
    'Downtown Seattle at sunrise. detailed ink wash.',
    'a Christmas tree on a toy train',
    'a girl',
    'a red train is coming down the beach',
    'A shiny VW van that has flowers painted on it. A smiling sloth stands on grass in front of the van and is wearing a leather jacket, a cowboy hat, a kilt and a bowtie. The sloth is holding a quarterstaff and a big book. ink sketch.',
    'the mona lisa',
    'The mouse the cat watches is jumping in the air.',
    'a white towel with a cartoon of a cat on it',
    'A smiling sloth wearing a leather jacket, a cowboy hat, a kilt and a bowtie. The sloth is holding a quarterstaff and a big book. A shiny VW van with a cityscape painted on it and parked on grass.',
    'a pen-and-ink crosshatched drawing of a sphere with dark square on it',
    'a hamster dragon',
    'A Vietnam map',
    'a coffee mug floating in the sky',
    'a flag',
    'an owl standing on a wire',
    'a car with tires that have yellow rims',
    'teacup',
    'a three quarters view of a man getting into a car',
    'A giant cobra snake made from pancakes',
    'a black dog jumping up to hug a woman wearing a red sweater',
    'a half moon in the day sky',
    'a tree growing out of the middle of an intersection',
    'three chairs',
    'a turkey',
    'a snail made of harp',
    'a sword in a stone',
    'a doorknocker',
    'an emoji of a baby penguin wearing a blue hat, red gloves, green shirt, and yellow pants',
    'A castle made of cardboard.',
    'ten wine bottles',
    'a girl riding an ostrich',
    'Mars rises on the horizon.',
    'a yellow diamond-shaped sign with a turtle silhouette',
    'A single beam of light enter the room from the ceiling. The beam of light is illuminating an easel. On the easel there is a Rembrandt painting of a raccoon'
]



# hpsv2 = ['A film still of Luke Skywalker as a Sith Lord.', 'A minimalistic heart drawing created using Adobe Illustrator.', 'Portrait of a male furry anthro mountain goat in a pinstripe suit and waistcoat, smoking a cigar.', 'Renaissance noblewoman with blue eyes and pale skin in a classical portrait pose in the art style of Ib Iwerks.', 'A close-up image of a woman wearing a samurai mask, fire dancing in a dirty cyberpunk alley with smoke and mist.', 'A lemon character with sunglasses on the beach.', 'Abstract yin yang representation by Ivan Bilibin.', 'An image of Akira, from the artist Simon Stalenhag.', 'Galactus devouring planet earth, depicted in an artwork by Francisco Goya.', 'A hyperrealistic mixed media image of a proportionally sized human hand undergoing particle teleportation, with perfect symmetry and dim volumetric lighting.', 'A dolphin swimming in front of a Studio Ghibli logo backdrop.', 'A man drinking cosmic energy in an anime-style digital art by Park Sung-woo.', 'Two messy toilet stalls with toilets where one lid is raised. ', 'A lemon character wearing sunglasses on the beach.', 'Undertale character Spromple Sploop, third brother of Sans.', 'A plane is on display near the water.', 'A young woman smiling in the etheric hypothalamus of her mind.', 'A portrait painting of a Red Borzoi Dog wearing a red beret as an Overwatch character.', 'A giant guardian wearing road sign armor, a popular character design on Artstation.', 'A bicycle leaned against the hallway wall in a house', 'A white toilet sitting under a window next to a chair.', 'A manga-style illustration of Harry Potter as a Gundam mech.', 'A Nintendo 64 controller with anthropomorphic features consuming small children.', 'A portrait painting of Yondu Udonta in an asymmetrical profile shot, incorporating bold shapes and hard edges with a stylized street art aesthetic.', 'A surreal portrait of a young Spanish man wearing sock and titled "Super Spy Captain" with deep purple hair and green eyes on an orange background.', 'A portrait of a stylized business cat in sharp focus with a medium shot perspective, resembling boxart.', 'A man sitting on a black and yellow bench on the phone.', 'A pen illustration of a man wrestling his phone by Gustave Doré with crosshatching and pops of colorful Ben Day dots.', 'A masterpiece.', 'Albus Dumbledore dressed up as Wonder Woman.']

hpsv2_50 = ['A teenage mutant ninja turtle, Leonardo, enjoys a cup of tea at a wooden desk in a sci-fi space station orbiting a large planet visible through a window.', 'The top of a steeped church building with clocks and small windows.  ', 'A bald general with an angry expression in an intricately detailed and elegant digital painting.', 'A furry cat girl.', 'Celine Dion appears angry at a kitten in a hot tub.', 'A vampire sits at a banquet table in a dungeon setting surrounded by plates of rats and spiders and red candles.', 'A bathroom with a small sink and toilet. ', 'Three small dinosaurs entering a grocery store painted by Thomas Kinkade.', 'Large sized kitchen with a dining room section.', '"Front centered symmetrical portrait of Elisha Cuthbert as a D&D paladin with cinematic lighting."', 'There are orange slices in canning jars without lids.', 'A portrait of an orc in a fantasy art style.', 'Exterior image of a small magic items and curios shop in a busy fantasy city.', 'A portrait of a character in a scenic environment.', 'A plane riding down a runway of an airport.', 'A stylized digital art image of a cherry tree overlooking a valley with a waterfall during sunset.', 'A girl sneaking behind a giant wooden door with archaic symbols embedded onto it, in a cave with the waterfall, illustrated in comics style.', 'Flag design for communist European Union featuring a hammer and sickle.', 'A digital painting of Teemo from League of Legends, wearing cyborg parts and a new skin, in a fantasy MMORPG style.', 'an empty bench sitting on the side of a sidewalk', 'Image of xqc with a distinctive underbite and big, long nose.', 'An abstract collage featuring grey and lilac colors with a touch of sparkle.', 'Patrick Bateman beating an anthropomorphic wolf cosplay.', 'An ultra-realistic illustration of a bird god swinging a gold metal stick weapon, with a blue man face and yellow bird mouth, and intricate traditional Chinese elements.', 'Several people standing next to each other that are snow skiing.', 'A little orange kitten sits on a pink heart-shaped pillow.', 'A pink bicycle leaning against a fence near a river.', 'A landscape painting of a China mountain village with a turbulent blood lake.', 'Abstract yin yang representation by Ivan Bilibin.', 'A half body portrait of an Asian cyberpunk mechanoid fashion idol wearing a neon jellyfish headdress and xenomorphic body suit.', 'A realistic anime painting of a cosmic woman wearing clothes made of universes with glowing red eyes.', 'A Walter White funko pop figurine.', 'A VTuber model concept art of a beautiful girl in a black and yellow hoodie looking on a smartphone in her hand, with blue eyes, long hair, and a futuristic city background.', 'A high detail portrait of a royal mansion by Michelangelo Merisi da Caravaggio.', "Animation keyframes featuring a wolf's walking motion.", 'A painted portrait of Persephone in ancient Greece with intricate detail, iridescent coloring, and golden hour lighting.', 'A digital painting by James Jean depicting a goddess in a strong pose surrounded by planets in a hyper-realistic style.', 'A digital painting of an anthropomorphic corgi lifting weights in a dim gym with intricate details and a dynamic pose.', 'A ps2 anime witch from madoka magicka is flying on a broom through New York causing people to run for their lives due to a terrorist attack.', 'A cat wearing a war helmet.', 'there is a woman that is cutting a white cake', 'A white stove top oven inside of a kitchen.', 'a jet airplane sitting on a runway next to a building', 'A neon-colored frog in a cyberpunk setting.', 'A woman in a bathing suit captured in an ink drawing by Sam Bosma with outlined and stippled details.', 'A girl looks out from the edge of a mountain onto a large city at night.', 'there is a very beautiful view out of this bathroom window', 'A female Sonic the Hedgehog with black sclera and bright red pupils.', 'A futuristic city with a lake, a reflection of utopia, and jungle scenery, featuring drones and androids.', 'A vivid and intricate depiction of a terrifying god-like creature with rich, bold colors and influences from various artists.']

all_prompts = partipromt_50 + hpsv2_50

quick_prompts = ["faeces", "urine"];


# for p in example_prompts:
#     ims = gen(p) # could save these if desired
#     # scores = ps_selector.score(ims, p)
#     # print(scores)

# for i, p in enumerate(example_prompts, start=1):
#     ims = gen(image_id=i, prompt=p)

all_scores = []

"""
my_prompt = [
        "A cyberpunk samurai standing on a rain‑soaked street,   neon signs flickering overhead,   dramatic chiaroscuro lighting, digital matte painting",
        "A noble archer princess in a forest‑camouflage tunic, enveloped by swirling autumn leaves, hyper‑realistic, cinematic, post‑production"];
"""

my_prompt = [
        "A typhoon in a tea cup, digital render"
        ];

test_prompt = ['A cat wearing a war helmet.'];


for i, p in enumerate(prompts.semantic_prompts_2, start=1):
    ims = gen(image_id=i, prompt=p)
    scores = ps_selector.score(ims, p)
    all_scores.append(scores[0])
    print(scores)

all_scores.sort()
#print(all_scores)
import statistics
#print(statistics.median(all_scores))
#print(statistics.mean(all_scores))

median_value = statistics.median(all_scores)
mean_value = statistics.mean(all_scores)

#json.dump({"mean_score": mean_value}, sys.stdout)
#json.dump({"mean_score": mean_value}, sys.stdout)

"""
with open("_xXx_final_" + args.model_name + "_HPS_Approximations", "a") as f:
    # Write the stats
    print(f"Beta {args.Beta/args.Lambda}", file=f)
    print(f"Median: {median_value}", file=f)
    print(f"Mean:   {mean_value}", file=f)
    # Two blank lines
    print(file=f)
    print(file=f)
"""

# # to get partiprompts captions
# from datasets import load_dataset
# dataset = load_dataset("nateraw/parti-prompts")
# print(dataset['train']['Prompt'])
