#!/usr/bin/env python3
"""
Multimodal Analysis Script for JIR Evaluation

This script addresses Reviewer B2: Multimodal Anomaly
- Why do multimodal models exhibit lower recall than text-only models (0.412 vs 0.429)?
- Analyze in which scenarios visual information enhances performance
- Analyze in which scenarios visual information may introduce noise
"""

import sys
import os
import json
import numpy as np
import time
from typing import List, Dict, Any, Tuple
from collections import defaultdict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch

# Add parent directory to path
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from utils import load_jsonl, METAINFO_DIR
from evaluate_utils import get_duration_by_youtube_id
from evaluate import JIREvaluator
# Visualization functions are now in visualize_multimodal_analysis.py
from model_config import get_all_model_ids, get_text_model_pairs
from similarity_cache import get_or_compute_matches
from need_type_utils import get_need_type_label


def get_text_model_id(multimodal_model_id: str) -> str:
    """Get corresponding text model ID from multimodal model ID."""
    if "_multimodal" in multimodal_model_id:
        return multimodal_model_id.replace("_multimodal", "")
    return None


def process_single_video_multimodal(args):
    """Process a single video for multimodal analysis (for parallel processing)."""
    text_model_id, youtube_id, filename, youtube_ids_set, similarity_threshold, fuzzy_sentence_interval, gpu_id, queries_cache = args
    
    # Ensure patch cache is loaded in this worker thread
    # (global cache might not be initialized in worker threads)
    from need_type_utils import load_type_patch_cache
    load_type_patch_cache()
    
    try:
        # Suppress verbose output during processing
        if torch.cuda.is_available() and gpu_id is not None:
            device = f'cuda:{gpu_id}'
        else:
            device = 'cpu'
        
        if youtube_id not in youtube_ids_set:
            return None
        
        evaluation_dir = os.getenv("EVALUATION_DIR", "../evaluation_output")
        multimodal_model_id = f"{text_model_id}_multimodal"
        
        text_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{text_model_id}")
        multimodal_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{multimodal_model_id}")
        
        text_file = os.path.join(text_dir, filename)
        multimodal_file = os.path.join(multimodal_dir, filename)
        
        if not os.path.exists(text_file) or not os.path.exists(multimodal_file):
            return None
        
        # Use pre-loaded queries cache if available, otherwise load from file
        if queries_cache is not None and youtube_id in queries_cache:
            queries_to_match = queries_cache[youtube_id]
        else:
            ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
            if not os.path.exists(ground_truth_file):
                return None
            queries_to_match = load_jsonl(ground_truth_file)
        
        text_candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{text_model_id}", filename)
        multimodal_candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{multimodal_model_id}", filename)
        
        if not os.path.exists(text_candidate_file) or not os.path.exists(multimodal_candidate_file):
            return None
        text_data = json.load(open(text_file, "r"))
        multimodal_data = json.load(open(multimodal_file, "r"))
        text_candidates = json.load(open(text_candidate_file, "r")).get("needs", [])
        multimodal_candidates = json.load(open(multimodal_candidate_file, "r")).get("needs", [])
        
        # Check cache first to avoid creating evaluator if not needed
        from evaluation.similarity_cache import get_cache_file_path, load_match_results
        cache_dir = os.path.join(evaluation_dir, ".similarity_cache")
        
        text_cache_file = get_cache_file_path(cache_dir, text_model_id, youtube_id, similarity_threshold, fuzzy_sentence_interval)
        multimodal_cache_file = get_cache_file_path(cache_dir, f"{text_model_id}_multimodal", youtube_id, similarity_threshold, fuzzy_sentence_interval)
        
        text_cached = load_match_results(text_cache_file)
        multimodal_cached = load_match_results(multimodal_cache_file)
        
        # Only create evaluator if cache miss (need to compute)
        evaluator = None
        def get_evaluator():
            nonlocal evaluator
            if evaluator is None:
                evaluator = JIREvaluator(similarity_threshold=similarity_threshold, device=device)
                if not hasattr(evaluator, 'model') or evaluator.model is None:
                    raise RuntimeError(f"Failed to initialize JIREvaluator model for {text_model_id}/{youtube_id}")
            return evaluator
        
        # Reuse text model cache from error_analysis.py (if available)
        # Only create evaluator if cache miss
        if text_cached is not None and len(text_cached.get("queries_to_match", [])) == len(queries_to_match) and len(text_cached.get("candidate_queries", [])) == len(text_candidates):
            text_results = text_cached["match_results"]
        else:
            text_results = get_or_compute_matches(
                text_model_id, youtube_id, queries_to_match, text_candidates,
                get_evaluator, similarity_threshold, fuzzy_sentence_interval,
                use_cache=True
            )
        
        # Multimodal model also uses cache (same cache mechanism, different model_id)
        multimodal_model_id = f"{text_model_id}_multimodal"
        if multimodal_cached is not None and len(multimodal_cached.get("queries_to_match", [])) == len(queries_to_match) and len(multimodal_cached.get("candidate_queries", [])) == len(multimodal_candidates):
            multimodal_results = multimodal_cached["match_results"]
        else:
            multimodal_results = get_or_compute_matches(
                multimodal_model_id, youtube_id, queries_to_match, multimodal_candidates,
                get_evaluator, similarity_threshold, fuzzy_sentence_interval,
                use_cache=True
            )
        
        text_recall = text_data.get("recall", {}).get("recall", 0.0)
        multimodal_recall = multimodal_data.get("recall", {}).get("recall", 0.0)
        # Fix: Use weighted (not matched_weighted) to include all queries (unmatched contribute 0)
        # matched_weighted only considers matched queries, which overestimates performance
        text_timeliness_avg = text_data.get("timeliness", {}).get("weighted_time_match", 0.0)
        multimodal_timeliness_avg = multimodal_data.get("timeliness", {}).get("weighted_time_match", 0.0)
        text_timeliness_start = text_data.get("timeliness", {}).get("weighted_start_match", 0.0)
        multimodal_timeliness_start = multimodal_data.get("timeliness", {}).get("weighted_start_match", 0.0)
        text_timeliness_end = text_data.get("timeliness", {}).get("weighted_end_match", 0.0)
        multimodal_timeliness_end = multimodal_data.get("timeliness", {}).get("weighted_end_match", 0.0)
        
        text_similarities = []
        multimodal_similarities = []
        for idx, query in enumerate(queries_to_match):
            text_matches = text_results.get(idx, [])
            multimodal_matches = multimodal_results.get(idx, [])
            if text_matches:
                text_similarities.append(max([s for _, s in text_matches]))
            if multimodal_matches:
                multimodal_similarities.append(max([s for _, s in multimodal_matches]))
        
        avg_text_sim = np.mean(text_similarities) if text_similarities else 0.0
        avg_multimodal_sim = np.mean(multimodal_similarities) if multimodal_similarities else 0.0
        
        # Collect detailed performance by need type (for aggregation across all videos)
        need_type_performance = defaultdict(lambda: {
            "text": {"matched": 0, "total": 0, "relevance_scores": [], "timeliness_scores": []},
            "multimodal": {"matched": 0, "total": 0, "relevance_scores": [], "timeliness_scores": []}
        })
        
        # Get relevance and timeliness scores from evaluation data
        text_relevance_scores = text_data.get("relevance", {}).get("query_scores", {})
        multimodal_relevance_scores = multimodal_data.get("relevance", {}).get("query_scores", {})
        text_timeliness_scores = text_data.get("timeliness", {}).get("query_time_scores", {})
        multimodal_timeliness_scores = multimodal_data.get("timeliness", {}).get("query_time_scores", {})
        
        # Get video duration for time error analysis
        video_duration = get_duration_by_youtube_id(youtube_id)
        
        for idx, query in enumerate(queries_to_match):
            need_type = get_need_type_label(query, youtube_id=youtube_id)
            need_type_performance[need_type]["text"]["total"] += 1
            need_type_performance[need_type]["multimodal"]["total"] += 1
            
            text_matched = len(text_results.get(idx, [])) > 0
            multimodal_matched = len(multimodal_results.get(idx, [])) > 0
            
            if text_matched:
                need_type_performance[need_type]["text"]["matched"] += 1
            if multimodal_matched:
                need_type_performance[need_type]["multimodal"]["matched"] += 1
        
            # Collect relevance and timeliness scores
            idx_str = str(idx)
            if idx_str in text_relevance_scores:
                need_type_performance[need_type]["text"]["relevance_scores"].append(text_relevance_scores[idx_str])
            if idx_str in multimodal_relevance_scores:
                need_type_performance[need_type]["multimodal"]["relevance_scores"].append(multimodal_relevance_scores[idx_str])
            
            if idx_str in text_timeliness_scores:
                timeliness = text_timeliness_scores[idx_str].get("avg_time_match", 0.0)
                need_type_performance[need_type]["text"]["timeliness_scores"].append(timeliness)
            if idx_str in multimodal_timeliness_scores:
                timeliness = multimodal_timeliness_scores[idx_str].get("avg_time_match", 0.0)
                need_type_performance[need_type]["multimodal"]["timeliness_scores"].append(timeliness)
            
            # Deep timeliness analysis: calculate actual time errors
            # Get ground truth time
            gt_start = query.get("start_time", 0.0)
            gt_end = query.get("end_time", gt_start)
            if gt_end is None or gt_end == gt_start:
                gt_end = gt_start + 10.0
            gt_duration = gt_end - gt_start
            
            # Determine query characteristics for deep analysis
            query_position_ratio = gt_start / video_duration if video_duration > 0 else 0.0
            if query_position_ratio < 0.33:
                position = "early"
            elif query_position_ratio < 0.67:
                position = "middle"
            else:
                position = "late"
            
            if gt_duration < 10:
                duration_cat = "short"
            elif gt_duration < 30:
                duration_cat = "medium"
            else:
                duration_cat = "long"
            
            # Calculate time errors for text model
            if text_matched and idx_str in text_timeliness_scores:
                text_matches = text_results.get(idx, [])
                if text_matches:
                    best_match_candidate, _ = max(text_matches, key=lambda x: x[1])
                    matched_start = best_match_candidate.get("start_time", 0.0)
                    matched_end = best_match_candidate.get("end_time", matched_start)
                    if matched_end is None or matched_end == matched_start:
                        matched_end = matched_start + 10.0
                    
                    # Calculate time offsets (normalized by video duration)
                    start_time_offset = abs(matched_start - gt_start) / video_duration if video_duration > 0 else abs(matched_start - gt_start)
                    end_time_offset = abs(matched_end - gt_end) / video_duration if video_duration > 0 else abs(matched_end - gt_end)
                    duration_offset = abs((matched_end - matched_start) - gt_duration) / video_duration if video_duration > 0 else abs((matched_end - matched_start) - gt_duration)
                    
                    # Store in need_type_performance for aggregation
                    if "time_errors" not in need_type_performance[need_type]["text"]:
                        need_type_performance[need_type]["text"]["time_errors"] = {"start": [], "end": [], "duration": []}
                        need_type_performance[need_type]["multimodal"]["time_errors"] = {"start": [], "end": [], "duration": []}
                    
                    need_type_performance[need_type]["text"]["time_errors"]["start"].append(start_time_offset)
                    need_type_performance[need_type]["text"]["time_errors"]["end"].append(end_time_offset)
                    need_type_performance[need_type]["text"]["time_errors"]["duration"].append(duration_offset)
            
            # Calculate time errors for multimodal model
            if multimodal_matched and idx_str in multimodal_timeliness_scores:
                multimodal_matches = multimodal_results.get(idx, [])
                if multimodal_matches:
                    best_match_candidate, _ = max(multimodal_matches, key=lambda x: x[1])
                    matched_start = best_match_candidate.get("start_time", 0.0)
                    matched_end = best_match_candidate.get("end_time", matched_start)
                    if matched_end is None or matched_end == matched_start:
                        matched_end = matched_start + 10.0
                    
                    # Calculate time offsets (normalized by video duration)
                    start_time_offset = abs(matched_start - gt_start) / video_duration if video_duration > 0 else abs(matched_start - gt_start)
                    end_time_offset = abs(matched_end - gt_end) / video_duration if video_duration > 0 else abs(matched_end - gt_end)
                    duration_offset = abs((matched_end - matched_start) - gt_duration) / video_duration if video_duration > 0 else abs((matched_end - matched_start) - gt_duration)
                    
                    if "time_errors" not in need_type_performance[need_type]["multimodal"]:
                        need_type_performance[need_type]["multimodal"]["time_errors"] = {"start": [], "end": [], "duration": []}
                    
                    need_type_performance[need_type]["multimodal"]["time_errors"]["start"].append(start_time_offset)
                    need_type_performance[need_type]["multimodal"]["time_errors"]["end"].append(end_time_offset)
                    need_type_performance[need_type]["multimodal"]["time_errors"]["duration"].append(duration_offset)
        
        # Collect deep timeliness analysis data (by duration, position)
        deep_timeliness_data = {
            "by_duration": {
                "short": {"text": {"start": [], "end": [], "duration": []}, "multimodal": {"start": [], "end": [], "duration": []}},
                "medium": {"text": {"start": [], "end": [], "duration": []}, "multimodal": {"start": [], "end": [], "duration": []}},
                "long": {"text": {"start": [], "end": [], "duration": []}, "multimodal": {"start": [], "end": [], "duration": []}}
            },
            "by_position": {
                "early": {"text": {"start": [], "end": [], "duration": []}, "multimodal": {"start": [], "end": [], "duration": []}},
                "middle": {"text": {"start": [], "end": [], "duration": []}, "multimodal": {"start": [], "end": [], "duration": []}},
                "late": {"text": {"start": [], "end": [], "duration": []}, "multimodal": {"start": [], "end": [], "duration": []}}
            }
        }
        
        # Re-iterate to collect deep analysis data (we already calculated characteristics above)
        for idx, query in enumerate(queries_to_match):
            idx_str = str(idx)
            need_type = get_need_type_label(query, youtube_id=youtube_id)
            
            gt_start = query.get("start_time", 0.0)
            gt_end = query.get("end_time", gt_start)
            if gt_end is None or gt_end == gt_start:
                gt_end = gt_start + 10.0
            gt_duration = gt_end - gt_start
            
            query_position_ratio = gt_start / video_duration if video_duration > 0 else 0.0
            if query_position_ratio < 0.33:
                position = "early"
            elif query_position_ratio < 0.67:
                position = "middle"
            else:
                position = "late"
            
            if gt_duration < 10:
                duration_cat = "short"
            elif gt_duration < 30:
                duration_cat = "medium"
            else:
                duration_cat = "long"
            
            # Collect text model time errors
            text_matches = text_results.get(idx, [])
            if text_matches and idx_str in text_timeliness_scores:
                best_match_candidate, _ = max(text_matches, key=lambda x: x[1])
                matched_start = best_match_candidate.get("start_time", 0.0)
                matched_end = best_match_candidate.get("end_time", matched_start)
                if matched_end is None or matched_end == matched_start:
                    matched_end = matched_start + 10.0
                
                start_time_offset = abs(matched_start - gt_start) / video_duration if video_duration > 0 else abs(matched_start - gt_start)
                end_time_offset = abs(matched_end - gt_end) / video_duration if video_duration > 0 else abs(matched_end - gt_end)
                duration_offset = abs((matched_end - matched_start) - gt_duration) / video_duration if video_duration > 0 else abs((matched_end - matched_start) - gt_duration)
                
                deep_timeliness_data["by_duration"][duration_cat]["text"]["start"].append(start_time_offset)
                deep_timeliness_data["by_duration"][duration_cat]["text"]["end"].append(end_time_offset)
                deep_timeliness_data["by_duration"][duration_cat]["text"]["duration"].append(duration_offset)
                
                deep_timeliness_data["by_position"][position]["text"]["start"].append(start_time_offset)
                deep_timeliness_data["by_position"][position]["text"]["end"].append(end_time_offset)
                deep_timeliness_data["by_position"][position]["text"]["duration"].append(duration_offset)
            
            # Collect multimodal model time errors
            multimodal_matches = multimodal_results.get(idx, [])
            if multimodal_matches and idx_str in multimodal_timeliness_scores:
                best_match_candidate, _ = max(multimodal_matches, key=lambda x: x[1])
                matched_start = best_match_candidate.get("start_time", 0.0)
                matched_end = best_match_candidate.get("end_time", matched_start)
                if matched_end is None or matched_end == matched_start:
                    matched_end = matched_start + 10.0
                
                start_time_offset = abs(matched_start - gt_start) / video_duration if video_duration > 0 else abs(matched_start - gt_start)
                end_time_offset = abs(matched_end - gt_end) / video_duration if video_duration > 0 else abs(matched_end - gt_end)
                duration_offset = abs((matched_end - matched_start) - gt_duration) / video_duration if video_duration > 0 else abs((matched_end - matched_start) - gt_duration)
                
                deep_timeliness_data["by_duration"][duration_cat]["multimodal"]["start"].append(start_time_offset)
                deep_timeliness_data["by_duration"][duration_cat]["multimodal"]["end"].append(end_time_offset)
                deep_timeliness_data["by_duration"][duration_cat]["multimodal"]["duration"].append(duration_offset)
                
                deep_timeliness_data["by_position"][position]["multimodal"]["start"].append(start_time_offset)
                deep_timeliness_data["by_position"][position]["multimodal"]["end"].append(end_time_offset)
                deep_timeliness_data["by_position"][position]["multimodal"]["duration"].append(duration_offset)
        
        # Collect detailed need type performance for aggregation
        need_type_details = {}
        for need_type, perf in need_type_performance.items():
            text_recall_rate = perf["text"]["matched"] / perf["text"]["total"] if perf["text"]["total"] > 0 else 0.0
            multimodal_recall_rate = perf["multimodal"]["matched"] / perf["multimodal"]["total"] if perf["multimodal"]["total"] > 0 else 0.0
            text_relevance_avg = np.mean(perf["text"]["relevance_scores"]) if perf["text"]["relevance_scores"] else 0.0
            multimodal_relevance_avg = np.mean(perf["multimodal"]["relevance_scores"]) if perf["multimodal"]["relevance_scores"] else 0.0
            # Use different variable names to avoid overwriting the video-level timeliness values
            text_timeliness_avg_by_type = np.mean(perf["text"]["timeliness_scores"]) if perf["text"]["timeliness_scores"] else 0.0
            multimodal_timeliness_avg_by_type = np.mean(perf["multimodal"]["timeliness_scores"]) if perf["multimodal"]["timeliness_scores"] else 0.0
            
            # Calculate time error statistics
            text_time_errors = perf["text"].get("time_errors", {})
            multimodal_time_errors = perf["multimodal"].get("time_errors", {})
            
            need_type_details[need_type] = {
                "text_recall": text_recall_rate,
                "multimodal_recall": multimodal_recall_rate,
                "recall_drop": text_recall_rate - multimodal_recall_rate,
                "text_relevance": text_relevance_avg,
                "multimodal_relevance": multimodal_relevance_avg,
                "relevance_drop": text_relevance_avg - multimodal_relevance_avg,
                "text_timeliness": text_timeliness_avg_by_type,
                "multimodal_timeliness": multimodal_timeliness_avg_by_type,
                "timeliness_drop": text_timeliness_avg_by_type - multimodal_timeliness_avg_by_type,
                "text_time_errors": {
                    "start_mean": float(np.mean(text_time_errors.get("start", []))) if text_time_errors.get("start") else 0.0,
                    "end_mean": float(np.mean(text_time_errors.get("end", []))) if text_time_errors.get("end") else 0.0,
                    "duration_mean": float(np.mean(text_time_errors.get("duration", []))) if text_time_errors.get("duration") else 0.0
                },
                "multimodal_time_errors": {
                    "start_mean": float(np.mean(multimodal_time_errors.get("start", []))) if multimodal_time_errors.get("start") else 0.0,
                    "end_mean": float(np.mean(multimodal_time_errors.get("end", []))) if multimodal_time_errors.get("end") else 0.0,
                    "duration_mean": float(np.mean(multimodal_time_errors.get("duration", []))) if multimodal_time_errors.get("duration") else 0.0
                },
                "time_error_difference": {
                    "start_diff": float(np.mean(multimodal_time_errors.get("start", [])) - np.mean(text_time_errors.get("start", []))) if (multimodal_time_errors.get("start") and text_time_errors.get("start")) else 0.0,
                    "end_diff": float(np.mean(multimodal_time_errors.get("end", [])) - np.mean(text_time_errors.get("end", []))) if (multimodal_time_errors.get("end") and text_time_errors.get("end")) else 0.0,
                    "duration_diff": float(np.mean(multimodal_time_errors.get("duration", [])) - np.mean(text_time_errors.get("duration", []))) if (multimodal_time_errors.get("duration") and text_time_errors.get("duration")) else 0.0
                },
                "total_queries": perf["text"]["total"]
            }
        
        # Keep old enhancement/noise scenarios for backward compatibility
        enhancement_scenarios = []
        noise_scenarios = []
        all_scenarios = []
        for need_type, perf in need_type_performance.items():
            text_rate = perf["text"]["matched"] / perf["text"]["total"] if perf["text"]["total"] > 0 else 0.0
            multimodal_rate = perf["multimodal"]["matched"] / perf["multimodal"]["total"] if perf["multimodal"]["total"] > 0 else 0.0
            improvement = multimodal_rate - text_rate
            
            scenario_data = {
                    "model": text_model_id,
                    "youtube_id": youtube_id,
                    "need_type": need_type,
                    "text_rate": text_rate,
                    "multimodal_rate": multimodal_rate,
                    "improvement": improvement
            }
            all_scenarios.append(scenario_data)
            
            if improvement > 0.1:
                enhancement_scenarios.append(scenario_data)
            elif improvement < -0.1:
                noise_scenarios.append({
                    "model": text_model_id,
                    "youtube_id": youtube_id,
                    "need_type": need_type,
                    "text_rate": text_rate,
                    "multimodal_rate": multimodal_rate,
                    "degradation": abs(improvement)
                })
        
        return {
            "text_recall": text_recall,
            "multimodal_recall": multimodal_recall,
            "recall_drop": text_recall - multimodal_recall,
            "timeliness_avg_difference": text_timeliness_avg - multimodal_timeliness_avg,
            "timeliness_start_difference": text_timeliness_start - multimodal_timeliness_start,
            "timeliness_end_difference": text_timeliness_end - multimodal_timeliness_end,
            "time_matching_difference": text_timeliness_avg - multimodal_timeliness_avg,  # Keep for backward compatibility
            "semantic_similarity_difference": avg_text_sim - avg_multimodal_sim,
            "enhancement_scenarios": enhancement_scenarios,
            "noise_scenarios": noise_scenarios,
            "all_scenarios": all_scenarios,  # Store all scenarios for statistics
            "youtube_id": youtube_id,
            "model": text_model_id,
            "need_type_performance": need_type_details,  # Detailed performance by need type for this video
            "text_timeliness": text_timeliness_avg,  # For video-level aggregation
            "multimodal_timeliness": multimodal_timeliness_avg,  # For video-level aggregation
            "deep_timeliness_data": deep_timeliness_data  # Deep timeliness analysis data
        }
    except Exception as e:
        print(f"Error processing {youtube_id} for {text_model_id}: {e}")
        return None


def analyze_multimodal_performance(
    evaluation_dir: str,
    output_dir: str,
    text_model_ids: List[str],
    similarity_threshold: float = 0.55,
    fuzzy_sentence_interval: int = 1,
    num_workers: int = 4,
    use_gpu: bool = True,
    num_gpus: int = 4
) -> Dict[str, Any]:
    """
    Analyze multimodal vs text model performance.
    
    Args:
        evaluation_dir: Directory containing evaluation results
        output_dir: Output directory for analysis results
        text_model_ids: List of text model IDs
        similarity_threshold: Similarity threshold
        fuzzy_sentence_interval: Fuzzy time interval
        
    Returns:
        Dictionary with multimodal analysis results
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Load video IDs
    youtube_ids = load_jsonl(os.path.join(METAINFO_DIR, "lecture_.jsonl")) + \
                  load_jsonl(os.path.join(METAINFO_DIR, "paper_.jsonl"))
    youtube_ids_set = set([item["youtube_id"] for item in youtube_ids])
    
    # Pre-load all queries_to_match to memory (optimization: reduce file I/O)
    print("Pre-loading queries_to_match to memory...")
    queries_cache = {}
    for item in youtube_ids:
        youtube_id = item["youtube_id"]
        ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
        if os.path.exists(ground_truth_file):
            try:
                queries_cache[youtube_id] = load_jsonl(ground_truth_file)
            except Exception as e:
                pass

    tasks = []
    for text_model_id in text_model_ids:
        multimodal_model_id = f"{text_model_id}_multimodal"
        text_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{text_model_id}")
        multimodal_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{multimodal_model_id}")
        
        if not os.path.exists(text_dir) or not os.path.exists(multimodal_dir):
            print(f"  Skipping {text_model_id}: directories not found")
            continue
        
        video_count = 0
        for filename in os.listdir(text_dir):
            if not filename.endswith(".json"):
                continue
            
            youtube_id = filename[:-len(".json")]
            if youtube_id not in youtube_ids_set:
                continue
            
            tasks.append((text_model_id, youtube_id, filename, youtube_ids_set, 
                         similarity_threshold, fuzzy_sentence_interval, None, queries_cache))
            video_count += 1
        
        print(f"  {text_model_id}: {video_count} videos")
    
    # Assign GPU IDs
    if use_gpu and torch.cuda.is_available():
        tasks_with_gpu = []
        for i, task in enumerate(tasks):
            gpu_id = i % num_gpus
            # task structure: (text_model_id, youtube_id, filename, youtube_ids_set, similarity_threshold, fuzzy_sentence_interval, None, queries_cache)
            # Replace None (6th element) with gpu_id, keep queries_cache (7th element)
            task_list = list(task)
            task_list[6] = gpu_id  # Replace None with gpu_id
            tasks_with_gpu.append(tuple(task_list))
        tasks = tasks_with_gpu
    
    # Collect results
    recall_drop_analysis = {
        "text_models_recall": [],
        "multimodal_models_recall": [],
        "recall_drop": [],
        "timeliness_avg_difference": [],
        "timeliness_start_difference": [],
        "timeliness_end_difference": [],
        "time_matching_difference": [],  # Keep for backward compatibility
        "semantic_similarity_difference": []
    }
    enhancement_scenarios = []
    noise_scenarios = []
    all_scenarios = []  # Store all scenarios for statistics
    video_performances = []  # Store per-video performance for aggregation
    need_type_aggregated = defaultdict(lambda: {
        "text_recall": [],
        "multimodal_recall": [],
        "text_relevance": [],
        "multimodal_relevance": [],
        "text_timeliness": [],
        "multimodal_timeliness": [],
        "text_time_errors": {"start": [], "end": [], "duration": []},
        "multimodal_time_errors": {"start": [], "end": [], "duration": []},
        "total_queries": 0
    })
    
    # Deep timeliness analysis: collect by duration, position
    timeliness_deep_analysis = {
        "by_duration": {
            "short": {"text_time_errors": {"start": [], "end": [], "duration": []}, "multimodal_time_errors": {"start": [], "end": [], "duration": []}},
            "medium": {"text_time_errors": {"start": [], "end": [], "duration": []}, "multimodal_time_errors": {"start": [], "end": [], "duration": []}},
            "long": {"text_time_errors": {"start": [], "end": [], "duration": []}, "multimodal_time_errors": {"start": [], "end": [], "duration": []}}
        },
        "by_position": {
            "early": {"text_time_errors": {"start": [], "end": [], "duration": []}, "multimodal_time_errors": {"start": [], "end": [], "duration": []}},
            "middle": {"text_time_errors": {"start": [], "end": [], "duration": []}, "multimodal_time_errors": {"start": [], "end": [], "duration": []}},
            "late": {"text_time_errors": {"start": [], "end": [], "duration": []}, "multimodal_time_errors": {"start": [], "end": [], "duration": []}}
        }
    }
    
    # Process in parallel
    print(f"\nProcessing {len(tasks)} videos with {num_workers} workers...")
    print(f"Using {num_gpus} GPUs for parallel processing")
    print(f"Note: Both text and multimodal models use cache. First run will compute and cache multimodal results.")
    error_count = 0
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_video_multimodal, task): task for task in tasks}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing videos", mininterval=2.0, file=sys.stdout):
            try:
                result = future.result(timeout=300)  # 5 minute timeout per task
            except Exception as e:
                error_count += 1
                if error_count <= 5:  # Only print first 5 errors
                    print(f"\n  Error in task: {e}", file=sys.stderr, flush=True)
                result = None
            
            if result is None:
                continue
            
            recall_drop_analysis["text_models_recall"].append(result["text_recall"])
            recall_drop_analysis["multimodal_models_recall"].append(result["multimodal_recall"])
            recall_drop_analysis["recall_drop"].append(result["recall_drop"])
            recall_drop_analysis["timeliness_avg_difference"].append(result["timeliness_avg_difference"])
            recall_drop_analysis["timeliness_start_difference"].append(result["timeliness_start_difference"])
            recall_drop_analysis["timeliness_end_difference"].append(result["timeliness_end_difference"])
            recall_drop_analysis["time_matching_difference"].append(result["time_matching_difference"])
            recall_drop_analysis["semantic_similarity_difference"].append(result["semantic_similarity_difference"])
            enhancement_scenarios.extend(result["enhancement_scenarios"])
            noise_scenarios.extend(result["noise_scenarios"])
            if "all_scenarios" in result:
                all_scenarios.extend(result["all_scenarios"])
            
            # Collect video-level performance
            if "youtube_id" in result and "need_type_performance" in result:
                video_performances.append({
                    "youtube_id": result["youtube_id"],
                    "model": result.get("model", ""),
                    "text_recall": result["text_recall"],
                    "multimodal_recall": result["multimodal_recall"],
                    "recall_drop": result["recall_drop"],
                    "text_timeliness": result.get("text_timeliness", 0.0),
                    "multimodal_timeliness": result.get("multimodal_timeliness", 0.0),
                    "timeliness_drop": result.get("timeliness_avg_difference", 0.0),
                    "need_type_performance": result["need_type_performance"]
                })
                
                # Aggregate by need type across all videos
                for need_type, perf in result["need_type_performance"].items():
                    need_type_aggregated[need_type]["text_recall"].append(perf["text_recall"])
                    need_type_aggregated[need_type]["multimodal_recall"].append(perf["multimodal_recall"])
                    need_type_aggregated[need_type]["text_relevance"].append(perf["text_relevance"])
                    need_type_aggregated[need_type]["multimodal_relevance"].append(perf["multimodal_relevance"])
                    need_type_aggregated[need_type]["text_timeliness"].append(perf["text_timeliness"])
                    need_type_aggregated[need_type]["multimodal_timeliness"].append(perf["multimodal_timeliness"])
                    need_type_aggregated[need_type]["total_queries"] += perf["total_queries"]
                    
                    # Aggregate time errors
                    if "text_time_errors" in perf and perf["text_time_errors"]["start_mean"] > 0:
                        need_type_aggregated[need_type]["text_time_errors"]["start"].append(perf["text_time_errors"]["start_mean"])
                        need_type_aggregated[need_type]["text_time_errors"]["end"].append(perf["text_time_errors"]["end_mean"])
                        need_type_aggregated[need_type]["text_time_errors"]["duration"].append(perf["text_time_errors"]["duration_mean"])
                    if "multimodal_time_errors" in perf and perf["multimodal_time_errors"]["start_mean"] > 0:
                        need_type_aggregated[need_type]["multimodal_time_errors"]["start"].append(perf["multimodal_time_errors"]["start_mean"])
                        need_type_aggregated[need_type]["multimodal_time_errors"]["end"].append(perf["multimodal_time_errors"]["end_mean"])
                        need_type_aggregated[need_type]["multimodal_time_errors"]["duration"].append(perf["multimodal_time_errors"]["duration_mean"])
                
                # Aggregate deep timeliness analysis data
                if "deep_timeliness_data" in result:
                    deep_data = result["deep_timeliness_data"]
                    # By duration
                    for duration_cat in ["short", "medium", "long"]:
                        for model_type in ["text", "multimodal"]:
                            for error_type in ["start", "end", "duration"]:
                                timeliness_deep_analysis["by_duration"][duration_cat][f"{model_type}_time_errors"][error_type].extend(
                                    deep_data["by_duration"][duration_cat][model_type][error_type]
                                )
                    # By position
                    for position in ["early", "middle", "late"]:
                        for model_type in ["text", "multimodal"]:
                            for error_type in ["start", "end", "duration"]:
                                timeliness_deep_analysis["by_position"][position][f"{model_type}_time_errors"][error_type].extend(
                                    deep_data["by_position"][position][model_type][error_type]
                                )
    
    # Calculate overall statistics
    avg_text_recall = np.mean(recall_drop_analysis["text_models_recall"]) if recall_drop_analysis["text_models_recall"] else 0.0
    avg_multimodal_recall = np.mean(recall_drop_analysis["multimodal_models_recall"]) if recall_drop_analysis["multimodal_models_recall"] else 0.0
    avg_recall_drop = avg_text_recall - avg_multimodal_recall
    
    avg_timeliness_avg_diff = np.mean(recall_drop_analysis["timeliness_avg_difference"]) if recall_drop_analysis["timeliness_avg_difference"] else 0.0
    avg_timeliness_start_diff = np.mean(recall_drop_analysis["timeliness_start_difference"]) if recall_drop_analysis["timeliness_start_difference"] else 0.0
    avg_timeliness_end_diff = np.mean(recall_drop_analysis["timeliness_end_difference"]) if recall_drop_analysis["timeliness_end_difference"] else 0.0
    avg_time_diff = avg_timeliness_avg_diff  # Keep for backward compatibility
    avg_sim_diff = np.mean(recall_drop_analysis["semantic_similarity_difference"]) if recall_drop_analysis["semantic_similarity_difference"] else 0.0
    
    # Analyze recall drop reasons
    # Calculate relative contribution of time matching vs semantic matching differences
    # Note: This shows the relative magnitude of differences, not direct causation of recall drop
    total_diff = abs(avg_time_diff) + abs(avg_sim_diff)
    if total_diff > 0:
        time_contribution = abs(avg_time_diff) / total_diff
        semantic_contribution = abs(avg_sim_diff) / total_diff
    else:
        time_contribution = 0.0
        semantic_contribution = 0.0
    
    # Compile results
    results = {
        "recall_drop_analysis": {
            "text_models_recall": float(avg_text_recall),
            "multimodal_models_recall": float(avg_multimodal_recall),
            "recall_drop": float(avg_recall_drop),
            "drop_reason": {
                "time_matching": {
                    "contribution": float(time_contribution),
                    "explanation": "Time matching accuracy difference between text and multimodal models"
                },
                "semantic_matching": {
                    "contribution": float(semantic_contribution),
                    "explanation": "Semantic similarity difference between text and multimodal models"
                }
            },
            "time_matching_difference": float(avg_time_diff),
            "semantic_similarity_difference": float(avg_sim_diff)
        },
        "timeliness_drop_analysis": {
            "timeliness_avg_difference": float(avg_timeliness_avg_diff),
            "timeliness_start_difference": float(avg_timeliness_start_diff),
            "timeliness_end_difference": float(avg_timeliness_end_diff),
            "explanation": "Multimodal models show lower timeliness scores across all metrics (avg, start, end)"
        },
        "visual_enhancement_scenarios": enhancement_scenarios[:50],  # Top 50
        "noise_introduction_scenarios": noise_scenarios[:50],  # Top 50
        "enhancement_by_type": aggregate_by_type(enhancement_scenarios),
        "noise_by_type": aggregate_by_type(noise_scenarios),
        "all_scenarios": all_scenarios,  # Store all scenarios for statistics
        "video_performances": video_performances,  # Per-video performance
        "need_type_aggregated": {k: {
            "text_recall": float(np.mean(v["text_recall"])) if v["text_recall"] else 0.0,
            "multimodal_recall": float(np.mean(v["multimodal_recall"])) if v["multimodal_recall"] else 0.0,
            "recall_drop": float(np.mean(v["text_recall"]) - np.mean(v["multimodal_recall"])) if v["text_recall"] and v["multimodal_recall"] else 0.0,
            "text_relevance": float(np.mean(v["text_relevance"])) if v["text_relevance"] else 0.0,
            "multimodal_relevance": float(np.mean(v["multimodal_relevance"])) if v["multimodal_relevance"] else 0.0,
            "relevance_drop": float(np.mean(v["text_relevance"]) - np.mean(v["multimodal_relevance"])) if v["text_relevance"] and v["multimodal_relevance"] else 0.0,
            "text_timeliness": float(np.mean(v["text_timeliness"])) if v["text_timeliness"] else 0.0,
            "multimodal_timeliness": float(np.mean(v["multimodal_timeliness"])) if v["multimodal_timeliness"] else 0.0,
            "timeliness_drop": float(np.mean(v["text_timeliness"]) - np.mean(v["multimodal_timeliness"])) if v["text_timeliness"] and v["multimodal_timeliness"] else 0.0,
            "text_time_errors": {
                "start_mean": float(np.mean(v["text_time_errors"]["start"])) if v["text_time_errors"]["start"] else 0.0,
                "end_mean": float(np.mean(v["text_time_errors"]["end"])) if v["text_time_errors"]["end"] else 0.0,
                "duration_mean": float(np.mean(v["text_time_errors"]["duration"])) if v["text_time_errors"]["duration"] else 0.0
            },
            "multimodal_time_errors": {
                "start_mean": float(np.mean(v["multimodal_time_errors"]["start"])) if v["multimodal_time_errors"]["start"] else 0.0,
                "end_mean": float(np.mean(v["multimodal_time_errors"]["end"])) if v["multimodal_time_errors"]["end"] else 0.0,
                "duration_mean": float(np.mean(v["multimodal_time_errors"]["duration"])) if v["multimodal_time_errors"]["duration"] else 0.0
            },
            "time_error_difference": {
                "start_diff": float(np.mean(v["multimodal_time_errors"]["start"]) - np.mean(v["text_time_errors"]["start"])) if (v["multimodal_time_errors"]["start"] and v["text_time_errors"]["start"]) else 0.0,
                "end_diff": float(np.mean(v["multimodal_time_errors"]["end"]) - np.mean(v["text_time_errors"]["end"])) if (v["multimodal_time_errors"]["end"] and v["text_time_errors"]["end"]) else 0.0,
                "duration_diff": float(np.mean(v["multimodal_time_errors"]["duration"]) - np.mean(v["text_time_errors"]["duration"])) if (v["multimodal_time_errors"]["duration"] and v["text_time_errors"]["duration"]) else 0.0
            },
            "total_queries": v["total_queries"],
            "num_videos": len(v["text_recall"])
        } for k, v in need_type_aggregated.items()},
        "timeliness_deep_analysis": {
            "by_duration": {
                duration_cat: {
                    "text_time_errors": {
                        "start_mean": float(np.mean(timeliness_deep_analysis["by_duration"][duration_cat]["text_time_errors"]["start"])) if timeliness_deep_analysis["by_duration"][duration_cat]["text_time_errors"]["start"] else 0.0,
                        "end_mean": float(np.mean(timeliness_deep_analysis["by_duration"][duration_cat]["text_time_errors"]["end"])) if timeliness_deep_analysis["by_duration"][duration_cat]["text_time_errors"]["end"] else 0.0,
                        "duration_mean": float(np.mean(timeliness_deep_analysis["by_duration"][duration_cat]["text_time_errors"]["duration"])) if timeliness_deep_analysis["by_duration"][duration_cat]["text_time_errors"]["duration"] else 0.0
                    },
                    "multimodal_time_errors": {
                        "start_mean": float(np.mean(timeliness_deep_analysis["by_duration"][duration_cat]["multimodal_time_errors"]["start"])) if timeliness_deep_analysis["by_duration"][duration_cat]["multimodal_time_errors"]["start"] else 0.0,
                        "end_mean": float(np.mean(timeliness_deep_analysis["by_duration"][duration_cat]["multimodal_time_errors"]["end"])) if timeliness_deep_analysis["by_duration"][duration_cat]["multimodal_time_errors"]["end"] else 0.0,
                        "duration_mean": float(np.mean(timeliness_deep_analysis["by_duration"][duration_cat]["multimodal_time_errors"]["duration"])) if timeliness_deep_analysis["by_duration"][duration_cat]["multimodal_time_errors"]["duration"] else 0.0
                    }
                } for duration_cat in ["short", "medium", "long"]
            },
            "by_position": {
                position: {
                    "text_time_errors": {
                        "start_mean": float(np.mean(timeliness_deep_analysis["by_position"][position]["text_time_errors"]["start"])) if timeliness_deep_analysis["by_position"][position]["text_time_errors"]["start"] else 0.0,
                        "end_mean": float(np.mean(timeliness_deep_analysis["by_position"][position]["text_time_errors"]["end"])) if timeliness_deep_analysis["by_position"][position]["text_time_errors"]["end"] else 0.0,
                        "duration_mean": float(np.mean(timeliness_deep_analysis["by_position"][position]["text_time_errors"]["duration"])) if timeliness_deep_analysis["by_position"][position]["text_time_errors"]["duration"] else 0.0
                    },
                    "multimodal_time_errors": {
                        "start_mean": float(np.mean(timeliness_deep_analysis["by_position"][position]["multimodal_time_errors"]["start"])) if timeliness_deep_analysis["by_position"][position]["multimodal_time_errors"]["start"] else 0.0,
                        "end_mean": float(np.mean(timeliness_deep_analysis["by_position"][position]["multimodal_time_errors"]["end"])) if timeliness_deep_analysis["by_position"][position]["multimodal_time_errors"]["end"] else 0.0,
                        "duration_mean": float(np.mean(timeliness_deep_analysis["by_position"][position]["multimodal_time_errors"]["duration"])) if timeliness_deep_analysis["by_position"][position]["multimodal_time_errors"]["duration"] else 0.0
                    }
                } for position in ["early", "middle", "late"]
            }
        }
    }
    
    # Save results
    output_file = os.path.join(output_dir, "multimodal_analysis.json")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    
    
    return results


def aggregate_by_type(scenarios: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Aggregate scenarios by need type."""
    by_type = defaultdict(lambda: {"count": 0, "avg_improvement": [], "avg_degradation": []})
    
    for scenario in scenarios:
        need_type = scenario.get("need_type", "Unknown")
        by_type[need_type]["count"] += 1
        
        if "improvement" in scenario:
            by_type[need_type]["avg_improvement"].append(scenario["improvement"])
        if "degradation" in scenario:
            by_type[need_type]["avg_degradation"].append(scenario["degradation"])
    
    result = {}
    for need_type, data in by_type.items():
        result[need_type] = {
            "count": data["count"],
            "avg_improvement": float(np.mean(data["avg_improvement"])) if data["avg_improvement"] else 0.0,
            "avg_degradation": float(np.mean(data["avg_degradation"])) if data["avg_degradation"] else 0.0
        }
    
    return result


if __name__ == "__main__":
    # Configuration
    evaluation_dir = os.getenv("EVALUATION_DIR", "../evaluation_output")
    output_dir = os.getenv("OUTPUT_DIR", "../iclr_rebuttal/multimodal_analysis")
    
    # Get all model IDs and extract text model pairs
    all_model_ids = get_all_model_ids()
    text_model_pairs = get_text_model_pairs(all_model_ids)
    text_model_ids = [pair[0] for pair in text_model_pairs]
    
    print(f"Found {len(text_model_ids)} text-multimodal pairs: {text_model_pairs}")
    
    print("Starting multimodal analysis...")
    print(f"Using GPU: {torch.cuda.is_available()}, GPUs available: {torch.cuda.device_count() if torch.cuda.is_available() else 0}")
    results = analyze_multimodal_performance(
        evaluation_dir=evaluation_dir,
        output_dir=output_dir,
        text_model_ids=text_model_ids,
        similarity_threshold=0.55,
        fuzzy_sentence_interval=1,
        num_workers=20,  # Optimized: reduced from 32 to 20 to reduce I/O competition
        use_gpu=True,
        num_gpus=4
    )
    
    print(f"\nMultimodal analysis complete!")
    print(f"Results saved to: {output_dir}")
    print(f"\nSummary:")
    print(f"  Text models recall: {results['recall_drop_analysis']['text_models_recall']:.3f}")
    print(f"  Multimodal models recall: {results['recall_drop_analysis']['multimodal_models_recall']:.3f}")
    print(f"  Recall drop: {results['recall_drop_analysis']['recall_drop']:.3f}")

