"""
Utility functions to add motion optimization to the CogVideoX pipeline.
"""

import os
import logging
from typing import List, Optional, Dict, Any, Tuple
import torch
from diffusers import CogVideoXPipeline
from motion_optimized_cogvideox_scheduler import MotionOptimizedCogVideoXScheduler

logger = logging.getLogger(__name__)

def create_motion_optimized_pipeline(
    model_path: str,
    optimize_motion: bool = True,
    optimization_lr: float = 0.001,
    optimization_steps: int = 1,
    motion_weight: float = 0.1,
    specific_timesteps: Optional[List[int]] = None,
    log_dir: str = "./motion_logs",
    torch_dtype: torch.dtype = torch.float16,
    device: Optional[torch.device] = None,
):
    """
    Create a CogVideoX pipeline with motion-optimized scheduler.
    
    Args:
        model_path: Path to the pretrained model
        optimize_motion: Whether to apply motion optimization
        optimization_lr: Learning rate for optimization
        optimization_steps: Number of optimization steps per diffusion step
        motion_weight: Weight of motion loss in optimization
        specific_timesteps: List of timesteps to apply optimization at (if None, use defaults)
        log_dir: Directory to save motion optimization logs
        torch_dtype: Data type for model components
        device: Device to load model on
    
    Returns:
        CogVideoX pipeline with motion-optimized scheduler
    """
    # Create the standard pipeline
    pipeline = CogVideoXPipeline.from_pretrained(
        model_path,
        torch_dtype=torch_dtype,
    )
    
    if device is not None:
        pipeline = pipeline.to(device)
    
    # Use default timesteps if none provided
    if specific_timesteps is None:
        specific_timesteps = [999, 995, 991, 987, 982, 978, 973, 968, 963, 957, 952, 946]
    
    # Create and configure motion-optimized scheduler
    motion_scheduler = MotionOptimizedCogVideoXScheduler.from_config(
        pipeline.scheduler.config,
        optimize_motion=optimize_motion,
        optimization_lr=optimization_lr,
        optimization_steps=optimization_steps,
        motion_weight=motion_weight,
        specific_timesteps=specific_timesteps,
    )
    
    # Register the transformer with the scheduler
    motion_scheduler.register_transformer(pipeline.transformer)
    
    # Replace the standard scheduler with our motion-optimized one
    pipeline.scheduler = motion_scheduler
    
    logger.info(f"Created motion-optimized pipeline with lr={optimization_lr}, steps={optimization_steps}")
    if optimize_motion:
        logger.info(f"Motion optimization will be applied at timesteps: {specific_timesteps}")
    
    # Add a method to the pipeline to make it easier to get needed state from it
    def _extract_pipeline_params(self, latent_model_input, t, prompt_embeds, attention_kwargs=None, image_rotary_emb=None):
        """Extract parameters needed for optimization from pipeline"""
        return {
            'prompt_embeds': prompt_embeds,
            'image_rotary_emb': image_rotary_emb,
            'attention_kwargs': attention_kwargs,
        }
    
    # Monkey patch the _extract_pipeline_params method onto the pipeline
    pipeline._extract_pipeline_params = _extract_pipeline_params.__get__(pipeline)
    
    # Create function to set up logging for a specific generation
    def prepare_optimization_logging(self, prompt, seed, prompt_embeds=None):
        if not self.scheduler.optimize_motion:
            return
        
        # Create log directory
        os.makedirs(log_dir, exist_ok=True)
        lr_dir = os.path.join(log_dir, f"{self.scheduler.optimization_lr}")
        os.makedirs(lr_dir, exist_ok=True)
        
        # Format the prompt for filename
        formatted_prompt = prompt.replace(" ", "_").replace("/", "_")[:50]
        log_file = os.path.join(lr_dir, f"{formatted_prompt}_{seed}.txt")
        
        # Set log file in scheduler
        self.scheduler.log_file = log_file
        self.scheduler.register_prompt_info(prompt, seed, prompt_embeds)
        
        logger.info(f"Motion optimization logs will be saved to: {log_file}")
        return log_file
    
    # Monkey patch the prepare_optimization_logging method onto the pipeline
    pipeline.prepare_optimization_logging = prepare_optimization_logging.__get__(pipeline)
    
    # Create a custom __call__ method that enables motion optimization
    original_call = pipeline.__call__
    
    def motion_optimized_call(self, *args, **kwargs):
        """
        Override the pipeline's __call__ method to incorporate motion optimization parameters
        in the scheduler's step method.
        """
        # Store parameters we need to pass to the scheduler
        pipeline_params = {}
        
        # Call the original pipeline method, but intercept the step calls
        original_step = self.scheduler.step
        
        def step_interceptor(*step_args, **step_kwargs):
            """Intercept step calls to add motion optimization parameters"""
            # Inject prompt_embeds and other needed params if available
            if hasattr(self, "_current_model_params"):
                step_kwargs.update(self._current_model_params)
            
            return original_step(*step_args, **step_kwargs)
        
        # Replace the step method temporarily
        self.scheduler.step = step_interceptor
        
        # Track the current model parameters - these will be updated during denoising
        self._current_model_params = {}
        
        # Hook for capturing parameters during the pipeline run
        def capture_params_hook(module, input, output):
            # Get prompt_embeds and other parameters
            if isinstance(input, tuple) and len(input) >= 3:
                latent_model_input = input[0]
                t = input[1]
                encoder_hidden_states = input[2]
                
                # Get image_rotary_emb if provided (usually as attention_kwargs)
                kwargs = {}
                if len(input) >= 4 and isinstance(input[3], dict):
                    kwargs = input[3]
                
                # Update current model parameters
                self._current_model_params = {
                    'prompt_embeds': encoder_hidden_states,
                }
                
                # Add image_rotary_emb and attention_kwargs if they exist
                if 'image_rotary_emb' in kwargs:
                    self._current_model_params['image_rotary_emb'] = kwargs['image_rotary_emb']
                if 'attention_kwargs' in kwargs:
                    self._current_model_params['attention_kwargs'] = kwargs['attention_kwargs']
            
            return output
        
        # Register the hook
        hook_handle = self.transformer.register_forward_hook(capture_params_hook)
        
        # If prepare_optimization_logging is called in the original pipeline
        if 'prompt' in kwargs and isinstance(kwargs['prompt'], str):
            # Setup logging with prompt info
            prompt = kwargs['prompt']
            # Extract or generate seed
            seed = None
            if 'generator' in kwargs and kwargs['generator'] is not None:
                if isinstance(kwargs['generator'], torch.Generator):
                    seed = kwargs['generator'].initial_seed()
                elif isinstance(kwargs['generator'], list) and len(kwargs['generator']) > 0:
                    seed = kwargs['generator'][0].initial_seed()
            
            if seed is not None:
                self.prepare_optimization_logging(prompt, seed)
        
        try:
            # Call the original pipeline method
            result = original_call(self, *args, **kwargs)
        finally:
            # Remove the hook and restore the original step method
            hook_handle.remove()
            self.scheduler.step = original_step
            if hasattr(self, "_current_model_params"):
                del self._current_model_params
        
        return result
    
    # Replace the pipeline's __call__ method if motion optimization is enabled
    if optimize_motion:
        pipeline.__call__ = motion_optimized_call.__get__(pipeline)
    
    return pipeline

def generate_video_with_motion_optimization(
    pipeline,
    prompt: str,
    seed: int = 42,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_frames: int = 81,
    num_inference_steps: int = 50,
    guidance_scale: float = 6.0,
    negative_prompt: Optional[str] = None,
    output_path: Optional[str] = None,
    fps: int = 16,
    **additional_kwargs,
):
    """
    Generate a video with motion optimization.
    
    Args:
        pipeline: Motion-optimized CogVideoX pipeline
        prompt: Text prompt for generation
        seed: Random seed
        height: Video height (default is model default)
        width: Video width (default is model default)
        num_frames: Number of frames to generate
        num_inference_steps: Number of denoising steps
        guidance_scale: Classifier-free guidance scale
        negative_prompt: Negative prompt for guidance
        output_path: Path to save the video (if None, video is not saved)
        fps: Frames per second for saved video
        **additional_kwargs: Additional arguments to pass to the pipeline
    
    Returns:
        Generated video frames
    """
    # Set up generator for reproducibility
    device = pipeline.device
    generator = torch.Generator(device=device).manual_seed(seed)
    
    # Generate the video
    logger.info(f"Generating video with prompt: '{prompt}'")
    logger.info(f"Using seed: {seed}")
    
    output = pipeline(
        prompt=prompt,
        height=height,
        width=width,
        num_frames=num_frames,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        negative_prompt=negative_prompt,
        generator=generator,
        **additional_kwargs,
    )
    
    # Save the video if output path provided
    if output_path is not None:
        from diffusers.utils import export_to_video
        video_frames = output.frames[0]
        export_to_video(video_frames, output_path, fps=fps)
        logger.info(f"Video saved to {output_path}")
    
    return output.frames

def plot_motion_variance(log_file, output_image_path, seed=None):
    """
    Generate a plot of motion variance optimization from a log file.
    
    Args:
        log_file: Path to the log file
        output_image_path: Path to save the plot
        seed: Seed value to display in the plot
    """
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    
    # Initialize data lists
    timesteps = []
    initial_motion_variance = []
    final_motion_variance = []
    initial_appearance_variance = []
    final_appearance_variance = []
    
    # Read the log file
    with open(log_file, "r") as file:
        current_timestep = None
        
        for line in file:
            line = line.strip()
            if line.startswith("In timestep:"):
                current_timestep = int(line.split(": ")[1])
                timesteps.append(current_timestep)
            elif line.startswith("Initial motion_variance:"):
                initial_motion_variance.append(float(line.split(": ")[1]))
            elif line.startswith("Initial motion_appearance_variance:"):
                initial_appearance_variance.append(float(line.split(": ")[1]))
            elif line.startswith("Final motion_variance:"):
                final_motion_variance.append(float(line.split(": ")[1]))
            elif line.startswith("Final motion_appearance_variance:"):
                final_appearance_variance.append(float(line.split(": ")[1]))
    
    # Create DataFrame for plotting
    df = pd.DataFrame({
        "Timestep": timesteps,
        "Initial Motion Variance": initial_motion_variance,
        "Final Motion Variance": final_motion_variance,
        "Initial Appearance Variance": initial_appearance_variance,
        "Final Appearance Variance": final_appearance_variance
    })
    
    # Sort by timestep (descending)
    df = df.sort_values("Timestep", ascending=False)
    
    # Create plot
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
    
    # Motion variance plot
    ax1.plot(df["Timestep"], df["Initial Motion Variance"], label="Before Optimization", 
             marker="o", linestyle="-", color="blue")
    ax1.plot(df["Timestep"], df["Final Motion Variance"], label="After Optimization", 
             marker="s", linestyle="-", color="green")
    ax1.set_title("Motion Variance Over Time")
    ax1.set_ylabel("Variance")
    ax1.grid(True)
    ax1.legend()
    
    # Calculate and display reduction percentage
    avg_reduction = ((df["Initial Motion Variance"] - df["Final Motion Variance"]) / 
                    df["Initial Motion Variance"] * 100).mean()
    ax1.text(0.02, 0.95, f"Avg reduction: {avg_reduction:.2f}%", 
            transform=ax1.transAxes, fontsize=10, 
            bbox=dict(facecolor='white', alpha=0.7))
    
    # Appearance variance plot
    ax2.plot(df["Timestep"], df["Initial Appearance Variance"], label="Before Optimization", 
             marker="o", linestyle="-", color="purple")
    ax2.plot(df["Timestep"], df["Final Appearance Variance"], label="After Optimization", 
             marker="s", linestyle="-", color="red")
    ax2.set_title("Appearance Variance Over Time")
    ax2.set_xlabel("Timestep")
    ax2.set_ylabel("Variance")
    ax2.grid(True)
    ax2.legend()
    
    # Add seed info if provided
    if seed is not None:
        fig.text(0.15, 0.95, f"Seed: {seed}", fontsize=12, 
                bbox=dict(facecolor="white", alpha=0.7))
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_image_path, dpi=300, bbox_inches="tight")
    logger.info(f"Variance plot saved to {output_image_path}")
    
    return fig