from typing import Callable, Dict, List, Optional, Union

import numpy as np
import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from diffusers.image_processor import PipelineImageInput
from diffusers.models import AutoencoderKLTemporalDecoder
from diffusers.schedulers import EulerDiscreteScheduler
from diffusers.utils import BaseOutput, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
    StableVideoDiffusionPipeline,
    StableVideoDiffusionPipelineOutput,
    _resize_with_antialiasing,
    retrieve_timesteps,
)

from model.track_unet import TrackUnet
from model.video_gen_unet import VideoGenUnet
import copy

logger = logging.get_logger(__name__)

def _append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
    return x[(...,) + (None,) * dims_to_append]
    

class DualUnetVideoDiffusionPipeline(StableVideoDiffusionPipeline):
    r"""
    Pipeline for video generation using dual UNet architecture.
    
    This pipeline uses two UNets:
    1. TrackUnet: processes condition_video to extract tracking information
    2. VideoGenUnet: generates final video using condition_image and tracking information
    
    This model inherits from [`StableVideoDiffusionPipeline`]. Check the superclass documentation 
    for the generic methods implemented for all pipelines.

    Args:
        vae ([`AutoencoderKLTemporalDecoder`]):
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
            Frozen CLIP image-encoder.
        condition_track_unet ([`TrackUnet`]):
            A `TrackUnet` to extract tracking information from condition video.
        unet ([`VideoGenUnet`]):
            A `VideoGenUnet` to generate final video using tracking information.
        scheduler ([`EulerDiscreteScheduler`]):
            A scheduler to be used in combination with unet to denoise the encoded image latents.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            A `CLIPImageProcessor` to extract features from generated images.
    """

    model_cpu_offload_seq = "image_encoder->condition_track_unet->unet->vae"
    _callback_tensor_inputs = ["latents"]

    def __init__(
        self,
        vae: AutoencoderKLTemporalDecoder,
        image_encoder: CLIPVisionModelWithProjection,
        condition_track_unet: TrackUnet,
        unet: VideoGenUnet,
        scheduler: EulerDiscreteScheduler,
        feature_extractor: CLIPImageProcessor,
    ):
        # 不调用父类的__init__，直接初始化DiffusionPipeline
        from diffusers.pipelines.pipeline_utils import DiffusionPipeline
        DiffusionPipeline.__init__(self)
        
        # 注册所有组件，移除noise_track_scheduler
        self.register_modules(
            vae=vae,
            image_encoder=image_encoder,
            condition_track_unet=condition_track_unet,
            unet=unet,
            scheduler=scheduler,
            feature_extractor=feature_extractor,
        )
        
        # 设置VAE scale factor和video processor
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)

    @torch.inference_mode()
    def encode_video(
        self,
        video: torch.Tensor,
        chunk_size: int = 14,
    ) -> torch.Tensor:
        """
        Encode video frames to embeddings.
        Copied from DepthCrafterPipeline.
        
        :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
        :param chunk_size: the chunk size to encode video
        :return: image_embeddings in shape of [b, 1024]
        """
        video_224 = _resize_with_antialiasing(video.float(), (224, 224))
        video_224 = (video_224 + 1.0) / 2.0  # [-1, 1] -> [0, 1]

        embeddings = []
        for i in range(0, video_224.shape[0], chunk_size):
            tmp = self.feature_extractor(
                images=video_224[i : i + chunk_size],
                do_normalize=True,
                do_center_crop=False,
                do_resize=False,
                do_rescale=False,
                return_tensors="pt",
            ).pixel_values.to(video.device, dtype=video.dtype)
            embeddings.append(self.image_encoder(tmp).image_embeds)  # [b, 1024]

        embeddings = torch.cat(embeddings, dim=0)  # [t, 1024]
        return embeddings

    @torch.inference_mode()
    def encode_vae_video(
        self,
        video: torch.Tensor,
        chunk_size: int = 14,
    ):
        """
        Encode video frames using VAE.
        Copied from DepthCrafterPipeline.
        
        :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
        :param chunk_size: the chunk size to encode video
        :return: vae latents in shape of [b, c, h, w]
        """
        video_latents = []
        for i in range(0, video.shape[0], chunk_size):
            video_latents.append(
                self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
            )
        video_latents = torch.cat(video_latents, dim=0)
        return video_latents

    def check_inputs(self, condition_video, condition_image, height, width):
        """
        Check inputs for the pipeline.
        """
        # Check condition_video
        if not isinstance(condition_video, torch.Tensor) and not isinstance(condition_video, np.ndarray):
            raise ValueError(
                f"Expected `condition_video` to be a `torch.Tensor` or `np.ndarray`, but got a {type(condition_video)}"
            )
        
        # Check condition_image
        if (
            not isinstance(condition_image, torch.Tensor)
            and not isinstance(condition_image, PIL.Image.Image)
            and not isinstance(condition_image, list)
        ):
            raise ValueError(
                "`condition_image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
                f" {type(condition_image)}"
            )

        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

    @torch.no_grad()
    def __call__(
        self,
        condition_video: Union[np.ndarray, torch.Tensor],
        condition_image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
        height: int = 576,
        width: int = 1024,
        num_frames: Optional[int] = None,
        num_inference_steps: int = 25,
        sigmas: Optional[List[float]] = None,
        min_guidance_scale: float = 1.0,
        max_guidance_scale: float = 3.0,
        fps: int = 7,
        motion_bucket_id: int = 127,
        noise_aug_strength: float = 0.02,
        decode_chunk_size: Optional[int] = None,
        num_videos_per_prompt: Optional[int] = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        output_type: Optional[str] = "pil",
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        return_dict: bool = True,
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
            condition_video (`np.ndarray` or `torch.Tensor`):
                Condition video frames. If np.ndarray, expected shape is [t, h, w, c] in range [0, 1].
                If torch.Tensor, expected shape is [t, c, h, w] in range [0, 1].
            condition_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
                Condition image to guide video generation. If you provide a tensor, the expected value range is between `[0, 1]`.
            height (`int`, *optional*, defaults to 576):
                The height in pixels of the generated video.
            width (`int`, *optional*, defaults to 1024):
                The width in pixels of the generated video.
            num_frames (`int`, *optional*):
                The number of video frames to generate. If not provided, uses the number of frames from condition_video.
            num_inference_steps (`int`, *optional*, defaults to 25):
                The number of denoising steps.
            sigmas (`List[float]`, *optional*):
                Custom sigmas to use for the denoising process.
            min_guidance_scale (`float`, *optional*, defaults to 1.0):
                The minimum guidance scale.
            max_guidance_scale (`float`, *optional*, defaults to 3.0):
                The maximum guidance scale.
            fps (`int`, *optional*, defaults to 7):
                Frames per second.
            motion_bucket_id (`int`, *optional*, defaults to 127):
                Used for conditioning the amount of motion for the generation.
            noise_aug_strength (`float`, *optional*, defaults to 0.02):
                The amount of noise added to the condition image.
            decode_chunk_size (`int`, *optional*):
                The number of frames to decode at a time.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                The number of videos to generate per prompt.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                A torch generator for reproducible generation.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated video.
            callback_on_step_end (`Callable`, *optional*):
                A function that is called at the end of each denoising step.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the callback function.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a StableVideoDiffusionPipelineOutput.

        Returns:
            [`StableVideoDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`StableVideoDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned.
        """

        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # Get num_frames from condition_video if not provided
        if isinstance(condition_video, np.ndarray):
            video_num_frames = condition_video.shape[0]
        else:
            video_num_frames = condition_video.shape[0]
        
        num_frames = num_frames if num_frames is not None else video_num_frames
        decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames

        # 1. Check inputs
        self.check_inputs(condition_video, condition_image, height, width)

        # 2. Define call parameters
        if isinstance(condition_image, PIL.Image.Image):
            batch_size = 1
        elif isinstance(condition_image, list):
            batch_size = len(condition_image)
        else:
            batch_size = condition_image.shape[0]
        
        device = self._execution_device
        self._guidance_scale = max_guidance_scale
        
        # Handle generator device automatically
        if generator is not None and hasattr(generator, 'device'):
            if str(generator.device) != str(device):
                generator = torch.Generator(device=device).manual_seed(generator.initial_seed())

        # 3. Encode condition_image
        image_embeddings = self._encode_image(condition_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)

        # 4. Encode condition_video
        # Convert condition_video to proper format
        if isinstance(condition_video, np.ndarray):
            condition_video = torch.from_numpy(condition_video.transpose(0, 3, 1, 2))
        condition_video = condition_video.to(device=device, dtype=self.dtype)
        condition_video = condition_video / 255.0
        condition_video = condition_video * 2.0 - 1.0  # [0,1] -> [-1,1]
        condition_video_embeddings = self.encode_video(
            condition_video, chunk_size=decode_chunk_size
        ).unsqueeze(
            0
        )  # [1, t, 1024]

        # Encode condition_video using VAE
        needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
        if needs_upcasting:
            self.vae.to(dtype=torch.float32)

        condition_video_latents = self.encode_vae_video(
            condition_video.to(self.vae.dtype),
            chunk_size=decode_chunk_size,
        ).unsqueeze(0)  # [1, t, c, h, w]
        condition_video_latents = condition_video_latents.to(image_embeddings.dtype)

        # 5. Encode condition_image using VAE
        condition_image = self.video_processor.preprocess(condition_image, height=height, width=width).to(device)
        noise = randn_tensor(condition_image.shape, generator=generator, device=device, dtype=condition_image.dtype)
        condition_image = condition_image + noise_aug_strength * noise

        condition_image_latents = self._encode_vae_image(
            condition_image,
            device=device,
            num_videos_per_prompt=num_videos_per_prompt,
            do_classifier_free_guidance=self.do_classifier_free_guidance,
        )
        condition_image_latents = condition_image_latents.to(image_embeddings.dtype)

        # Cast back to fp16 if needed
        if needs_upcasting:
            self.vae.to(dtype=torch.float16)

        # Repeat the image latents for each frame
        condition_image_latents = condition_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)

        # 6. Get Added Time IDs
        fps = fps - 1  # SVD was conditioned on fps - 1
        added_time_ids1 = self._get_add_time_ids(
            7,
            motion_bucket_id,
            0.0,
            image_embeddings.dtype,
            batch_size,
            num_videos_per_prompt,
            False,
        )
        added_time_ids1 = added_time_ids1.to(device)

        added_time_ids2 = self._get_add_time_ids(
            7,
            motion_bucket_id,
            noise_aug_strength,
            image_embeddings.dtype,
            batch_size,
            num_videos_per_prompt,
            self.do_classifier_free_guidance,
        )
        added_time_ids2 = added_time_ids2.to(device)

        # 7. Prepare timesteps (only for main scheduler, track unet uses fixed timestep)
        scheduler_track = copy.deepcopy(self.scheduler)
        track_timesteps, _ = retrieve_timesteps(scheduler_track, 25, device, None, None)
        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)

        # 8. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        
        # Prepare track noise (fixed, only used once)
        track_noise_shape = (
            batch_size * num_videos_per_prompt,
            num_frames,
            self.condition_track_unet.config.in_channels - condition_video_latents.shape[2],  # 减去condition_video的通道数
            height // self.vae_scale_factor,
            width // self.vae_scale_factor,
        )
        track_noise = randn_tensor(track_noise_shape, generator=generator, device=device, dtype=image_embeddings.dtype)
        track_noise = track_noise * scheduler_track.init_noise_sigma
        # track_noise = torch.randn(track_noise_shape, generator=generator, device=device, dtype=image_embeddings.dtype)
        
        # Prepare main latents for video generation
        latents = self.prepare_latents(
            batch_size * num_videos_per_prompt,
            num_frames,
            num_channels_latents,
            height,
            width,
            image_embeddings.dtype,
            device,
            generator,
            latents,  # Use provided latents if available
        )

        # 9. Prepare guidance scale
        guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
        guidance_scale = guidance_scale.to(device, latents.dtype)
        guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
        guidance_scale = _append_dims(guidance_scale, latents.ndim)

        self._guidance_scale = guidance_scale
        
        # 10. Run track unet once with fixed timestep to get track condition
        track_noise = scheduler_track.scale_model_input(track_noise, track_timesteps[0])
        track_input = torch.cat([track_noise, condition_video_latents], dim=2)
        
        with torch.no_grad():
            track_output = self.condition_track_unet(
                track_input,
                track_timesteps[0],
                encoder_hidden_states=condition_video_embeddings,
                added_time_ids=added_time_ids1,
                return_dict=True,
            )
        
        # Process track condition
        track_condition_down_res_raw, track_condition_mid_res_raw = track_output.down_block_res_samples, track_output.mid_block_sample
        track_condition_down_res = list(tuple(track_condition_down_res_raw))
        feature_scale = 1.5
        constant_list = [0, 0, 0, 
                         0, 0, 0, 
                         feature_scale, feature_scale, feature_scale, 
                         feature_scale, feature_scale, feature_scale]
        for i in range(12):
            track_condition_down_res[i] = track_condition_down_res_raw[i] * constant_list[i]
        track_condition_mid_res = track_condition_mid_res_raw[0] * 0
        if self.do_classifier_free_guidance:
            track_condition_down_res = tuple(torch.cat([res, res]) for res in track_condition_down_res)
            track_condition_mid_res = torch.cat([track_condition_mid_res, track_condition_mid_res])
        # 11. Denoising loop (only for main UNet)
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        self._num_timesteps = len(timesteps)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # Expand latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # Main UNet forward pass with pre-computed tracking condition
                unet_input = torch.cat([latent_model_input, condition_image_latents], dim=2)
                
                noise_pred = self.unet(
                    unet_input,
                    t,
                    encoder_hidden_states=image_embeddings,
                    added_time_ids=added_time_ids2,
                    down_block_additional_residuals=track_condition_down_res,
                    mid_block_additional_residual=track_condition_mid_res,
                    return_dict=False,
                )[0]

                # Perform guidance
                if self.do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)

                # Update latents using main scheduler
                latents = self.scheduler.step(noise_pred, t, latents).prev_sample
                
                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        if k == "latents":
                            callback_kwargs[k] = latents
                        else:
                            callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
                    latents = callback_outputs.pop("latents", latents)

                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()

        # 12. Decode latents
        if not output_type == "latent":
            if needs_upcasting:
                self.vae.to(dtype=torch.float16)
            frames = self.decode_latents(latents, num_frames, decode_chunk_size)
            frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
        else:
            frames = latents

        self.maybe_free_model_hooks()

        if not return_dict:
            return frames

        return StableVideoDiffusionPipelineOutput(frames=frames)
    
    @classmethod
    def from_pretrained_with_two_unets(
        cls,
        pretrained_model_name_or_path: str,
        torch_dtype: Optional[torch.dtype] = None,
        variant: Optional[str] = None,
        **kwargs
    ):
        """
        Load pipeline
        
        Args:
            pretrained_model_name_or_path: Path to the base SVD model
            fixed_timestep: Fixed timestep for track unet inference
            torch_dtype: Data type for pipeline components
            variant: Model variant (e.g., "fp16")
            **kwargs: Additional arguments for from_pretrained
        """
        # Load base components from SVD
        from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler
        from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
        
        # Prepare common loading kwargs
        loading_kwargs = {}
        if torch_dtype is not None:
            loading_kwargs["torch_dtype"] = torch_dtype
        if variant is not None:
            loading_kwargs["variant"] = variant
        
        vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_model_name_or_path, subfolder="vae", **loading_kwargs)
        image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", **loading_kwargs)
        feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_name_or_path, subfolder="feature_extractor")
        scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
        
        # Load custom UNets (need to be provided separately)
        # These paths should be provided in kwargs
        track_unet_path = kwargs.pop('track_unet_path', None)
        video_gen_unet_path = kwargs.pop('video_gen_unet_path', pretrained_model_name_or_path)
        
        if track_unet_path is None:
            raise ValueError("track_unet_path must be provided")
            
        condition_track_unet = TrackUnet.from_pretrained(track_unet_path, **loading_kwargs)
        unet = VideoGenUnet.from_pretrained(
            video_gen_unet_path, 
            subfolder="unet" if video_gen_unet_path == pretrained_model_name_or_path else None,
            **loading_kwargs
        )
        
        pipeline = cls(
            vae=vae,
            image_encoder=image_encoder,
            condition_track_unet=condition_track_unet,
            unet=unet,
            scheduler=scheduler,
            feature_extractor=feature_extractor,
            **kwargs
        )
        
        # Set dtype for the entire pipeline if specified
        if torch_dtype is not None:
            pipeline = pipeline.to(torch_dtype)
        
        return pipeline

    # @property
    # def do_classifier_free_guidance(self):
    #     return False
