"""
Example script for generating videos with motion-optimized CogVideoX.
This version is simplified for ease of understanding and use.
"""

import os
import argparse
import logging
import torch
from datetime import datetime
from motion_optimization_utils import (
    create_motion_optimized_pipeline,
    generate_video_with_motion_optimization,
    plot_motion_variance
)

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(description="Generate videos with motion-optimized CogVideoX")
    
    # Basic generation parameters
    parser.add_argument("--prompt", type=str, required=True, 
                        help="Text prompt for video generation")
    parser.add_argument("--negative_prompt", type=str, default=None,
                        help="Negative prompt to guide generation")
    parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", 
                        help="Path to pretrained model (e.g., THUDM/CogVideoX-5b, THUDM/CogVideoX-2b)")
    parser.add_argument("--output_dir", type=str, default="./outputs", 
                        help="Directory to save generated videos")
    parser.add_argument("--width", type=int, default=480, help="Video width")
    parser.add_argument("--height", type=int, default=720, help="Video height")
    parser.add_argument("--num_frames", type=int, default=81, 
                        help="Number of frames to generate (should be divisible by 4 plus 1)")
    parser.add_argument("--num_inference_steps", type=int, default=50, 
                        help="Number of denoising steps")
    parser.add_argument("--guidance_scale", type=float, default=6.0, 
                        help="Classifier-free guidance scale")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for generation")
    parser.add_argument("--fps", type=int, default=16, 
                        help="Frames per second for saved video")
    
    # Motion optimization parameters
    parser.add_argument("--optimize_motion", action="store_true", 
                        help="Enable motion variance optimization")
    parser.add_argument("--optimization_lr", type=float, default=0.001, 
                        help="Learning rate for motion optimization")
    parser.add_argument("--optimization_steps", type=int, default=1, 
                        help="Number of optimization steps per diffusion step")
    parser.add_argument("--motion_weight", type=float, default=0.1, 
                        help="Weight of motion loss in optimization")
    parser.add_argument("--log_dir", type=str, default="./motion_logs", 
                        help="Directory to save motion optimization logs")
    parser.add_argument("--timesteps", type=int, nargs="+", 
                        help="Custom timesteps to apply optimization (if not provided, default list will be used)")
    parser.add_argument("--comparison", action="store_true", 
                        help="Generate both optimized and non-optimized videos for comparison")
    
    args = parser.parse_args()
    
    # Validate and create output directories
    os.makedirs(args.output_dir, exist_ok=True)
    if args.optimize_motion:
        os.makedirs(args.log_dir, exist_ok=True)
    
    return args

def setup_output_paths(args, prompt, seed, optimized=True):
    """Set up output paths for video and plot"""
    # Format prompt for filename
    formatted_prompt = prompt.replace(" ", "_").replace("/", "_")[:50]
    formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create optimization-specific directory if needed
    if args.optimize_motion and optimized:
        lr_dir = os.path.join(args.output_dir, f"lr_{args.optimization_lr}")
        os.makedirs(lr_dir, exist_ok=True)
        
        # Video path
        video_path = os.path.join(
            lr_dir, 
            f"{formatted_prompt}_{seed}_{formatted_time}_optimized.mp4"
        )
        
        # Plot path
        plot_path = os.path.join(
            args.log_dir,
            f"{args.optimization_lr}",
            f"{formatted_prompt}_{seed}_plot.png"
        )
    else:
        # Regular output path
        video_path = os.path.join(
            args.output_dir, 
            f"{formatted_prompt}_{seed}_{formatted_time}.mp4"
        )
        plot_path = None
    
    return video_path, plot_path

def main():
    args = parse_args()
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    # Set up specific timesteps
    specific_timesteps = args.timesteps if args.timesteps else None
    
    # Single prompt mode
    prompt = args.prompt
    negative_prompt = args.negative_prompt
    seed = args.seed
    
    if args.comparison:
        # Generate both optimized and non-optimized videos for comparison
        logger.info("Comparison mode: generating both optimized and non-optimized videos")
        
        # First generate non-optimized video
        logger.info("Generating non-optimized video...")
        non_opt_video_path, _ = setup_output_paths(args, prompt, seed, optimized=False)
        
        non_opt_pipeline = create_motion_optimized_pipeline(
            model_path=args.model_path,
            optimize_motion=False,  # Disable optimization
            torch_dtype=torch.float16,
            device=device,
        )
        
        generate_video_with_motion_optimization(
            pipeline=non_opt_pipeline,
            prompt=prompt,
            negative_prompt=negative_prompt,
            seed=seed,
            height=args.height,
            width=args.width,
            num_frames=args.num_frames,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale,
            output_path=non_opt_video_path,
            fps=args.fps,
        )
        
        # Clean up
        del non_opt_pipeline
        torch.cuda.empty_cache()
        
        # Then generate optimized video
        logger.info("Generating motion-optimized video...")
        opt_video_path, plot_path = setup_output_paths(args, prompt, seed, optimized=True)
        
        opt_pipeline = create_motion_optimized_pipeline(
            model_path=args.model_path,
            optimize_motion=True,
            optimization_lr=args.optimization_lr,
            optimization_steps=args.optimization_steps,
            motion_weight=args.motion_weight,
            specific_timesteps=specific_timesteps,
            log_dir=args.log_dir,
            torch_dtype=torch.float16,
            device=device,
        )
        
        generate_video_with_motion_optimization(
            pipeline=opt_pipeline,
            prompt=prompt,
            negative_prompt=negative_prompt,
            seed=seed,
            height=args.height,
            width=args.width,
            num_frames=args.num_frames,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale,
            output_path=opt_video_path,
            fps=args.fps,
        )
        
        # Create variance plot
        if plot_path:
            log_file = os.path.join(
                args.log_dir,
                f"{args.optimization_lr}",
                f"{prompt.replace(' ', '_').replace('/', '_')[:50]}_{seed}.txt"
            )
            if os.path.exists(log_file):
                plot_motion_variance(log_file, plot_path, seed)
        
        # Clean up
        del opt_pipeline
        torch.cuda.empty_cache()
        
        logger.info(f"Non-optimized video saved to: {non_opt_video_path}")
        logger.info(f"Optimized video saved to: {opt_video_path}")
    
    else:
        # Generate a single video (optimized or not)
        video_path, plot_path = setup_output_paths(args, prompt, seed)
        
        pipeline = create_motion_optimized_pipeline(
            model_path=args.model_path,
            optimize_motion=args.optimize_motion,
            optimization_lr=args.optimization_lr,
            optimization_steps=args.optimization_steps,
            motion_weight=args.motion_weight,
            specific_timesteps=specific_timesteps,
            log_dir=args.log_dir,
            torch_dtype=torch.float16,
            device=device,
        )
        
        generate_video_with_motion_optimization(
            pipeline=pipeline,
            prompt=prompt,
            negative_prompt=negative_prompt,
            seed=seed,
            height=args.height,
            width=args.width,
            num_frames=args.num_frames,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale,
            output_path=video_path,
            fps=args.fps,
        )
        
        # Create variance plot if optimization was enabled
        if args.optimize_motion and plot_path:
            log_file = os.path.join(
                args.log_dir,
                f"{args.optimization_lr}",
                f"{prompt.replace(' ', '_').replace('/', '_')[:50]}_{seed}.txt"
            )
            if os.path.exists(log_file):
                plot_motion_variance(log_file, plot_path, seed)
        
        logger.info(f"Video saved to: {video_path}")
    
    logger.info("Processing complete!")

if __name__ == "__main__":
    main()