# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
# with the following modifications:
# - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`.
# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
from typing import Any, Dict, List, Optional, Union
import torch
import random
from typing import Callable
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
from .sd3_sde_with_logprob_multis1 import sde_step_with_logprob

@torch.no_grad()
def pipeline_with_logprob(
    self,
    prompt: Union[str, List[str]] = None,
    prompt_2: Optional[Union[str, List[str]]] = None,
    prompt_3: Optional[Union[str, List[str]]] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 28,
    mini_num_image_per_prompt: int = 1,
    sigmas: Optional[List[float]] = None,
    guidance_scale: float = 7.0,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    negative_prompt_2: Optional[Union[str, List[str]]] = None,
    negative_prompt_3: Optional[Union[str, List[str]]] = None,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    clip_skip: Optional[int] = None,
    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
    max_sequence_length: int = 256,
    skip_layer_guidance_scale: float = 2.8,
    noise_level: float = 0.7,
    train_num_steps: int = 1,
    process_index: int = 0,
    random_timestep: Optional[int] = None,
    reward_fn: Optional[Callable] = None,
    reward_prompt: Optional[List[str]] = None,
    prompt_metadata: Optional[Dict[str, Any]] = None,
):
    height = height or self.default_sample_size * self.vae_scale_factor
    width = width or self.default_sample_size * self.vae_scale_factor

    # 1. Check inputs. Raise error if not correct
    self.check_inputs(
        prompt,
        prompt_2,
        prompt_3,
        height,
        width,
        negative_prompt=negative_prompt,
        negative_prompt_2=negative_prompt_2,
        negative_prompt_3=negative_prompt_3,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
        max_sequence_length=max_sequence_length,
    )

    self._guidance_scale = guidance_scale
    self._skip_layer_guidance_scale = skip_layer_guidance_scale
    self._clip_skip = clip_skip
    self._joint_attention_kwargs = joint_attention_kwargs
    self._interrupt = False

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = self._execution_device

    lora_scale = (
        self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
    )
    (
        prompt_embeds,
        negative_prompt_embeds,
        pooled_prompt_embeds,
        negative_pooled_prompt_embeds,
    ) = self.encode_prompt(
        prompt=prompt,
        prompt_2=prompt_2,
        prompt_3=prompt_3,
        negative_prompt=negative_prompt,
        negative_prompt_2=negative_prompt_2,
        negative_prompt_3=negative_prompt_3,
        do_classifier_free_guidance=self.do_classifier_free_guidance,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        device=device,
        clip_skip=self.clip_skip,
        max_sequence_length=max_sequence_length,
        lora_scale=lora_scale,
    )

    # 4. Prepare latent variables
    num_channels_latents = self.transformer.config.in_channels
    latents = self.prepare_latents(
        batch_size,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )

    # 5. Prepare timesteps
    scheduler_kwargs = {}
    timesteps, num_inference_steps = retrieve_timesteps(
        self.scheduler,
        num_inference_steps,
        device,
        sigmas=sigmas,
        **scheduler_kwargs,
    )
    num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
    self._num_timesteps = len(timesteps)

    random.seed(process_index)
    if random_timestep is None:
        random_timestep = random.randint(0, num_inference_steps-train_num_steps)

    # 对num_inference_steps 均匀采 train_num_steps 个点
    # 要求第一个点为1，最后一个点为num_inference_steps-1的整数
    random_timesteps = torch.linspace(1, num_inference_steps-1, train_num_steps).long()

    # 6. Prepare image embeddings
    pre_latent = []
    post_latent = []
    all_log_probs = []
    all_timesteps = []
    rewards = []

    if self.do_classifier_free_guidance:
        tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
    # 7. Denoising loop
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):

            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
            timestep = t.expand(latent_model_input.shape[0])
            
            noise_pred = self.transformer(
                hidden_states=latent_model_input,
                timestep=timestep,
                encoder_hidden_states=tem_prompt_embeds,
                pooled_projections=tem_pooled_prompt_embeds,
                joint_attention_kwargs=self.joint_attention_kwargs,
                return_dict=False,
            )[0]
            noise_pred = noise_pred.to(prompt_embeds.dtype)
            # perform guidance
            if self.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
                
            latents_dtype = latents.dtype

            if i == 0:
                cur_noise_level = noise_level
                # 初始步骤分出偏好对
                sde_latents = latents.repeat(2, 1, 1, 1)
                repeated_noise_pred = noise_pred.repeat(2, 1, 1, 1)
                sde_latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
                    self.scheduler, 
                    repeated_noise_pred.float(), 
                    t.unsqueeze(0), 
                    sde_latents.float(),
                    noise_level=cur_noise_level,
                )
                if sde_latents.dtype != latents_dtype:
                    sde_latents = sde_latents.to(latents_dtype)
                rewards = reward_fn(sde_latents, reward_prompt, prompt_metadata, t.unsqueeze(0))
                rewards = rewards[0]['latent_score_sd3']
                # 选择rewards最大和最下的2个sde_latents
                max_reward_latent = sde_latents[torch.argmax(rewards)]
                min_reward_latent = sde_latents[torch.argmin(rewards)]
                
                latents = torch.stack([max_reward_latent, min_reward_latent], dim=0)
                if self.do_classifier_free_guidance:
                    repeated_prompt_embeds = prompt_embeds.repeat(2, 1, 1)
                    repeated_negative_prompt_embeds = negative_prompt_embeds.repeat(2, 1, 1)
                    repeated_pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
                    repeated_negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(2, 1)
                    tem_prompt_embeds = torch.cat([repeated_negative_prompt_embeds, repeated_prompt_embeds], dim=0)
                    tem_pooled_prompt_embeds = torch.cat([repeated_negative_pooled_prompt_embeds, repeated_pooled_prompt_embeds], dim=0)
                else:
                    tem_prompt_embeds = prompt_embeds.repeat(2, 1, 1)
                    tem_pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)

            elif i in random_timesteps:
                cur_noise_level= noise_level
                pre_latent.append(latents)

                # 将latents repeat mini_num_image_per_prompt次
                sde_latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
                repeated_noise_pred = noise_pred.repeat(mini_num_image_per_prompt, 1, 1, 1)
                sde_latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
                    self.scheduler, 
                    repeated_noise_pred.float(), 
                    t.unsqueeze(0), 
                    sde_latents.float(),
                    noise_level=cur_noise_level,
                )
                if sde_latents.dtype != latents_dtype:
                    sde_latents = sde_latents.to(latents_dtype)
            
                rewards = reward_fn(sde_latents, reward_prompt, prompt_metadata, t.unsqueeze(0))
                rewards_score = rewards[0]['latent_score_sd3']
                # 选择rewards最大和最下的2个sde_latents
                good_rewards,bad_rewards = torch.chunk(rewards_score, 2, dim=1)
                good_latents,bad_latents = sde_latents.chunk(2, dim=0)
                good_log_prob, bad_log_prob = log_prob.chunk(2, dim=0)
                
                max_good_reward = good_rewards[torch.argmax(good_rewards)]
                min_bad_reward = bad_rewards[torch.argmin(bad_rewards)]
                max_good_reward_latent = good_latents[torch.argmax(good_rewards)]
                min_bad_reward_latent = bad_latents[torch.argmin(bad_rewards)]
                max_good_log_prob = good_log_prob[torch.argmax(good_rewards)]
                min_bad_log_prob = bad_log_prob[torch.argmin(bad_rewards)]

                post_latent.append(torch.stack([max_good_reward_latent, min_bad_reward_latent], dim=0))
                all_log_probs.append(torch.stack([max_good_log_prob, min_bad_log_prob], dim=0))
                all_timesteps.append(t.repeat(len(sde_latents)//2))
                
                latents = torch.stack([max_good_reward_latent, min_bad_reward_latent], dim=0)
                rewards.append(torch.stack([max_good_reward, min_bad_reward], dim=0))

            else:
                cur_noise_level= 0
                latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
                    self.scheduler, 
                    noise_pred.float(), 
                    t.unsqueeze(0), 
                    latents.float(),
                    noise_level=cur_noise_level,
                )

            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
            
    latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
    latents = latents.to(dtype=self.vae.dtype)
    image = self.vae.decode(latents, return_dict=False)[0]
    image = self.image_processor.postprocess(image, output_type=output_type)

    # Offload all models
    self.maybe_free_model_hooks()
    return image, pre_latent, post_latent, all_log_probs, all_timesteps, rewards
