import inspect, copy
from typing import Any, Callable, Dict, List, Optional, Union
import torch, einops
from diffusers.schedulers import (
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    LMSDiscreteScheduler,
    DDIMInverseScheduler,
    PNDMScheduler,
)
from diffusers import LattePipeline
from diffusers.pipelines.latte.pipeline_latte import retrieve_timesteps


class LattePipeline_GN(LattePipeline):
    def __init__(self,
        tokenizer,
        text_encoder,
        vae,
        transformer,
        scheduler,
    ):
        super(LattePipeline_GN, self).__init__(
                tokenizer = tokenizer,
                text_encoder = text_encoder,
                vae = vae,
                unet = unet,
                transformer = transformer,
                scheduler = scheduler
        )
        
        self.recall_timesteps = 1
        self.ensemble = 1
        self.ensemble_rate = 0.1
        self.pre_num_inference_steps = 50
        self.fast_ensemble = False
        self.momentum = 0.
        self.traj_momentum = 0.05
        self.ensemble_guidance_scale = False
        self.noise_type = "uniform"
        
    @torch.no_grad()
    def __call__(self,
        prompt: Union[str, List[str]] = None,
        video_length: Optional[int] = 16,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_videos_per_prompt: Optional[int] = 1,
        decode_chunk_size = None,
        *args,**kwargs):
        
        decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else video_length
        num_frames = video_length
        _height = height or self.transformer.config.sample_size * self.vae_scale_factor
        _width = width or self.transformer.config.sample_size * self.vae_scale_factor
        self._guidance_scale = guidance_scale
        
        if prompt is not None and isinstance(prompt, str):
            _batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            _batch_size = len(prompt)
        else:
            _batch_size = kwargs.get("prompt_embeds", None).shape[0]
        _device = self._execution_device
        _do_classifier_free_guidance = guidance_scale > 1
        
        _prompt_embeds, _negative_prompt_embeds = self.encode_prompt(
            prompt,
            _do_classifier_free_guidance,
            negative_prompt,
            num_images_per_prompt=kwargs.get("num_images_per_prompt", 1),
            device=_device,
            prompt_embeds = kwargs.get("prompt_embeds", None),
            negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None),
            clean_caption=kwargs.get("clean_caption", True),
            mask_feature=kwargs.get("mask_feature", True),
        )
        
        if _do_classifier_free_guidance:
            _prompt_embeds = torch.cat([_negative_prompt_embeds, _prompt_embeds],dim=0)
        
        pre_num_inference_steps = self.pre_num_inference_steps
        _emergency_scheduler = copy.deepcopy(self.scheduler)
        _timesteps, _num_inference_steps = retrieve_timesteps(
            _emergency_scheduler, pre_num_inference_steps, _device, timesteps, sigmas
        )
        _set_num_inference_steps = _num_inference_steps
        _num_channels_latents = self.transformer.config.in_channels
        
        _latents = self.prepare_latents(
            _batch_size * kwargs.get("num_images_per_prompt", 1),
            _num_channels_latents,
            num_frames,
            _height,
            _width,
            _prompt_embeds.dtype,
            _device,
            kwargs.get("generator", None),
            kwargs.get("latents", None)
            )

        extra_step_kwargs = self.prepare_extra_step_kwargs(kwargs.get("generator", None), kwargs.get("eta", 0.))
        _inverse_scheduler = DDIMInverseScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
                                                                    subfolder='scheduler')
        _inverse_scheduler.set_timesteps(_num_inference_steps, device=_device)
        _lasting_t = _timesteps[0]
        _prev_lasting_t = _timesteps[0] -_emergency_scheduler.config.num_train_timesteps // _emergency_scheduler.num_inference_steps
        _optim_steps = self.recall_timesteps
        
        for i in range(_optim_steps):
            if self.ensemble == 1:
                if self.ensemble_guidance_scale:
                    guidance_scale = guidance_scale
                    rand = torch.randn(1).item()
                positive_guidance_scale = guidance_scale if not self.ensemble_guidance_scale else guidance_scale + rand
                negative_guidance_scale = 1.0 if not self.ensemble_guidance_scale else 1.0 + rand
            
                _latent_model_input = torch.cat([_latents] * 2) if _do_classifier_free_guidance else _latents # Forward DDIM
                _latent_model_input = _emergency_scheduler.scale_model_input(_latent_model_input, _lasting_t)
                
                _current_timestep = _lasting_t
                if not torch.is_tensor(_current_timestep):
                    # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                    # This would be a good case for the `match` statement (Python 3.10+)
                    is_mps = _latent_model_input.device.type == "mps"
                    if isinstance(current_timestep, float):
                        dtype = torch.float32 if is_mps else torch.float64
                    else:
                        dtype = torch.int32 if is_mps else torch.int64
                    _current_timestep = torch.tensor([_current_timestep], dtype=dtype, device=_latent_model_input.device)
                elif len(_current_timestep.shape) == 0:
                    current_timestep = _current_timestep[None].to(_latent_model_input.device)
                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
                _current_timestep = _current_timestep.expand(_latent_model_input.shape[0])
                
                _noise_pred = self.transformer(
                    _latent_model_input,
                    timestep=_lasting_t,
                    encoder_hidden_states=_prompt_embeds,
                    enable_temporal_attentions=kwargs.get("enable_temporal_attentions", True),
                ).sample
                
                if _do_classifier_free_guidance:
                    _noise_pred_uncond, _noise_pred_text = _noise_pred.chunk(2)
                    _noise_pred = _noise_pred_uncond + positive_guidance_scale * (_noise_pred_text - _noise_pred_uncond)
                _latents = _emergency_scheduler.step(_noise_pred, _lasting_t, _latents, **extra_step_kwargs).prev_sample

                _latent_model_input = torch.cat([_latents] * 2) if _do_classifier_free_guidance else _latents  # Inverse DDIM
                _latent_model_input = _emergency_scheduler.scale_model_input(_latent_model_input, _lasting_t)
                
                _current_timestep = _prev_lasting_t
                if not torch.is_tensor(_current_timestep):
                    # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                    # This would be a good case for the `match` statement (Python 3.10+)
                    is_mps = _latent_model_input.device.type == "mps"
                    if isinstance(current_timestep, float):
                        dtype = torch.float32 if is_mps else torch.float64
                    else:
                        dtype = torch.int32 if is_mps else torch.int64
                    _current_timestep = torch.tensor([_current_timestep], dtype=dtype, device=_latent_model_input.device)
                elif len(_current_timestep.shape) == 0:
                    current_timestep = _current_timestep[None].to(_latent_model_input.device)
                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
                _current_timestep = _current_timestep.expand(_latent_model_input.shape[0])
                
                _noise_pred = self.transformer(
                    _latent_model_input,
                    timestep=_current_timestep,
                    encoder_hidden_states=_prompt_embeds,
                    enable_temporal_attentions=kwargs.get("enable_temporal_attentions", True),
                ).sample
                if _do_classifier_free_guidance:
                    _noise_pred_uncond, _noise_pred_text = _noise_pred.chunk(2)
                    _noise_pred = _noise_pred_uncond + negative_guidance_scale * (_noise_pred_text - _noise_pred_uncond)
                _inv_extra_step_kwargs = copy.deepcopy(extra_step_kwargs)
                _inv_extra_step_kwargs.pop("eta")
                _latents = _inverse_scheduler.step(_noise_pred, _lasting_t, _latents, return_dict=False)[0]
            else:
                results = []
                _prev_latents = _latents.clone()
                _prev_prev_latents = _latents.clone()
                
                for j in range(self.ensemble):
                    if self.noise_type == "uniform":
                        additional_noise = (torch.rand_like(_latents) - 0.5) * 2 * (3 ** (1/2))
                    elif self.noise_type == "truncated_gaussian":
                        additional_noise = torch.randn_like(_latents).clip_(-1, 1)
                    else:
                        additional_noise = torch.randn_like(_latents)
                    if self.fast_ensemble:
                        _prev_prev_latents = _prev_latents
                        _prev_latents = _latents
                        _latents = results[-1] * (1-self.traj_momentum) + self.traj_momentum * (1-self.traj_momentum) * _prev_latents + self.traj_momentum * self.traj_momentum * _prev_prev_latents if len(results)>0 else _latents
                        _latents_n = additional_noise * self.ensemble_rate * (1 - self.momentum) ** (j+1) + _latents
                    else:
                        _latents_n = additional_noise * self.ensemble_rate + _latents
                    if self.ensemble_guidance_scale:
                        guidance_scale = guidance_scale
                        rand = torch.randn(1).item()
                    positive_guidance_scale = guidance_scale if not self.ensemble_guidance_scale else guidance_scale + rand
                    negative_guidance_scale = 1.0 if not self.ensemble_guidance_scale else 1.0 + rand
                
                    
                    _latent_model_input = torch.cat([_latents_n] * 2) if _do_classifier_free_guidance else _latents_n # Forward DDIM
                    _latent_model_input = _emergency_scheduler.scale_model_input(_latent_model_input, _lasting_t)
                    
                    _current_timestep = _lasting_t
                    if not torch.is_tensor(_current_timestep):
                        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                        # This would be a good case for the `match` statement (Python 3.10+)
                        is_mps = _latent_model_input.device.type == "mps"
                        if isinstance(current_timestep, float):
                            dtype = torch.float32 if is_mps else torch.float64
                        else:
                            dtype = torch.int32 if is_mps else torch.int64
                        _current_timestep = torch.tensor([_current_timestep], dtype=dtype, device=_latent_model_input.device)
                    elif len(_current_timestep.shape) == 0:
                        current_timestep = _current_timestep[None].to(_latent_model_input.device)
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
                    _current_timestep = _current_timestep.expand(_latent_model_input.shape[0])
                    
                    _noise_pred = self.transformer(
                        _latent_model_input,
                        timestep=_current_timestep,
                        encoder_hidden_states=_prompt_embeds,
                        enable_temporal_attentions=kwargs.get("enable_temporal_attentions", True),
                    ).sample
                    
                    if _do_classifier_free_guidance:
                        _noise_pred_uncond, _noise_pred_text = _noise_pred.chunk(2)
                        _noise_pred = _noise_pred_uncond + positive_guidance_scale * (_noise_pred_text - _noise_pred_uncond)
                    _latents_n = _emergency_scheduler.step(_noise_pred, _lasting_t, _latents_n, **extra_step_kwargs).prev_sample
                    _latent_model_input = torch.cat([_latents_n] * 2) if _do_classifier_free_guidance else _latents_n  # Inverse DDIM
                    _latent_model_input = _emergency_scheduler.scale_model_input(_latent_model_input, _prev_lasting_t)
                    
                    _current_timestep = _prev_lasting_t
                    if not torch.is_tensor(_current_timestep):
                        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                        # This would be a good case for the `match` statement (Python 3.10+)
                        is_mps = _latent_model_input.device.type == "mps"
                        if isinstance(current_timestep, float):
                            dtype = torch.float32 if is_mps else torch.float64
                        else:
                            dtype = torch.int32 if is_mps else torch.int64
                        _current_timestep = torch.tensor([_current_timestep], dtype=dtype, device=_latent_model_input.device)
                    elif len(_current_timestep.shape) == 0:
                        current_timestep = _current_timestep[None].to(_latent_model_input.device)
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
                    _current_timestep = _current_timestep.expand(_latent_model_input.shape[0])
                    
                    _noise_pred = self.transformer(
                        _latent_model_input,
                        timestep=_current_timestep,
                        encoder_hidden_states=_prompt_embeds,
                        enable_temporal_attentions=kwargs.get("enable_temporal_attentions", True),
                    ).sample
                    
                    if _do_classifier_free_guidance:
                        _noise_pred_uncond, _noise_pred_text = _noise_pred.chunk(2)
                        _noise_pred = _noise_pred_uncond + negative_guidance_scale * (_noise_pred_text - _noise_pred_uncond)
                    _inv_extra_step_kwargs = copy.deepcopy(extra_step_kwargs)
                    _inv_extra_step_kwargs.pop("eta")
                    _latents_n = _inverse_scheduler.step(_noise_pred, _lasting_t, _latents_n, return_dict=False)[0]
                    results.append(_latents_n)
                if self.fast_ensemble:
                    _latents = results[-1]
                else:
                    _latents = torch.stack(results,0).mean(0)
                
                
        print("Successfully Generate Optimium Noise")
        if kwargs.get("prompt_2", None) is not None:
            prompt = kwargs.get("prompt_2", None) # Second Prompt
        
        return super().__call__(    latents = _latents,
                                    prompt = prompt,
                                    video_length = num_frames,
                                    height = height,
                                    width = width,
                                    num_inference_steps = num_inference_steps,
                                    guidance_scale = guidance_scale,
                                    negative_prompt = negative_prompt,
                                    num_images_per_prompt = kwargs.get("num_images_per_prompt", True),
                                    enable_temporal_attentions =  kwargs.get("enable_temporal_attentions", True)
                                    *args,**kwargs)
        
        
