#!/usr/bin/env python3
"""
LatentSync batch inference script - evaluate lip-sync quality

Measure the synchronization between lip movements and audio in videos

Usage example:
python batch_inference_lipsync.py \
    --video_dir /path/to/videos \
    --output_file results_lipsync.json \
    --device cuda
"""

import argparse
import json
import os
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict, Any
from datetime import datetime
from tqdm import tqdm
from statistics import fmean
import warnings

from eval.syncnet import SyncNetEval
from eval.syncnet_detect import SyncNetDetector
from latentsync.utils.util import red_text

warnings.filterwarnings('ignore')


def find_video_files(video_dir: str, extensions: List[str] = None) -> List[str]:
    """Find video files"""
    if extensions is None:
        extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.m4v']
    
    video_dir = Path(video_dir)
    video_files = []
    
    for ext in extensions:
        video_files.extend(list(video_dir.glob(f"*{ext}")))
        video_files.extend(list(video_dir.glob(f"*{ext.upper()}")))
    
    return sorted([str(f) for f in set(video_files)])


def evaluate_single_video(
    video_path: str,
    syncnet: SyncNetEval,
    syncnet_detector: SyncNetDetector,
    temp_dir: str = "temp_lipsync",
    detect_results_dir: str = "detect_results_lipsync"
) -> Dict[str, Any]:
    """Evaluate lip-sync for a single video"""
    
    result = {
        'video_path': video_path,
        'video_name': Path(video_path).name,
        'success': False,
        'error': None,
        'sync_confidence': None,
        'av_offset': None
    }
    
    try:
        # Detect faces and crop
        syncnet_detector(video_path=video_path, min_track=50)
        
        crop_dir = os.path.join(detect_results_dir, "crop")
        if not os.path.exists(crop_dir) or not os.listdir(crop_dir):
            raise Exception(f"Face not detected in {video_path}")
        
        crop_videos = os.listdir(crop_dir)
        
        # Evaluate each cropped video segment
        av_offset_list = []
        conf_list = []
        
        for video in crop_videos:
            crop_video_path = os.path.join(crop_dir, video)
            av_offset, _, conf = syncnet.evaluate(
                video_path=crop_video_path,
                temp_dir=temp_dir
            )
            av_offset_list.append(av_offset)
            conf_list.append(conf)
        
        # Calculate average
        result['av_offset'] = int(fmean(av_offset_list))
        result['sync_confidence'] = float(fmean(conf_list))
        result['num_faces'] = len(crop_videos)
        result['success'] = True
        
    except Exception as e:
        result['error'] = str(e)
        result['success'] = False
    
    return result


def calculate_statistics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Calculate statistics"""
    stats = {
        'timestamp': datetime.now().isoformat(),
        'total_count': len(results),
        'successful_count': 0,
        'failed_count': 0
    }
    
    # Filter successful results
    successful_results = [r for r in results if r.get('success', False)]
    stats['successful_count'] = len(successful_results)
    stats['failed_count'] = len(results) - len(successful_results)
    
    if not successful_results:
        return stats
    
    # Extract sync confidence
    sync_confidences = [r['sync_confidence'] for r in successful_results 
                       if r.get('sync_confidence') is not None]
    
    if sync_confidences:
        stats['sync_confidence'] = {
            'mean': float(np.mean(sync_confidences)),
            'std': float(np.std(sync_confidences)),
            'min': float(np.min(sync_confidences)),
            'max': float(np.max(sync_confidences)),
            'median': float(np.median(sync_confidences))
        }
    
    # Extract AV offset
    av_offsets = [r['av_offset'] for r in successful_results 
                 if r.get('av_offset') is not None]
    
    if av_offsets:
        stats['av_offset'] = {
            'mean': float(np.mean(av_offsets)),
            'std': float(np.std(av_offsets)),
            'min': int(np.min(av_offsets)),
            'max': int(np.max(av_offsets)),
            'median': float(np.median(av_offsets))
        }
    
    return stats


def batch_inference(
    video_dir: str,
    model_path: str,
    device: str = 'cuda',
    temp_dir: str = "temp_lipsync",
    detect_results_dir: str = "detect_results_lipsync"
) -> List[Dict[str, Any]]:
    """Batch inference"""
    
    # Check device
    if device == 'cuda' and not torch.cuda.is_available():
        print("Warning: CUDA not available, using CPU")
        device = "cpu"
    
    # Load SyncNet model
    print(f"Loading SyncNet model (device: {device})...")
    if not os.path.exists(model_path):
        print(f"Warning: Model file not found: {model_path}")
        print("Trying to use default model path...")
    
    syncnet = SyncNetEval(device=device)
    syncnet.loadParameters(model_path)
    print("Model loaded!\n")
    
    # Initialize face detector
    syncnet_detector = SyncNetDetector(device=device, detect_results_dir=detect_results_dir)
    
    # Find video files
    video_files = find_video_files(video_dir)
    
    if not video_files:
        print(f"Error: No video files found in {video_dir}")
        return []
    
    print(f"Found {len(video_files)} video files\n")
    
    # Batch processing
    results = []
    
    for video_path in tqdm(video_files, desc="Evaluating lip-sync"):
        result = evaluate_single_video(
            video_path,
            syncnet,
            syncnet_detector,
            temp_dir,
            detect_results_dir
        )
        results.append(result)
        
        # Clean temporary files
        if os.path.exists(detect_results_dir):
            import shutil
            shutil.rmtree(detect_results_dir)
        if os.path.exists(temp_dir):
            import shutil
            shutil.rmtree(temp_dir)
    
    return results


def main():
    parser = argparse.ArgumentParser(
        description='LatentSync batch lip-sync evaluation',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    
    parser.add_argument(
        '--video_dir',
        type=str,
        required=True,
        help='Directory containing video files'
    )
    parser.add_argument(
        '--model_path',
        type=str,
        default='checkpoints/auxiliary/syncnet_v2.model',
        help='SyncNet model file path (default: checkpoints/auxiliary/syncnet_v2.model)'
    )
    parser.add_argument(
        '--output_file',
        type=str,
        default='results_lipsync.json',
        help='Output JSON file path (default: results_lipsync.json)'
    )
    parser.add_argument(
        '--device',
        type=str,
        default='cuda',
        choices=['cuda', 'cpu'],
        help='Compute device (default: cuda)'
    )
    parser.add_argument(
        '--temp_dir',
        type=str,
        default='temp_lipsync',
        help='Temporary file directory (default: temp_lipsync)'
    )
    
    args = parser.parse_args()
    
    # Validate input directory
    if not os.path.exists(args.video_dir):
        print(f"Error: Video directory does not exist: {args.video_dir}")
        return
    
    # Batch inference
    results = batch_inference(
        video_dir=args.video_dir,
        model_path=args.model_path,
        device=args.device,
        temp_dir=args.temp_dir
    )
    
    if not results:
        print("No videos processed successfully")
        return
    
    # Calculate statistics
    statistics = calculate_statistics(results)
    
    # Save results
    output_data = {
        'config': {
            'video_dir': args.video_dir,
            'model_path': args.model_path,
            'device': args.device
        },
        'statistics': statistics,
        'results': results
    }
    
    output_path = Path(args.output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)
    
    print(f"\nResults saved to: {args.output_file}")
    
    # Print statistics
    print(f"\n{'='*80}")
    print("Statistics Summary")
    print(f"{'='*80}")
    print(f"Total: {statistics['total_count']}")
    print(f"Successful: {statistics['successful_count']}")
    print(f"Failed: {statistics['failed_count']}")
    
    if 'sync_confidence' in statistics:
        sync_stats = statistics['sync_confidence']
        print(f"\nLip-sync Confidence:")
        print(f"  Mean: {sync_stats['mean']:.4f}")
        print(f"  Std: {sync_stats['std']:.4f}")
        print(f"  Min: {sync_stats['min']:.4f}")
        print(f"  Max: {sync_stats['max']:.4f}")
        print(f"  Median: {sync_stats['median']:.4f}")
        print(f"\nNote: Higher confidence indicates better lip-sync quality")
        print(f"      Typical range: 0-10, recommended > 3 for good sync")
    
    if 'av_offset' in statistics:
        offset_stats = statistics['av_offset']
        print(f"\nAudio-Video Offset (frames):")
        print(f"  Mean: {offset_stats['mean']:.2f}")
        print(f"  Std: {offset_stats['std']:.2f}")
        print(f"  Min: {offset_stats['min']}")
        print(f"  Max: {offset_stats['max']}")
        print(f"  Median: {offset_stats['median']:.2f}")
        print(f"\nNote: Values closer to 0 indicate better synchronization")


if __name__ == '__main__':
    main()
