"""
Motion-optimized scheduler for CogVideoX that extends CogVideoXDPMScheduler.
This implementation adds motion variance optimization capabilities to the diffusers scheduler.
"""

import math
import logging
import torch
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, Union, List, Callable, Any
import gc

from diffusers.schedulers.scheduling_utils import SchedulerOutput
from diffusers import CogVideoXDPMScheduler

logger = logging.getLogger(__name__)

class MotionOptimizedCogVideoXScheduler(CogVideoXDPMScheduler):
    """
    Extends CogVideoXDPMScheduler with motion variance optimization capabilities.
    
    This scheduler modifies the step method to optimize latents at specific timesteps
    to reduce motion variance while preserving content quality.
    """
    
    def __init__(
        self,
        *args,
        optimize_motion: bool = False,
        optimization_lr: float = 0.001,
        optimization_steps: int = 1,
        motion_weight: float = 0.1,
        specific_timesteps: List[int] = None,
        log_file: str = None,
        **kwargs
    ):
        """
        Initialize the scheduler with motion optimization parameters.
        
        Args:
            optimize_motion: Whether to apply motion optimization
            optimization_lr: Learning rate for optimization
            optimization_steps: Number of optimization steps per diffusion step
            motion_weight: Weight of motion loss in the optimization
            specific_timesteps: List of timesteps to apply optimization at
            log_file: File to log optimization metrics
        """
        super().__init__(*args, **kwargs)
        
        # Motion optimization parameters
        self.optimize_motion = optimize_motion
        self.optimization_lr = optimization_lr
        self.optimization_steps = optimization_steps
        self.motion_weight = motion_weight
        self.specific_timesteps = specific_timesteps or []
        self.log_file = log_file
        
        # Transformer model reference for optimization
        self.transformer = None
        
        # Prompt info for logging
        self.prompt = None
        self.seed = None
        self.prompt_embeds = None
        
        # Last sample holder for multi-step optimization
        self.prev_latents = None
        self.pred_original_samples = {}
    
    def register_transformer(self, transformer):
        """Register the transformer model for optimization"""
        self.transformer = transformer
        
    def register_prompt_info(self, prompt, seed, prompt_embeds=None):
        """Register prompt and seed for logging purposes"""
        self.prompt = prompt
        self.seed = seed
        self.prompt_embeds = prompt_embeds
    
    def calculate_motion_variance(self, latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate the motion variance between consecutive frames in latent space.
        
        Args:
            latents: Latent representation with shape [batch_size, frames, channels, height, width]
                    or [frames, channels, height, width]
        
        Returns:
            motion_max_variance: Maximum motion variance
            appearance_max_variance: Maximum appearance variance
        """
        # Ensure proper dimensions
        if latents.dim() == 4:  # [frames, channels, height, width]
            latents = latents.unsqueeze(0)  # Add batch dimension
        
        # For CogVideoX, the latent format is [batch_size, frames, channels, height, width]
        batch_size, frames, channels, height, width = latents.shape
        
        # Reshape for easier frame-by-frame processing
        # [B, F, C, H, W] -> [B, F, H, W, C]
        latent_frames = latents.permute(0, 1, 3, 4, 2)
        
        # Calculate frame differences
        motion_frames = []
        for frame_idx in range(1, frames):
            current_frame = latent_frames[:, frame_idx]     # [B, H, W, C]
            previous_frame = latent_frames[:, frame_idx-1]  # [B, H, W, C]
            
            # Absolute difference between consecutive frames
            motion_frame = torch.abs(current_frame - previous_frame)  # [B, H, W, C]
            motion_frames.append(motion_frame)
        
        # Stack along frame dimension
        motion_frames_tensor = torch.stack(motion_frames, dim=1)  # [B, F-1, H, W, C]
        
        # Calculate variances
        motion_variance = torch.var(motion_frames_tensor, dim=1, unbiased=False)  # [B, H, W, C]
        appearance_variance = torch.var(latent_frames, dim=1, unbiased=False)     # [B, H, W, C]
        
        # Average across channels
        mean_motion_variance = motion_variance.mean(dim=-1)        # [B, H, W]
        mean_appearance_variance = appearance_variance.mean(dim=-1)  # [B, H, W]
        
        # Get maximum variance as metric
        motion_max_variance = torch.max(mean_motion_variance.view(batch_size, -1), dim=1)[0]  # [B]
        appearance_max_variance = torch.max(mean_appearance_variance.view(batch_size, -1), dim=1)[0]  # [B]
        
        return motion_max_variance, appearance_max_variance
    
    def motion_loss(self, latents: torch.Tensor) -> torch.Tensor:
        """
        Calculate motion loss for optimization.
        
        Args:
            latents: Latent representation
            
        Returns:
            Loss value
        """
        motion_max_variance, _ = self.calculate_motion_variance(latents)
        return motion_max_variance.mean()
    
    def optimize_latents(
        self,
        latents: torch.Tensor,
        timestep: torch.Tensor,
        prompt_embeds: torch.Tensor,
        old_pred_original_sample: Optional[torch.Tensor] = None,
        timestep_back: Optional[int] = None,
        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.Tensor:
        """
        Optimize latents to reduce motion variance.
        
        Args:
            latents: Current latent representation
            timestep: Current diffusion timestep
            prompt_embeds: Prompt embeddings (includes conditional and unconditional embeddings)
            old_pred_original_sample: Previous predicted original sample
            timestep_back: Previous timestep
            image_rotary_emb: Rotary positional embeddings for transformer
            attention_kwargs: Additional kwargs for attention computation
            
        Returns:
            Optimized latents
        """
        if self.transformer is None:
            logger.warning("Transformer not registered. Cannot perform optimization.")
            return latents
        
        device = latents.device
        dtype = latents.dtype
        
        # Create optimizable latents
        optimizable_latents = latents.detach().clone()
        optimizable_latents.requires_grad_(True)
        
        # Setup optimizer
        optimizer = torch.optim.Adam([optimizable_latents], lr=self.optimization_lr)
        
        # Log initial metrics
        with torch.no_grad():
            motion_variance, appearance_variance = self.calculate_motion_variance(optimizable_latents)
            
            if self.log_file:
                with open(self.log_file, "a") as f:
                    f.write(f"In timestep: {timestep.item()}\n")
                    f.write(f"Initial motion_variance: {motion_variance.item():.6f}\n")
                    f.write(f"Initial motion_appearance_variance: {appearance_variance.item():.6f}\n")
        
        # Optimization loop
        for i in range(self.optimization_steps):
            optimizer.zero_grad()
            
            # Calculate motion variance loss
            motion_var_loss = self.motion_loss(optimizable_latents)
            
            # Calculate content preservation loss to maintain fidelity
            # Use L2 distance from original latents
            preservation_loss = F.mse_loss(optimizable_latents, latents)
            
            # Combined loss
            # total_loss = preservation_loss + self.motion_weight * motion_var_loss
            total_loss = motion_var_loss
            
            # Backward and optimize
            total_loss.backward()
            optimizer.step()
            
            # Log progress
            if i == self.optimization_steps - 1:
                with torch.no_grad():
                    motion_variance, appearance_variance = self.calculate_motion_variance(optimizable_latents)
                    logger.info(f"Motion optimization step {i+1}/{self.optimization_steps}: "
                               f"motion variance: {motion_variance.item():.6f}, "
                               f"appearance variance: {appearance_variance.item():.6f}")
                    
                    if self.log_file:
                        with open(self.log_file, "a") as f:
                            f.write(f"Final motion_variance: {motion_variance.item():.6f}\n")
                            f.write(f"Final motion_appearance_variance: {appearance_variance.item():.6f}\n")
                            f.write(f"preservation_loss: {preservation_loss.item():.6f}\n")
                            f.write(f"total_loss: {total_loss.item():.6f}\n")
                            f.write("\n")
        
        # Clean up
        del optimizer
        torch.cuda.empty_cache()
        gc.collect()
        
        # Return optimized latents (detached from computation graph)
        return optimizable_latents.detach()
    
    def step(
        self,
        model_output: torch.Tensor,
        old_pred_original_sample: torch.Tensor,
        timestep: int,
        timestep_back: int,
        sample: torch.Tensor,
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
        generator=None,
        variance_noise: Optional[torch.Tensor] = None,
        return_dict: bool = True,
        # Additional parameters for optimization
        prompt_embeds: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Union[SchedulerOutput, Tuple]:
        """
        Override the step method to incorporate motion optimization.
        
        This extends the CogVideoXDPMScheduler step method with motion optimization
        applied at specific timesteps.
        """
        # Check if we should apply optimization for this timestep
        should_optimize = (
            self.optimize_motion and 
            timestep.item() in self.specific_timesteps and
            timestep.item() > 0 and  # Don't optimize the last step
            prompt_embeds is not None and
            self.transformer is not None
        )
        
        # First perform the standard denoising step from parent class
        if return_dict:
            prev_sample, pred_original_sample = super().step(
                model_output=model_output,
                old_pred_original_sample=old_pred_original_sample,
                timestep=timestep,
                timestep_back=timestep_back,
                sample=sample,
                eta=eta,
                use_clipped_model_output=use_clipped_model_output,
                generator=generator,
                variance_noise=variance_noise,
                return_dict=False,
            )
        else:
            prev_sample, pred_original_sample = super().step(
                model_output=model_output,
                old_pred_original_sample=old_pred_original_sample,
                timestep=timestep,
                timestep_back=timestep_back,
                sample=sample,
                eta=eta,
                use_clipped_model_output=use_clipped_model_output,
                generator=generator,
                variance_noise=variance_noise,
                return_dict=False,
            )
        
        # Store for future optimizations
        self.prev_latents = sample
        
        # Apply motion optimization if needed
        if should_optimize:
            logger.info(f"Applying motion optimization at timestep {timestep.item()}")
            
            # Optimize the latents to reduce motion variance
            optimized_sample = self.optimize_latents(
                prev_sample,
                timestep,
                prompt_embeds,
                old_pred_original_sample,
                timestep_back,
                image_rotary_emb,
                attention_kwargs,
            )
            
            # Use the optimized sample
            prev_sample = optimized_sample
        
        if not return_dict:
            return (prev_sample, pred_original_sample)
        
        return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)