from typing import Callable, Dict, List, Optional, Union

import torch
from einops import rearrange

from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import PipelineImageInput
from diffusers.utils import replace_example_docstring, is_torch_xla_available
from diffusers.utils.torch_utils import randn_tensor

from lib.pipelines.pipeline_output import CosmosPipelineOutput
from lib.pipelines.pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline, \
    EXAMPLE_DOC_STRING, retrieve_timesteps, retrieve_latents
from lib.pipelines.pipeline_cosmos2_acwm import ACWMCosmos2Pipeline


if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False


class CannyAugCosmos2Pipeline(ACWMCosmos2Pipeline):

    @torch.no_grad()
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        self,
        image: PipelineImageInput = None,
        video: List[PipelineImageInput] = None,  # input video is (b v) t c h w (t is ahead of c)
        cond_to_concat: List[PipelineImageInput] = None,  # (b v) c t h w
        prompt: Union[str, List[str]] = None,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        height: int = 704,
        width: int = 1280,
        num_frames: int = 93, # equals chunk size
        num_inference_steps: int = 35,
        guidance_scale: float = 7.0,
        fps: int = 16,
        num_videos_per_prompt: Optional[int] = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback_on_step_end: Optional[
            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
        ] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        max_sequence_length: int = 512,
        sigma_conditioning: float = 0.0001,
        n_view: int = 3,
        n_prev: int = 4,
        merge_view_into_width: bool = True,
        postprocess_video: bool = True,
        show_progress: bool = True
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
            image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
                The image to be used as a conditioning input for the video generation.
            video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
                The video to be used as a conditioning input for the video generation.
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            height (`int`, defaults to `704`):
                The height in pixels of the generated image.
            width (`int`, defaults to `1280`):
                The width in pixels of the generated image.
            num_frames (`int`, defaults to `93`):
                The number of frames in the generated video.
            num_inference_steps (`int`, defaults to `35`):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, defaults to `7.0`):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`.
            fps (`int`, defaults to `16`):
                The frames per second of the generated video.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int`, defaults to `512`):
                The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
                the prompt is shorter than this length, it will be padded.
            sigma_conditioning (`float`, defaults to `0.0001`):
                The sigma value used for scaling conditioning latents. Ideally, it should not be changed or should be
                set to a small value close to zero.

        Examples:

        Returns:
            [`~CosmosPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
                the first element is a list with the generated images and the second element is a list of `bool`s
                indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
        """

        # if self.safety_checker is None:
        #     raise ValueError(
        #         f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
        #         "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
        #         f"Please ensure that you are compliant with the license agreement."
        #     )

        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)

        self._guidance_scale = guidance_scale
        self._current_timestep = None
        self._interrupt = False

        device = self._execution_device

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        # 3. Encode input prompt
        assert num_videos_per_prompt == 1  # TODO: only support num_videos_per_prompt=1 for now
        (
            prompt_embeds,
            negative_prompt_embeds,
        ) = self.encode_prompt(
            prompt=prompt,
            negative_prompt=negative_prompt,
            do_classifier_free_guidance=self.do_classifier_free_guidance,
            num_videos_per_prompt=num_videos_per_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            device=device,
            max_sequence_length=max_sequence_length,
        )

        # 4. Prepare timesteps
        sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
        sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
        if self.scheduler.config.final_sigmas_type == "sigma_min":
            # Replace the last sigma (which is zero) with the minimum sigma value
            self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]

        # 5. Prepare latent variables
        vae_dtype = self.vae.dtype
        transformer_dtype = self.transformer.dtype

        if image is not None:
            video = self.video_processor.preprocess(image, height, width).unsqueeze(2)
        else:
            # input video is (b v) t c h w, output is (b v) c t h w
            video = self.video_processor.preprocess_video(video, height, width)
        video = video.to(device=device, dtype=vae_dtype)

        # num_channels_latents = self.transformer.config.in_channels - 1
        num_channels_latents = self.vae.z_dim
        if video.shape[2] > n_prev:  # pyright: ignore
            video = video[:, :, :n_prev]  # pyright: ignore
        latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents(
            video,  # memory only! (b v) c t h w
            batch_size * n_view,
            num_channels_latents,
            height,
            width,
            num_frames,
            self.do_classifier_free_guidance,
            torch.float32,
            device,
            generator,
            latents,
            view_1_only=True,
            n_view=n_view
        ) 
           # latents here only contains future
        unconditioning_latents = None

        cond_mask = cond_mask.to(transformer_dtype)
        if self.do_classifier_free_guidance:
            uncond_mask = uncond_mask.to(transformer_dtype)
            unconditioning_latents = conditioning_latents

        padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)  # ??? looks like it's always 0
        sigma_conditioning = torch.tensor(sigma_conditioning, dtype=torch.float32, device=device)
        t_conditioning = sigma_conditioning / (sigma_conditioning + 1)

        # 6. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        self._num_timesteps = len(timesteps)

        if not show_progress:
            self.set_progress_bar_config(disable=True)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                self._current_timestep = t
                current_sigma = self.scheduler.sigmas[i]

                current_t = current_sigma / (current_sigma + 1)
                c_in = 1 - current_t
                c_skip = 1 - current_t
                c_out = -current_t
                timestep = current_t.view(1, 1, 1, 1, 1).expand(
                    latents.size(0), -1, latents.size(2)+conditioning_latents.size(2), -1, -1
                )  # [B, 1, T, 1, 1]

                cond_latent = latents * c_in  # all noise with a scale factor
                # replace :n_prev frames with clean video latents
                cond_latent = torch.cat([conditioning_latents, cond_latent], dim=2)  # frame
                cond_latent = torch.cat([cond_latent, cond_to_concat], dim=1)  # channel
                cond_latent = cond_latent.to(transformer_dtype)  # (b v) c t h w
                cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep
                cond_timestep = cond_timestep.to(transformer_dtype)

                noise_pred = self.transformer(
                    hidden_states=cond_latent,
                    timestep=cond_timestep,
                    encoder_hidden_states=prompt_embeds,
                    fps=fps,
                    condition_mask=cond_mask,
                    padding_mask=padding_mask,
                    return_dict=False,
                    n_view=n_view
                )[0]

                noise_pred = noise_pred[:, :, n_prev:]  # remove memory
                noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype)
                # cond_latent = cond_latent[:, :num_channels_latents]  # remove condition, keep memory
                # cond_latent = cond_latent / c_in
                # noise_pred = (c_skip * cond_latent + c_out * noise_pred.float()).to(transformer_dtype)

                if self.do_classifier_free_guidance:
                    uncond_latent = latents * c_in
                    # replace :n_prev frames with clean video latents
                    # uncond_latent = uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * uncond_latent
                    uncond_latent = torch.cat([conditioning_latents, uncond_latent], dim=2)  # frame
                    uncond_latent = torch.cat([uncond_latent, cond_to_concat], dim=1)  # channel
                    uncond_latent = uncond_latent.to(transformer_dtype)
                    uncond_timestep = uncond_indicator * t_conditioning + (1 - uncond_indicator) * timestep
                    uncond_timestep = uncond_timestep.to(transformer_dtype)

                    noise_pred_uncond = self.transformer(
                        hidden_states=uncond_latent,
                        timestep=uncond_timestep,
                        encoder_hidden_states=negative_prompt_embeds,
                        fps=fps,
                        condition_mask=uncond_mask,
                        padding_mask=padding_mask,
                        return_dict=False,
                        n_view=n_view
                    )[0]
                    noise_pred_uncond = noise_pred_uncond[:, :, n_prev:]  # remove memory
                    noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype)
                    # uncond_latent = uncond_latent[:, :num_channels_latents]  # remove condition, keep memory
                    # uncond_latent = uncond_latent / c_in
                    # noise_pred_uncond = (c_skip * uncond_latent + c_out * noise_pred_uncond.float()).to(transformer_dtype)

                    noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond)

                noise_pred = (latents - noise_pred) / current_sigma
                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
                # noise_pred = (cond_latent - noise_pred) / current_sigma
                # cond_latent = self.scheduler.step(noise_pred, t, cond_latent, return_dict=False)[0]

                # # insert back memory
                # cond_latent = rearrange(cond_latent, '(b v) c t h w -> b c v t h w', v=n_view)
                # conditioning_latents = rearrange(conditioning_latents, '(b v) c t h w -> b c v t h w', v=n_view)
                # # update hand-view memory and keep head-view memory unchanged
                # conditioning_latents[:, :, :1] = mem_head_view_backup.clone()
                # conditioning_latents[:, :, 1:] = cond_latent[:, :, 1:, :n_prev]  # pyright: ignore
                # latents = cond_latent[:, :, :, n_prev:]  # remove memory for next diffusion iter # pyright: ignore
                # latents = rearrange(latents, 'b c v t h w -> (b v) c t h w')
                # conditioning_latents = rearrange(conditioning_latents, 'b c v t h w -> (b v) c t h w')

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()

                if XLA_AVAILABLE:
                    xm.mark_step()

        self._current_timestep = None

        if not output_type == "latent":
            latents_mean = (
                torch.tensor(self.vae.config.latents_mean)
                .view(1, self.vae.config.z_dim, 1, 1, 1)
                .to(latents.device, latents.dtype)
            )
            latents_std = (
                torch.tensor(self.vae.config.latents_std)
                .view(1, self.vae.config.z_dim, 1, 1, 1)
                .to(latents.device, latents.dtype)
            )
            latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean  # config.sigma_data=1.0

            video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
            if merge_view_into_width:
                video = rearrange(video, '(b v) c t h w -> b c t h (v w)', v=n_view)  # should be vw not wv !!!

            # if self.safety_checker is not None:
            #     self.safety_checker.to(device)
            #     video = self.video_processor.postprocess_video(video, output_type="np")
            #     video = (video * 255).astype(np.uint8)
            #     video_batch = []
            #     for vid in video:
            #         vid = self.safety_checker.check_video_safety(vid)
            #         video_batch.append(vid)
            #     video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
            #     video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
            #     video = self.video_processor.postprocess_video(video, output_type=output_type)
            #     self.safety_checker.to("cpu")
            # else:
            if postprocess_video:
                video = self.video_processor.postprocess_video(video, output_type=output_type)
        else:
            video = latents

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (video,)

        return CosmosPipelineOutput(frames=video)

    def prepare_latents(
        self,
        video: torch.Tensor,  # (b v) c t h w
        batch_size: int,  # (b v)
        num_channels_latents: 16,
        height: int = 704,
        width: int = 1280,
        num_frames: int = 93,  # memory not included
        do_classifier_free_guidance: bool = True,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        view_1_only: bool = True, 
        n_view: int = 3,
        noise_to_condition_frames=0.2,
    ) -> torch.Tensor:
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        num_cond_frames = video.size(2)
        num_cond_latent_frames = num_cond_frames  # encode memory separately
        init_latents = [retrieve_latents(self.vae.encode(video[:, :, it].unsqueeze(2)), generator) for it in range(video.size(2))]

        init_latents = torch.cat(init_latents, dim=2).to(dtype)

        latents_mean = (
            torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
        )
        latents_std = (
            torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
        )
        init_latents = (init_latents - latents_mean) / latents_std * self.scheduler.config.sigma_data  # sigma_data=1.0

        if view_1_only:
            init_latents = rearrange(init_latents, '(b v) c m h w -> b v c m h w', v=n_view)
            # hand_view_noise_shape = list(init_latents.shape)
            # hand_view_noise_shape[1] -= 1  # 3 -> 2
            # hand_view_noise = randn_tensor(hand_view_noise_shape, generator=generator, device=device, dtype=dtype)
            # init_latents = torch.cat([init_latents[:, :1], hand_view_noise], dim=1)
            
            # # option 1: deprecated; fill hand view mem with noise and update during diffusion loop
            # view_noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
            # option 2: fill hand view mem with 0 and ignore completely
            view_noise = torch.zeros(init_latents.shape, device=device, dtype=dtype)

            head_view_mask = init_latents.new_zeros((init_latents.shape[0], init_latents.shape[1], 1, num_cond_latent_frames, 1, 1))
            head_view_mask[:, :1] = 1.0
            init_latents = rearrange(init_latents, 'b v c m h w -> (b v) c m h w')
            view_noise = rearrange(view_noise, 'b v c m h w -> (b v) c m h w')
            head_view_mask = rearrange(head_view_mask, 'b v c m h w -> (b v) c m h w')
            if noise_to_condition_frames > 0:
                rand_noise_ff_s = torch.rand(batch_size) * noise_to_condition_frames
                rand_noise_ff_e = torch.rand(batch_size) * noise_to_condition_frames
                rand_noise_ff_s, rand_noise_ff_e = torch.minimum(rand_noise_ff_s, rand_noise_ff_e), torch.maximum(rand_noise_ff_s, rand_noise_ff_e)
                rand_noise_ff = torch.stack([torch.linspace(rand_noise_ff_s[_], rand_noise_ff_e[_], num_cond_latent_frames) for _ in range(batch_size)], dim=0)
                rand_noise_ff = rand_noise_ff.reshape(batch_size, 1, num_cond_latent_frames, 1, 1).to(dtype=dtype, device=device)
                head_view_mask = head_view_mask * (1.0 - rand_noise_ff)
            init_latents = init_latents * head_view_mask + view_noise * (1 - head_view_mask)
            
        num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
        latent_height = height // self.vae_scale_factor_spatial
        latent_width = width // self.vae_scale_factor_spatial
        shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)

        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            latents = latents.to(device=device, dtype=dtype)

        latents = latents * self.scheduler.config.sigma_max  # sigma_max = 80.0

        padding_shape = (batch_size, 1, num_cond_latent_frames+num_latent_frames, latent_height, latent_width)
        ones_padding = latents.new_ones(padding_shape)
        zeros_padding = latents.new_zeros(padding_shape)
        
        mask_shape = (batch_size, 1, num_cond_latent_frames+num_latent_frames, latent_height, latent_width)
        conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
        conditioning_mask[:, :, :num_cond_latent_frames] = 1.0
        if view_1_only:
            conditioning_mask = rearrange(conditioning_mask, '(b v) c f h w -> b v c f h w', v=n_view)
            # conditioning_mask[:, :1, :mem_size] = 1.0
            conditioning_mask[:, 1:] = 0.0
            conditioning_mask = rearrange(conditioning_mask, 'b v c f h w -> (b v) c f h w', v=n_view)

        # similar to conditioning mask but useful to timesteps
        cond_indicator = latents.new_zeros(batch_size, 1, num_cond_latent_frames+num_latent_frames, 1, 1)
        cond_indicator[:, :, :num_cond_latent_frames] = 1.0
        if view_1_only:
            cond_indicator = rearrange(cond_indicator, '(b v) c f h w -> b v c f h w', v=n_view)
            cond_indicator[:, 1:] = 0.0
            cond_indicator = rearrange(cond_indicator, 'b v c f h w -> (b v) c f h w', v=n_view)

        uncond_indicator = uncond_mask = None  # equals cond_indicator and cond_mask
        if do_classifier_free_guidance:
            uncond_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
            uncond_mask[:, :, :num_cond_latent_frames] = 1.0
            if view_1_only:
                uncond_mask = rearrange(uncond_mask, '(b v) c f h w -> b v c f h w', v=n_view)
                # uncond_mask[:, :1, :mem_size] = 1.0
                uncond_mask[:, 1:] = 0.0
                uncond_mask = rearrange(uncond_mask, 'b v c f h w -> (b v) c f h w', v=n_view)

            # similar to conditioning mask but useful to timesteps
            uncond_indicator = latents.new_zeros(batch_size, 1, num_cond_latent_frames+num_latent_frames, 1, 1)
            uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
            if view_1_only:
                uncond_indicator = rearrange(uncond_indicator, '(b v) c f h w -> b v c f h w', v=n_view)
                uncond_indicator[:, 1:] = 0.0
                uncond_indicator = rearrange(uncond_indicator, 'b v c f h w -> (b v) c f h w', v=n_view)

        # latents here only contains future
        return latents, init_latents, cond_indicator, uncond_indicator, conditioning_mask, uncond_mask
