#!/usr/bin/env python3
"""
Aesthetic Predictor V2.5 Batch Inference Script (Enhanced)
"""

import argparse
import json
import os
from pathlib import Path
from typing import List, Dict, Optional
from tqdm import tqdm

import torch
import cv2
import numpy as np
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
from PIL import Image


def load_path_list(path_file: str) -> List[str]:
    with open(path_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    return [line.strip() for line in lines if line.strip()]


def get_files_from_dir(directory: str, extensions: List[str]) -> List[str]:
    directory = Path(directory)
    files = []
    for ext in extensions:
        files.extend(list(directory.glob(f"**/*.{ext}")))
        files.extend(list(directory.glob(f"**/*.{ext.upper()}")))
    files = sorted(list(set([str(f) for f in files])))
    return files


def load_image_safely(image_path: str) -> Optional[Image.Image]:
    try:
        image = Image.open(image_path).convert("RGB")
        return image
    except Exception as e:
        print(f"Warning: Cannot load image {image_path}: {e}")
        return None


def extract_keyframes_scene_change(
    video_path: str,
    max_frames: int = 10,
    threshold: float = 30.0,
    min_interval: int = 10
) -> List[Image.Image]:
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Warning: Cannot open video {video_path}")
            return []
        
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames == 0:
            print(f"Warning: Video {video_path} has no frames")
            cap.release()
            return []
        
        keyframes = []
        keyframe_indices = []
        prev_frame = None
        frame_idx = 0
        last_keyframe_idx = -min_interval
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            
            if prev_frame is None:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                keyframes.append(Image.fromarray(frame_rgb))
                keyframe_indices.append(frame_idx)
                last_keyframe_idx = frame_idx
                prev_frame = gray
                frame_idx += 1
                continue
            
            hist_prev = cv2.calcHist([prev_frame], [0], None, [256], [0, 256])
            hist_curr = cv2.calcHist([gray], [0], None, [256], [0, 256])
            
            hist_prev = cv2.normalize(hist_prev, hist_prev).flatten()
            hist_curr = cv2.normalize(hist_curr, hist_curr).flatten()
            
            diff = cv2.compareHist(hist_prev, hist_curr, cv2.HISTCMP_BHATTACHARYYA)
            diff_score = diff * 100
            
            if diff_score > threshold and (frame_idx - last_keyframe_idx) >= min_interval:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                keyframes.append(Image.fromarray(frame_rgb))
                keyframe_indices.append(frame_idx)
                last_keyframe_idx = frame_idx
                prev_frame = gray
                
                if len(keyframes) >= max_frames:
                    break
            
            frame_idx += 1
        
        cap.release()
        
        if len(keyframes) < max_frames // 2:
            print(f"Warning: Only extracted {len(keyframes)} frames, switching to uniform sampling")
            return extract_keyframes_uniform(video_path, max_frames)
        
        return keyframes
    
    except Exception as e:
        print(f"Warning: Scene change detection failed {video_path}: {e}")
        return []


def extract_keyframes_frame_diff(
    video_path: str,
    max_frames: int = 10,
    threshold: float = 25.0,
    sample_interval: int = 5
) -> List[Image.Image]:
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Warning: Cannot open video {video_path}")
            return []
        
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames == 0:
            print(f"Warning: Video {video_path} has no frames")
            cap.release()
            return []
        
        keyframes = []
        prev_frame = None
        frame_idx = 0
        
        while len(keyframes) < max_frames:
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_idx % sample_interval != 0:
                frame_idx += 1
                continue
            
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            
            if prev_frame is None:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                keyframes.append(Image.fromarray(frame_rgb))
                prev_frame = gray
                frame_idx += 1
                continue
            
            frame_diff = cv2.absdiff(prev_frame, gray)
            mean_diff = np.mean(frame_diff)
            
            if mean_diff > threshold:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                keyframes.append(Image.fromarray(frame_rgb))
                prev_frame = gray
            
            frame_idx += 1
        
        cap.release()
        
        if len(keyframes) < max_frames // 2:
            print(f"Warning: Only extracted {len(keyframes)} frames, switching to uniform sampling")
            return extract_keyframes_uniform(video_path, max_frames)
        
        return keyframes
    
    except Exception as e:
        print(f"Warning: Frame diff detection failed {video_path}: {e}")
        return []


def extract_keyframes_uniform(
    video_path: str,
    num_frames: int = 10
) -> List[Image.Image]:
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Warning: Cannot open video {video_path}")
            return []
        
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames == 0:
            print(f"Warning: Video {video_path} has no frames")
            cap.release()
            return []
        
        frames = []
        
        if num_frames >= total_frames:
            frame_indices = list(range(total_frames))
        else:
            frame_indices = [
                int(i * total_frames / num_frames) 
                for i in range(num_frames)
            ]
        
        for frame_idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(frame_rgb))
        
        cap.release()
        return frames
    
    except Exception as e:
        print(f"Warning: Uniform extraction failed {video_path}: {e}")
        return []


def extract_keyframes_interval(
    video_path: str,
    num_frames: int = 10,
    frame_interval: float = 1.0
) -> List[Image.Image]:
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Warning: Cannot open video {video_path}")
            return []
        
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        
        if total_frames == 0 or fps <= 0:
            print(f"Warning: Video {video_path} has invalid info")
            cap.release()
            return []
        
        frames = []
        frame_step = int(fps * frame_interval)
        frame_step = max(1, frame_step)
        
        frame_indices = []
        current_frame = 0
        while current_frame < total_frames and len(frame_indices) < num_frames:
            frame_indices.append(current_frame)
            current_frame += frame_step
        
        for frame_idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(frame_rgb))
        
        cap.release()
        return frames
    
    except Exception as e:
        print(f"Warning: Interval extraction failed {video_path}: {e}")
        return []


def extract_keyframes(
    video_path: str,
    num_frames: int = 10,
    frame_interval: Optional[float] = None,
    method: str = "uniform",
    scene_threshold: float = 30.0,
    diff_threshold: float = 25.0
) -> List[Image.Image]:
    if method == "scene_change":
        return extract_keyframes_scene_change(
            video_path,
            max_frames=num_frames,
            threshold=scene_threshold,
            min_interval=max(1, int(num_frames * 0.5))
        )
    elif method == "frame_diff":
        return extract_keyframes_frame_diff(
            video_path,
            max_frames=num_frames,
            threshold=diff_threshold,
            sample_interval=max(1, 5)
        )
    elif method == "interval":
        if frame_interval is None:
            frame_interval = 1.0
        return extract_keyframes_interval(
            video_path,
            num_frames=num_frames,
            frame_interval=frame_interval
        )
    else:
        return extract_keyframes_uniform(
            video_path,
            num_frames=num_frames
        )


class AestheticPredictorBatch:
    def __init__(self, device: Optional[str] = None):
        print("Loading model...")
        self.model, self.preprocessor = convert_v2_5_from_siglip(
            low_cpu_mem_usage=True,
            trust_remote_code=True,
        )
        
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.device = device
        if device == "cuda":
            self.model = self.model.to(torch.bfloat16).cuda()
        else:
            self.model = self.model.to(torch.float32)
        
        print(f"Model loaded to device: {device}")

    def predict_single(self, image_path: str) -> Optional[float]:
        image = load_image_safely(image_path)
        if image is None:
            return None
        
        pixel_values = self.preprocessor(
            images=image, return_tensors="pt"
        ).pixel_values
        
        if self.device == "cuda":
            pixel_values = pixel_values.to(torch.bfloat16).cuda()
        else:
            pixel_values = pixel_values.to(torch.float32)
        
        with torch.inference_mode():
            score = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
        
        return float(score)
    
    def predict_image(self, image: Image.Image) -> Optional[float]:
        try:
            pixel_values = self.preprocessor(
                images=image, return_tensors="pt"
            ).pixel_values
            
            if self.device == "cuda":
                pixel_values = pixel_values.to(torch.bfloat16).cuda()
            else:
                pixel_values = pixel_values.to(torch.float32)
            
            with torch.inference_mode():
                score = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
            
            return float(score)
        except Exception as e:
            print(f"Warning: Image prediction failed: {e}")
            return None
    
    def predict_video(
        self,
        video_path: str,
        num_frames: int = 10,
        frame_interval: Optional[float] = None,
        frame_extraction_method: str = "uniform",
        scene_threshold: float = 30.0,
        diff_threshold: float = 25.0
    ) -> Optional[Dict]:
        keyframes = extract_keyframes(
            video_path,
            num_frames=num_frames,
            frame_interval=frame_interval,
            method=frame_extraction_method,
            scene_threshold=scene_threshold,
            diff_threshold=diff_threshold
        )
        
        if not keyframes:
            return None
        
        frame_scores = []
        for frame in keyframes:
            score = self.predict_image(frame)
            if score is not None:
                frame_scores.append(score)
        
        if not frame_scores:
            return None
        
        mean_score = float(sum(frame_scores) / len(frame_scores))
        min_score = float(min(frame_scores))
        max_score = float(max(frame_scores))
        std_score = float(torch.tensor(frame_scores).std().item()) if len(frame_scores) > 1 else 0.0
        
        return {
            "video_path": str(video_path),
            "num_keyframes": len(keyframes),
            "num_scored_frames": len(frame_scores),
            "aesthetic_score": mean_score,
            "frame_scores": frame_scores,
            "min_frame_score": min_score,
            "max_frame_score": max_score,
            "std_frame_score": std_score
        }

    def predict_batch(
        self,
        image_paths: List[str],
        output_dir: Optional[str] = None,
        save_results: bool = True
    ) -> Dict:
        results = []
        failed_images = []
        
        print(f"Processing {len(image_paths)} images...")
        
        for image_path in tqdm(image_paths, desc="Processing images"):
            score = self.predict_single(image_path)
            
            if score is not None:
                results.append({
                    "image_path": str(image_path),
                    "aesthetic_score": score
                })
            else:
                failed_images.append(str(image_path))
        
        if results:
            scores = [r["aesthetic_score"] for r in results]
            stats = {
                "total_images": len(image_paths),
                "successful": len(results),
                "failed": len(failed_images),
                "mean_score": float(sum(scores) / len(scores)),
                "min_score": float(min(scores)),
                "max_score": float(max(scores)),
                "std_score": float(torch.tensor(scores).std().item()) if len(scores) > 1 else 0.0
            }
        else:
            stats = {
                "total_images": len(image_paths),
                "successful": 0,
                "failed": len(failed_images),
                "mean_score": 0.0,
                "min_score": 0.0,
                "max_score": 0.0,
                "std_score": 0.0
            }
        
        output = {
            "statistics": stats,
            "results": results,
            "failed_images": failed_images
        }
        
        if save_results and output_dir:
            os.makedirs(output_dir, exist_ok=True)
            
            results_file = os.path.join(output_dir, "results.json")
            with open(results_file, 'w', encoding='utf-8') as f:
                json.dump(output, f, indent=2, ensure_ascii=False)
            print(f"\nResults saved to: {results_file}")
            
            metrics_file = os.path.join(output_dir, "metrics.json")
            with open(metrics_file, 'w', encoding='utf-8') as f:
                json.dump(stats, f, indent=2, ensure_ascii=False)
            print(f"Metrics saved to: {metrics_file}")
            
            if failed_images:
                failed_file = os.path.join(output_dir, "failed_images.txt")
                with open(failed_file, 'w', encoding='utf-8') as f:
                    for img in failed_images:
                        f.write(f"{img}\n")
                print(f"Failed images list saved to: {failed_file}")
        
        return output
    
    def predict_video_batch(
        self,
        video_paths: List[str],
        output_dir: Optional[str] = None,
        save_results: bool = True,
        num_frames: int = 10,
        frame_interval: Optional[float] = None,
        frame_extraction_method: str = "uniform",
        scene_threshold: float = 30.0,
        diff_threshold: float = 25.0
    ) -> Dict:
        results = []
        failed_videos = []
        
        print(f"Processing {len(video_paths)} videos...")
        print(f"Extraction method: {frame_extraction_method}")
        print(f"Target keyframes: {num_frames}")
        if frame_extraction_method == "scene_change":
            print(f"Scene threshold: {scene_threshold}")
        elif frame_extraction_method == "frame_diff":
            print(f"Diff threshold: {diff_threshold}")
        elif frame_interval:
            print(f"Frame interval: {frame_interval} seconds")
        
        for video_path in tqdm(video_paths, desc="Processing videos"):
            result = self.predict_video(
                video_path,
                num_frames=num_frames,
                frame_interval=frame_interval,
                frame_extraction_method=frame_extraction_method,
                scene_threshold=scene_threshold,
                diff_threshold=diff_threshold
            )
            
            if result is not None:
                results.append(result)
            else:
                failed_videos.append(str(video_path))
        
        if results:
            scores = [r["aesthetic_score"] for r in results]
            stats = {
                "total_videos": len(video_paths),
                "successful": len(results),
                "failed": len(failed_videos),
                "mean_score": float(sum(scores) / len(scores)),
                "min_score": float(min(scores)),
                "max_score": float(max(scores)),
                "std_score": float(torch.tensor(scores).std().item()) if len(scores) > 1 else 0.0,
                "num_frames_per_video": num_frames,
                "frame_interval": frame_interval,
                "extraction_method": frame_extraction_method
            }
        else:
            stats = {
                "total_videos": len(video_paths),
                "successful": 0,
                "failed": len(failed_videos),
                "mean_score": 0.0,
                "min_score": 0.0,
                "max_score": 0.0,
                "std_score": 0.0,
                "num_frames_per_video": num_frames,
                "frame_interval": frame_interval,
                "extraction_method": frame_extraction_method
            }
        
        output = {
            "statistics": stats,
            "results": results,
            "failed_videos": failed_videos
        }
        
        if save_results and output_dir:
            os.makedirs(output_dir, exist_ok=True)
            
            results_file = os.path.join(output_dir, "results.json")
            with open(results_file, 'w', encoding='utf-8') as f:
                json.dump(output, f, indent=2, ensure_ascii=False)
            print(f"\nResults saved to: {results_file}")
            
            metrics_file = os.path.join(output_dir, "metrics.json")
            with open(metrics_file, 'w', encoding='utf-8') as f:
                json.dump(stats, f, indent=2, ensure_ascii=False)
            print(f"Metrics saved to: {metrics_file}")
            
            if failed_videos:
                failed_file = os.path.join(output_dir, "failed_videos.txt")
                with open(failed_file, 'w', encoding='utf-8') as f:
                    for vid in failed_videos:
                        f.write(f"{vid}\n")
                print(f"Failed videos list saved to: {failed_file}")
        
        return output


def main():
    parser = argparse.ArgumentParser(
        description="Aesthetic Predictor V2.5 Batch Inference Script",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python batch_inference.py --image_dir ./images --output_dir ./outputs
  python batch_inference.py --video_dir ./videos --output_dir ./outputs --num_frames 10
  python batch_inference.py --video_dir ./videos --output_dir ./outputs --frame_extraction_method scene_change --num_frames 15 --scene_threshold 30
  python batch_inference.py --video_dir ./videos --output_dir ./outputs --frame_extraction_method frame_diff --num_frames 20 --diff_threshold 25
  python batch_inference.py --video_dir ./videos --output_dir ./outputs --frame_extraction_method interval --frame_interval 0.5 --num_frames 20
  python batch_inference.py --video_list video_list.txt --output_dir ./outputs --frame_extraction_method scene_change
        """
    )
    
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument(
        "--image_dir",
        type=str,
        help="Directory containing images"
    )
    input_group.add_argument(
        "--image_list",
        type=str,
        help="File containing image paths (one per line)"
    )
    input_group.add_argument(
        "--video_dir",
        type=str,
        help="Directory containing videos"
    )
    input_group.add_argument(
        "--video_list",
        type=str,
        help="File containing video paths (one per line)"
    )
    
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="Output directory path"
    )
    
    parser.add_argument(
        "--extensions",
        type=str,
        nargs="+",
        default=["jpg", "jpeg", "png", "bmp", "webp"],
        help="Supported image extensions (default: jpg jpeg png bmp webp)"
    )
    
    parser.add_argument(
        "--video_extensions",
        type=str,
        nargs="+",
        default=["mp4", "avi", "mov", "mkv", "flv", "webm"],
        help="Supported video extensions (default: mp4 avi mov mkv flv webm)"
    )
    
    parser.add_argument(
        "--num_frames",
        type=int,
        default=10,
        help="Number of keyframes to extract per video (default: 10)"
    )
    
    parser.add_argument(
        "--frame_interval",
        type=float,
        default=None,
        help="Frame interval in seconds"
    )
    
    parser.add_argument(
        "--frame_extraction_method",
        type=str,
        choices=["uniform", "interval", "scene_change", "frame_diff"],
        default="uniform",
        help="Keyframe extraction method (default: uniform)"
    )
    
    parser.add_argument(
        "--scene_threshold",
        type=float,
        default=30.0,
        help="Scene change threshold (0-100) for scene_change method (default: 30.0)"
    )
    
    parser.add_argument(
        "--diff_threshold",
        type=float,
        default=25.0,
        help="Frame diff threshold (0-255) for frame_diff method (default: 25.0)"
    )
    
    parser.add_argument(
        "--device",
        type=str,
        choices=["cuda", "cpu"],
        default=None,
        help="Device to use (default: auto)"
    )
    
    parser.add_argument(
        "--no_save",
        action="store_true",
        help="Do not save results to file"
    )
    
    args = parser.parse_args()
    
    predictor = AestheticPredictorBatch(device=args.device)
    
    if args.image_dir or args.image_list:
        if args.image_dir:
            print(f"Loading images from directory: {args.image_dir}")
            image_paths = get_files_from_dir(args.image_dir, args.extensions)
            print(f"Found {len(image_paths)} images")
        else:
            print(f"Loading images from file: {args.image_list}")
            image_paths = load_path_list(args.image_list)
            print(f"Found {len(image_paths)} image paths")
        
        if not image_paths:
            print("Error: No image files found")
            return
        
        results = predictor.predict_batch(
            image_paths=image_paths,
            output_dir=args.output_dir if not args.no_save else None,
            save_results=not args.no_save
        )
        
        stats = results["statistics"]
        print("\n" + "="*50)
        print("Statistics:")
        print(f"  Total images: {stats['total_images']}")
        print(f"  Successful: {stats['successful']}")
        print(f"  Failed: {stats['failed']}")
        if stats['successful'] > 0:
            print(f"  Mean score: {stats['mean_score']:.2f}")
            print(f"  Min score: {stats['min_score']:.2f}")
            print(f"  Max score: {stats['max_score']:.2f}")
            print(f"  Std score: {stats['std_score']:.2f}")
        print("="*50)
    
    elif args.video_dir or args.video_list:
        if args.video_dir:
            print(f"Loading videos from directory: {args.video_dir}")
            video_paths = get_files_from_dir(args.video_dir, args.video_extensions)
            print(f"Found {len(video_paths)} videos")
        else:
            print(f"Loading videos from file: {args.video_list}")
            video_paths = load_path_list(args.video_list)
            print(f"Found {len(video_paths)} video paths")
        
        if not video_paths:
            print("Error: No video files found")
            return
        
        results = predictor.predict_video_batch(
            video_paths=video_paths,
            output_dir=args.output_dir if not args.no_save else None,
            save_results=not args.no_save,
            num_frames=args.num_frames,
            frame_interval=args.frame_interval,
            frame_extraction_method=args.frame_extraction_method,
            scene_threshold=args.scene_threshold,
            diff_threshold=args.diff_threshold
        )
        
        stats = results["statistics"]
        print("\n" + "="*50)
        print("Statistics:")
        print(f"  Total videos: {stats['total_videos']}")
        print(f"  Successful: {stats['successful']}")
        print(f"  Failed: {stats['failed']}")
        print(f"  Keyframes per video: {stats['num_frames_per_video']}")
        if stats['successful'] > 0:
            print(f"  Mean score: {stats['mean_score']:.2f}")
            print(f"  Min score: {stats['min_score']:.2f}")
            print(f"  Max score: {stats['max_score']:.2f}")
            print(f"  Std score: {stats['std_score']:.2f}")
        print("="*50)


if __name__ == "__main__":
    main()

