import pdb
from typing import Dict, List, Optional, Tuple, Union
import torch
import wandb
import PIL.Image
from torch.nn import MSELoss, KLDivLoss
from diffusers import DiTPipeline
from diffusers.configuration_utils import register_to_config
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.models import AutoencoderKL, DiTTransformer2DModel
from diffusers.schedulers import KarrasDiffusionSchedulers, DDIMScheduler
from diffusers.utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator]=None, sample_mode: str='sample'):
    if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
        return encoder_output.latent_dist.sample(generator)
    elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
        return encoder_output.latent_dist.mode()
    elif hasattr(encoder_output, 'latents'):
        return encoder_output.latents
    else:
        raise AttributeError('Could not access latents of provided encoder_output')
class DiTPipelineDAP(DiTPipeline):
    model_cpu_offload_seq = 'transformer->vae'
    def __init__(self, transformer: DiTTransformer2DModel, vae: AutoencoderKL, scheduler: DDIMScheduler, id2label: Optional[Dict[int, str]]=None, guidance_loss_func=None, wandb_run=None):
        super().__init__(transformer=transformer, vae=vae, scheduler=scheduler, id2label=id2label)
        self.labels = {}
        if id2label is not None:
            for key, value in id2label.items():
                for label in value.split(','):
                    self.labels[label.lstrip().rstrip()] = int(key)
            self.labels = dict(sorted(self.labels.items()))
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
        self.guidance_loss_func = guidance_loss_func or MSELoss()
        self.wandb_enabled = wandb_run is not None
        self.wandb_run = wandb_run
    def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
        if not isinstance(label, list):
            label = list(label)
        for l in label:
            if l not in self.labels:
                raise ValueError(f'{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}.')
        return [self.labels[l] for l in label]
    def get_layer_idx(self, layer_idx):
        num_layers = 28
        original_idx = layer_idx
        if layer_idx < 0:
            layer_idx = num_layers + layer_idx
        if layer_idx >= num_layers:
            import warnings
            warnings.warn(f'指定的层索引 {original_idx} 超出了可用层范围 (0-{num_layers - 1})。将使用最后一层 ({num_layers - 1}) 的特征。')
            layer_idx = num_layers - 1
        return layer_idx
    def prepare_image_latents(self, image, timestep, batch_size, dtype, device, generator=None):
        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
            raise ValueError(f'`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}')
        image = image.to(device=device, dtype=dtype)
        if image.shape[1] == 4:
            init_latents = image
        else:
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(f'You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}. Make sure the batch size matches the length of the generators.')
            elif isinstance(generator, list):
                if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
                    image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
                elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
                    raise ValueError(f'Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} ')
                init_latents = [retrieve_latents(self.vae.encode(image[i:i + 1]), generator=generator[i]) for i in range(batch_size)]
                init_latents = torch.cat(init_latents, dim=0)
            else:
                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
            init_latents = self.vae.config.scaling_factor * init_latents
        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
            deprecation_message = f'You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial images (`image`). Initial images are now duplicating to match the number of text prompts. Note that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update your script to pass as many initial images as text prompts to suppress this warning.'
            deprecate('len(prompt) != len(image)', '1.0.0', deprecation_message, standard_warn=False)
            additional_image_per_prompt = batch_size // init_latents.shape[0]
            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
            repeats = (batch_size + init_latents.shape[0] - 1) // init_latents.shape[0]
            init_latents = init_latents.repeat(repeats, 1, 1, 1)[:batch_size]
        else:
            init_latents = torch.cat([init_latents], dim=0)
        shape = init_latents.shape
        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
        latents = init_latents
        return latents
    @anonymous.edu()
    def __call__(self, class_labels: List[int], reference_images: PipelineImageInput=None, guidance_scale: float=4.0, guidance_step_size: float=1.0, time_travel: int=25, feature_layer_idx: int=-2, generator: Optional[Union[torch.Generator, List[torch.Generator]]]=None, num_inference_steps: int=50, output_type: Optional[str]='pil', return_dict: bool=True) -> Union[ImagePipelineOutput, Tuple]:
        device = self._execution_device
        batch_size = len(class_labels)
        latent_size = self.transformer.config.sample_size
        latent_channels = self.transformer.config.in_channels
        latents = randn_tensor(shape=(batch_size, latent_channels, latent_size, latent_size), generator=generator, device=self._execution_device, dtype=self.transformer.dtype)
        latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
        class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1)
        class_null = torch.tensor([1000] * batch_size, device=self._execution_device)
        class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
        if self.wandb_enabled:
            self.metrics = {'total_loss': 0.0, 'avg_loss': 0.0, 'step_losses': []}
        global_step = 0
        image = reference_images
        self.scheduler.set_timesteps(num_inference_steps)
        for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
            if i <= time_travel:
                image_features = []
                for batch in image.split(10, dim=0):
                    image_latents = self.prepare_image_latents(batch, t, batch.shape[0], latents.dtype, device, generator)
                    image_latent_model_input = self.scheduler.scale_model_input(image_latents, t)
                    timesteps = t.expand(image_latent_model_input.shape[0]).to(image_latent_model_input.device)
                    try:
                        ind = self.get_layer_idx(feature_layer_idx)
                    except:
                        raise ValueError('Hidden states out of range.')
                    try:
                        batch_output = self.transformer(image_latent_model_input, timestep=timesteps, class_labels=class_labels, return_hidden_state_layer=ind)
                        image_features.append(batch_output.hidden_states)
                    except Exception as e:
                        print(f'Error processing batch: {e}')
                        continue
                image_features = torch.cat(image_features, dim=0).mean(0, keepdim=True)
            else:
                pass
            if guidance_scale > 1:
                half = latent_model_input[:len(latent_model_input) // 2]
                latent_model_input = torch.cat([half, half], dim=0)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            timesteps = t
            if not torch.is_tensor(timesteps):
                is_mps = latent_model_input.device.type == 'mps'
                if isinstance(timesteps, float):
                    dtype = torch.float32 if is_mps else torch.float64
                else:
                    dtype = torch.int32 if is_mps else torch.int64
                timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
            elif len(timesteps.shape) == 0:
                timesteps = timesteps[None].to(latent_model_input.device)
            timesteps = timesteps.expand(latent_model_input.shape[0])
            noise_pred = self.transformer(latent_model_input, timestep=timesteps, class_labels=class_labels_input).sample
            if guidance_scale > 1:
                eps, rest = (noise_pred[:, :latent_channels], noise_pred[:, latent_channels:])
                cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
                half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
                eps = torch.cat([half_eps, half_eps], dim=0)
                noise_pred = torch.cat([eps, rest], dim=1)
            if i <= time_travel:
                if self.guidance_loss_func is not None:
                    with torch.enable_grad():
                        if guidance_scale > 1:
                            current_latents = latent_model_input[:batch_size].clone()
                        else:
                            current_latents = latent_model_input.clone()
                        current_latents.requires_grad_(True)
                        try:
                            ind = self.get_layer_idx(feature_layer_idx)
                        except:
                            raise ValueError('Hidden states out of range.')
                        timesteps = t.expand(current_latents.shape[0]).to(current_latents.device)
                        generated_features = self.transformer(current_latents, timestep=timesteps, class_labels=class_labels, return_hidden_state_layer=ind).hidden_states
                        loss = self.guidance_loss_func(generated_features, image_features)
                        if self.wandb_enabled:
                            try:
                                self.metrics['step_losses'].append(loss.item())
                                self.metrics['total_loss'] += loss.item()
                                self.metrics['avg_loss'] = self.metrics['total_loss'] / (i + 1)
                                self.wandb_run.log({'step_loss': loss.item(), 'average_loss': self.metrics['avg_loss']})
                            except Exception as e:
                                print(f'Error logging to wandb: {e}')
                        global_step += 1
                        guide_grad = torch.autograd.grad(loss, current_latents, retain_graph=True)[0]
                        if guidance_scale > 1:
                            guide_grad = torch.cat([guide_grad, guide_grad], dim=0)
                        if noise_pred.shape[1] > guide_grad.shape[1]:
                            if i <= time_travel:
                                noise_part, var_part = torch.split(noise_pred, latent_channels, dim=1)
                                noise_part = noise_part + torch.sqrt(1 - self.scheduler.alphas_cumprod[t]) * guidance_step_size * guide_grad
                                noise_pred = torch.cat([noise_part, var_part], dim=1)
                            else:
                                pass
                        else:
                            noise_pred = noise_pred + torch.sqrt(1 - self.scheduler.alphas_cumprod[t]) * guidance_step_size * guide_grad
                        del generated_features, image_features, loss, guide_grad
            else:
                pass
            if self.transformer.config.out_channels // 2 == latent_channels:
                model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
            else:
                model_output = noise_pred
            latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample
        if guidance_scale > 1:
            latents, _ = latent_model_input.chunk(2, dim=0)
        else:
            latents = latent_model_input
        latents = 1 / self.vae.config.scaling_factor * latents
        samples = self.vae.decode(latents).sample
        samples = (samples / 2 + 0.5).clamp(0, 1)
        samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
        if output_type == 'pil':
            samples = self.numpy_to_pil(samples)
        self.maybe_free_model_hooks()
        if not return_dict:
            return (samples,)
        return ImagePipelineOutput(images=samples)