import torch
import torch.nn.functional as F

from einops import rearrange
import math
from typing import Any, Dict, List, Optional, Union

class PyramidDiTForVideoGeneration:
    def __init__(
        self, 
        model_path, 
        model_dtype='bf16', 
        use_gradient_checkpointing=False, 
        return_log=True, 
        model_variant="diffusion_transformer_768p", 
        timestep_shift=1.0, 
        stage_range=[0, 1/3, 2/3, 1],
        sample_ratios=[1, 1, 1], 
        scheduler_gamma=1/3, 
        use_mixed_training=False, 
        use_flash_attn=False, 
        load_text_encoder=True, 
        load_vae=True, 
        max_temporal_length=31, 
        frame_per_unit=1, 
        use_temporal_causal=True, 
        corrupt_ratio=1/3, 
        interp_condition_pos=True, 
        stages=[1, 2, 4], 
        video_sync_group=8, 
        gradient_checkpointing_ratio=0.6, 
        audio_model_path=None,
        use_audio=False,
        use_audio_joint_attn=False,
        load_dit=True,
        load_audio_encoder=False,        
        init_additional_args: Dict[str, Any] = {},
        use_motion_loss: bool = False,
        **kwargs,
    ):
        super().__init__()

        self.use_motion_loss = use_motion_loss
        
    def calculate_loss(self, model_preds_list, targets_list, gt_motion_list=None):
        if gt_motion_list:
            base_loss_list = []
            motion_loss_list = []
        else:
            base_loss_list = None
            motion_loss_list = None
        loss_list = []
    
        for idx, (model_pred, target) in enumerate(zip(model_preds_list, targets_list)):
            # Compute the loss.
            loss_weight = torch.ones_like(target)
            base_loss = torch.mean( 
                (loss_weight.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
                1,
            ) 
            if gt_motion_list:
                base_loss_list.append(base_loss)
                
            # Compute the motion loss.
            if gt_motion_list is not None:
                gt_motion = gt_motion_list[idx]
                motion_term = self.motion_loss(model_pred, target, gt_motion)
                motion_loss_list.append(motion_term)
                total_loss = base_loss + motion_term
                loss_list.append(total_loss)
            else:
                loss_list.append(base_loss)                     

        diffusion_loss = torch.cat(loss_list, dim=0).mean()
        return diffusion_loss, {}, loss_list, base_loss_list, motion_loss_list
        
    def motion_loss(self, model_pred, target, gt_motion, lambda_weight=1.0):
        pred_error = torch.abs(model_pred.float() - target.float())
        motion_term = lambda_weight * torch.mean((pred_error * gt_motion.float()) ** 2)
        
        return motion_term

    @torch.no_grad()
    def generate_one_unit(
        self,
        latents,
        past_conditions, # List of past conditions, contains the conditions of each stage
        prompt_embeds,
        prompt_attention_mask,
        pooled_prompt_embeds,
        num_inference_steps,
        height,
        width,
        temp,
        device,
        dtype,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        is_first_frame: bool = False,
        audio: Optional[torch.Tensor] = None,
        audio_temperature: float = 1.0,
        additional_args: Optional[Dict[str, Any]] = {},
    ):
        stages = self.stages
        intermed_latents = []

        for i_s in range(len(stages)):
            self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
            timesteps = self.scheduler.timesteps

            if i_s > 0:
                height *= 2; width *= 2
                latents = rearrange(latents, 'b c t h w -> (b t) c h w')
                latents = F.interpolate(latents, size=(height, width), mode='nearest')
                latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
                # Fix the stage
                ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s]  
                gamma = self.scheduler.config.gamma
                alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)
                beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)

                bs, ch, temp, height, width = latents.shape
                noise = self.sample_block_noise(bs, ch, temp, height, width)
                noise = noise.to(device=device, dtype=dtype)
                latents = alpha * latents + beta * noise   

            for idx, t in enumerate(timesteps):
                if self.do_classifier_free_guidance and not self.do_audio_skip_guidance:
                    latent_model_input = torch.cat([latents] * 2)
                elif self.do_classifier_free_guidance and self.do_audio_skip_guidance:
                    latent_model_input = torch.cat([latents] * 3)
                else:
                    latent_model_input = latents
                timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)

                latent_model_input = past_conditions[i_s] + [latent_model_input]

                noise_pred = self.dit(
                    sample=[latent_model_input], 
                    timestep_ratio=timestep,
                    encoder_hidden_states=prompt_embeds,
                    encoder_attention_mask=prompt_attention_mask,
                    pooled_projections=pooled_prompt_embeds,
                    audio=audio, 
                    audio_temperature=audio_temperature,
                    additional_args=additional_args,
                )

                noise_pred = noise_pred[0]
                
                # perform guidance
                if self.do_classifier_free_guidance and not self.do_audio_skip_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    if is_first_frame:
                        noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
                    else:
                        noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond)
                elif self.do_classifier_free_guidance and self.do_audio_skip_guidance:
                    noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
                    if is_first_frame:
                        noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
                    else:
                        noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond) \
                            + self.audio_guidance_scale * (noise_pred_text - noise_pred_perturb)
                
                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(
                    model_output=noise_pred,
                    timestep=timestep,
                    sample=latents,
                    generator=generator,
                ).prev_sample

            intermed_latents.append(latents)

        return intermed_latents