
import sys
import os
import json
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from collections import defaultdict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
import threading

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from utils import load_jsonl, METAINFO_DIR
from evaluate_utils import get_duration_by_youtube_id
from evaluate import JIREvaluator
from model_config import get_all_model_ids
from similarity_cache import get_or_compute_matches
from need_type_utils import get_need_type_label, normalize_need_type_label
import csv

SIMILARITY_THRESHOLD = 0.55
FUZZY_SENTENCE_INTERVAL = 1

_model_init_lock = threading.Lock()

_evaluator_pool = {}
_evaluator_pool_lock = threading.Lock()

def initialize_evaluator_pool(num_gpus: int, similarity_threshold: float = 0.55):
    global _evaluator_pool
    
    
    if not torch.cuda.is_available():
        with _model_init_lock:
            _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
        return
    
    for gpu_id in range(num_gpus):
        device = f'cuda:{gpu_id}'
        try:
            with _model_init_lock:
                evaluator = JIREvaluator(similarity_threshold=similarity_threshold, device=device)
                _evaluator_pool[device] = evaluator
        except Exception as e:
            import sys
            if 'cpu' not in _evaluator_pool:
                with _model_init_lock:
                    _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
    
    if 'cpu' not in _evaluator_pool:
        with _model_init_lock:
            _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
    

def get_evaluator(device: str, similarity_threshold: float = 0.55):
    global _evaluator_pool
    
    with _evaluator_pool_lock:
        if device in _evaluator_pool:
            return _evaluator_pool[device]
        
        if 'cpu' in _evaluator_pool:
            return _evaluator_pool['cpu']
        
        import sys
        with _model_init_lock:
            return JIREvaluator(similarity_threshold=similarity_threshold, device=device)

def get_temporal_position(start_time: float, duration: float) -> str:
    if duration <= 0:
        return "unknown"
    
    position_ratio = start_time / duration
    if position_ratio < 0.33:
        return "early"
    elif position_ratio < 0.67:
        return "middle"
    else:
        return "late"

def load_video_classifications(csv_file: str = "/home/key4/JIRArena-exp/data/metainfo/video_classess_patch.csv") -> Dict[str, Dict[str, str]]:
    classifications = {}
    if os.path.exists(csv_file):
        with open(csv_file, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                youtube_id = row.get('youtube_id', '')
                if youtube_id:
                    classifications[youtube_id] = {
                        'content_format': row.get('content_format', 'Unknown'),
                        'content_focus': row.get('content_focus', 'Unknown'),
                        'production_style': row.get('production_style', 'Unknown')
                    }
    return classifications

def process_single_video_contextual(args):
    (model_id, youtube_id, filename, video_metadata, evaluation_dir, 
     youtube_ids_set, similarity_threshold, fuzzy_sentence_interval, gpu_id) = args
    
    try:
        if youtube_id not in youtube_ids_set:
            return None
        
        evaluation_file = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}", filename)
        if not os.path.exists(evaluation_file):
            return None
        
        ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
        if not os.path.exists(ground_truth_file):
            return None
        
        candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{model_id}", filename)
        if not os.path.exists(candidate_file):
            return None
        
        evaluation_data = json.load(open(evaluation_file, "r"))
        queries_to_match = load_jsonl(ground_truth_file)
        candidate_data = json.load(open(candidate_file, "r"))
        candidate_queries = candidate_data.get("needs", [])
        metadata = video_metadata.get(youtube_id)
        
        if metadata is None:
            return None
        
        duration = metadata.get("duration", 0)
        num_needs = len(queries_to_match)
        need_density = num_needs / (duration / 60) if duration > 0 else 0.0
        

        video_classifications = load_video_classifications()
        video_class = video_classifications.get(youtube_id, {})
        content_format = video_class.get('content_format', 'Unknown')
        content_focus = video_class.get('content_focus', 'Unknown')
        production_style = video_class.get('production_style', 'Unknown')
        

        relevance = evaluation_data.get("relevance", {})
        query_scores = relevance.get("query_scores", {})
        

        evaluator = None
        def get_evaluator_lazy():
            nonlocal evaluator
            if evaluator is None:
                if torch.cuda.is_available() and gpu_id is not None:
                    device = f'cuda:{gpu_id}'
                else:
                    device = 'cpu'
                evaluator = get_evaluator(device, similarity_threshold)
            return evaluator
        

        match_results = get_or_compute_matches(
            model_id, youtube_id, queries_to_match, candidate_queries,
            get_evaluator_lazy, similarity_threshold, fuzzy_sentence_interval,
            use_cache=True
        )
        

        query_results = []
        for idx, query in enumerate(queries_to_match):
            query_relevance = query_scores.get(str(idx), 0.0)
            

            need_type = get_need_type_label(query, youtube_id=youtube_id)
            

            start_time = query.get("start_time", 0.0)
            temporal_position = get_temporal_position(start_time, duration)
            

            matches = match_results.get(idx, [])
            is_matched = len(matches) > 0
            

            is_vague = is_matched and query_relevance < 0.1
            
            query_results.append({
                "query_index": idx,
                "need_type": need_type,
                "content_format": content_format,
                "content_focus": content_focus,
                "production_style": production_style,
                "duration": duration,
                "need_density": need_density,
                "temporal_position": temporal_position,
                "start_time": start_time,
                "relevance": query_relevance,
                "is_matched": is_matched,
                "is_vague": is_vague,
                "question": query.get("question", ""),
                "need": query.get("need", "")
            })
        
        return {
            "youtube_id": youtube_id,
            "model_id": model_id,
            "duration": duration,
            "need_density": need_density,
            "query_results": query_results
        }
    except Exception as e:
        import sys
        import traceback
        traceback.print_exc(file=sys.stderr)
        return None

def analyze_contextual_impact(
    evaluation_dir: str,
    output_dir: str,
    model_ids: List[str],
    similarity_threshold: float = 0.55,
    fuzzy_sentence_interval: int = 1,
    num_workers: int = None,
    use_gpu: bool = True,
    num_gpus: int = 4
) -> Dict[str, Any]:
    os.makedirs(output_dir, exist_ok=True)
    

    youtube_ids = load_jsonl(os.path.join(METAINFO_DIR, "lecture_.jsonl")) + \
                  load_jsonl(os.path.join(METAINFO_DIR, "paper_.jsonl"))
    youtube_ids_set = set([item["youtube_id"] for item in youtube_ids])
    video_metadata = {item["youtube_id"]: item for item in youtube_ids}
    

    tasks = []
    for model_id in model_ids:
        model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
        if not os.path.exists(model_dir):
            continue
        
        for filename in os.listdir(model_dir):
            if not filename.endswith(".json"):
                continue
            
            youtube_id = filename[:-len(".json")]
            if youtube_id not in youtube_ids_set:
                continue
            
            tasks.append((model_id, youtube_id, filename, video_metadata, evaluation_dir,
                         youtube_ids_set, similarity_threshold, fuzzy_sentence_interval, None))
    

    if num_workers is None:
        num_workers = min(len(tasks), num_gpus * 4 if use_gpu and torch.cuda.is_available() else os.cpu_count())
    

    if use_gpu and torch.cuda.is_available():
        initialize_evaluator_pool(num_gpus, similarity_threshold)
        tasks_with_gpu = []
        for i, task in enumerate(tasks):
            gpu_id = i % num_gpus
            tasks_with_gpu.append(task[:-1] + (gpu_id,))
        tasks = tasks_with_gpu
    else:
        initialize_evaluator_pool(0, similarity_threshold)
    

    content_format_stats = defaultdict(lambda: {"relevance_scores": [], "vague_count": 0, "matched_count": 0, "total_count": 0})
    content_focus_stats = defaultdict(lambda: {"relevance_scores": [], "vague_count": 0, "matched_count": 0, "total_count": 0})
    production_style_stats = defaultdict(lambda: {"relevance_scores": [], "vague_count": 0, "matched_count": 0, "total_count": 0})
    duration_stats = defaultdict(lambda: {"relevance_scores": [], "vague_count": 0, "matched_count": 0, "total_count": 0})

    need_density_stats = {"_raw": []}
    temporal_stats = defaultdict(lambda: {"relevance_scores": [], "vague_count": 0, "matched_count": 0, "total_count": 0})
    

    all_need_densities = []
    

    query_type_stats = defaultdict(lambda: {
        "relevance_scores": [],
        "vague_count": 0,
        "matched_count": 0,
        "total_count": 0,
        "by_duration": defaultdict(lambda: {"relevance_scores": [], "count": 0}),
        "by_need_density": defaultdict(lambda: {"relevance_scores": [], "count": 0}),
        "by_temporal": defaultdict(lambda: {"relevance_scores": [], "count": 0})
    })
    

    vague_queries = []
    

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_video_contextual, task): task for task in tasks}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing videos"):
            result = future.result()
            if result is None:
                continue
            
            duration = result.get("duration", 0)
            need_density = result.get("need_density", 0.0)
            

            duration_minutes = duration / 60 if duration > 0 else 0
            if duration_minutes < 10:
                duration_group = "short"
            elif duration_minutes < 30:
                duration_group = "medium"
            else:

                duration_group = "long"
            

            if need_density > 0 and need_density not in all_need_densities:
                all_need_densities.append(need_density)
            

            for query_result in result["query_results"]:
                relevance = query_result["relevance"]
                need_type = query_result["need_type"]
                temporal_position = query_result["temporal_position"]
                is_vague = query_result["is_vague"]
                is_matched = query_result.get("is_matched", relevance > 0)

                query_need_density = need_density
                content_format = query_result.get("content_format", "Unknown")
                content_focus = query_result.get("content_focus", "Unknown")
                production_style = query_result.get("production_style", "Unknown")
                

                content_format_stats[content_format]["relevance_scores"].append(relevance)
                content_format_stats[content_format]["total_count"] += 1
                if is_matched:
                    content_format_stats[content_format]["matched_count"] += 1
                if is_vague:
                    content_format_stats[content_format]["vague_count"] += 1
                

                content_focus_stats[content_focus]["relevance_scores"].append(relevance)
                content_focus_stats[content_focus]["total_count"] += 1
                if is_matched:
                    content_focus_stats[content_focus]["matched_count"] += 1
                if is_vague:
                    content_focus_stats[content_focus]["vague_count"] += 1
                

                production_style_stats[production_style]["relevance_scores"].append(relevance)
                production_style_stats[production_style]["total_count"] += 1
                if is_matched:
                    production_style_stats[production_style]["matched_count"] += 1
                if is_vague:
                    production_style_stats[production_style]["vague_count"] += 1
                

                duration_stats[duration_group]["relevance_scores"].append(relevance)
                duration_stats[duration_group]["total_count"] += 1
                if is_matched:
                    duration_stats[duration_group]["matched_count"] += 1
                if is_vague:
                    duration_stats[duration_group]["vague_count"] += 1
                

                need_density_stats["_raw"].append({
                    "relevance": relevance,
                    "need_density": query_need_density,
                    "is_vague": is_vague,
                    "is_matched": is_matched
                })
                

                temporal_stats[temporal_position]["relevance_scores"].append(relevance)
                temporal_stats[temporal_position]["total_count"] += 1
                if is_matched:
                    temporal_stats[temporal_position]["matched_count"] += 1
                if is_vague:
                    temporal_stats[temporal_position]["vague_count"] += 1
                

                query_type_stats[need_type]["relevance_scores"].append(relevance)
                query_type_stats[need_type]["total_count"] += 1
                if is_matched:
                    query_type_stats[need_type]["matched_count"] += 1
                if is_vague:
                    query_type_stats[need_type]["vague_count"] += 1
                    vague_queries.append(query_result)
                

                query_type_stats[need_type]["by_duration"][duration_group]["relevance_scores"].append(relevance)
                query_type_stats[need_type]["by_duration"][duration_group]["count"] += 1
                

                if "_raw_need_density" not in query_type_stats[need_type]:
                    query_type_stats[need_type]["_raw_need_density"] = []
                query_type_stats[need_type]["_raw_need_density"].append({
                    "relevance": relevance,
                    "need_density": query_need_density
                })
                
                query_type_stats[need_type]["by_temporal"][temporal_position]["relevance_scores"].append(relevance)
                query_type_stats[need_type]["by_temporal"][temporal_position]["count"] += 1
    

    p33 = None
    p67 = None
    if all_need_densities:

        p33 = np.percentile(all_need_densities, 20)
        p67 = np.percentile(all_need_densities, 80)
        

        need_density_grouped = defaultdict(lambda: {"relevance_scores": [], "vague_count": 0, "matched_count": 0, "total_count": 0})
        for item in need_density_stats["_raw"]:
            if item["need_density"] < p33:
                group = "low"
            elif item["need_density"] < p67:
                group = "medium"
            else:
                group = "high"
            need_density_grouped[group]["relevance_scores"].append(item["relevance"])
            need_density_grouped[group]["total_count"] += 1
            if item.get("is_matched", False):
                need_density_grouped[group]["matched_count"] += 1
            if item["is_vague"]:
                need_density_grouped[group]["vague_count"] += 1
        need_density_stats = need_density_grouped
        
        

        for need_type, stats in query_type_stats.items():
            if "_raw_need_density" in stats:
                for item in stats["_raw_need_density"]:
                    if item["need_density"] < p33:
                        group = "low"
                    elif item["need_density"] < p67:
                        group = "medium"
                    else:
                        group = "high"
                    if group not in stats["by_need_density"]:
                        stats["by_need_density"][group] = {"relevance_scores": [], "count": 0}
                    stats["by_need_density"][group]["relevance_scores"].append(item["relevance"])
                    stats["by_need_density"][group]["count"] += 1

                del stats["_raw_need_density"]
    

    def calc_stats(data_dict):
        stats = {}
        for key, data in data_dict.items():
            if key == "_raw":
                continue
            if data["relevance_scores"]:

                matched_scores = [s for s in data["relevance_scores"] if s > 0]
                median_relevance = float(np.median(matched_scores)) if matched_scores else 0.0
                

                matched_count = data.get("matched_count", 0)
                vague_rate = data["vague_count"] / matched_count if matched_count > 0 else 0.0
                
                stats[key] = {
                    "mean_relevance": float(np.mean(data["relevance_scores"])),
                    "median_relevance": median_relevance,
                    "median_relevance_all": float(np.median(data["relevance_scores"])),
                    "std_relevance": float(np.std(data["relevance_scores"])),
                    "vague_rate": vague_rate,
                    "vague_rate_all": data["vague_count"] / data["total_count"] if data["total_count"] > 0 else 0.0,
                    "vague_count": data["vague_count"],
                    "matched_count": matched_count,
                    "total_count": data["total_count"],
                    "non_zero_relevance_count": len(matched_scores)
                }
        return stats
    
    content_format_impact = calc_stats(content_format_stats)
    content_focus_impact = calc_stats(content_focus_stats)
    production_style_impact = calc_stats(production_style_stats)
    duration_impact = calc_stats(duration_stats)
    need_density_impact = calc_stats(need_density_stats)
    temporal_impact = calc_stats(temporal_stats)
    

    query_type_impact = {}
    for need_type, stats in query_type_stats.items():
        if stats["relevance_scores"]:
            query_type_impact[need_type] = {
                "mean_relevance": float(np.mean(stats["relevance_scores"])),
                "median_relevance": float(np.median(stats["relevance_scores"])),
                "std_relevance": float(np.std(stats["relevance_scores"])),
                "vague_rate": stats["vague_count"] / stats["total_count"] if stats["total_count"] > 0 else 0.0,
                "vague_count": stats["vague_count"],
                "total_count": stats["total_count"],
                "by_duration": {
                    dgroup: {
                        "mean_relevance": float(np.mean(data["relevance_scores"])) if data["relevance_scores"] else 0.0,
                        "count": data["count"]
                    }
                    for dgroup, data in stats["by_duration"].items()
                },
                "by_need_density": {
                    dgroup: {
                        "mean_relevance": float(np.mean(data["relevance_scores"])) if data["relevance_scores"] else 0.0,
                        "count": data["count"]
                    }
                    for dgroup, data in stats["by_need_density"].items()
                },
                "by_temporal": {
                    tpos: {
                        "mean_relevance": float(np.mean(data["relevance_scores"])) if data["relevance_scores"] else 0.0,
                        "count": data["count"]
                    }
                    for tpos, data in stats["by_temporal"].items()
                }
            }
    

    def get_range(impact_dict):
        if not impact_dict:
            return 0.0
        values = [s["mean_relevance"] for s in impact_dict.values()]
        return max(values) - min(values) if values else 0.0
    
    context_importance = {
        "content_focus": {
            "range": get_range(content_focus_impact),
            "impact": content_focus_impact
        },
        "content_format": {
            "range": get_range(content_format_impact),
            "impact": content_format_impact
        },
        "production_style": {
            "range": get_range(production_style_impact),
            "impact": production_style_impact
        },
        "duration": {
            "range": get_range(duration_impact),
            "impact": duration_impact
        },
        "temporal_position": {
            "range": get_range(temporal_impact),
            "impact": temporal_impact
        },
        "need_density": {
            "range": get_range(need_density_impact),
            "impact": need_density_impact
        }
    }
    

    context_importance_ranking = sorted(
        context_importance.items(),
        key=lambda x: x[1]["range"],
        reverse=True
    )
    

    results = {
        "content_format_impact": content_format_impact,
        "content_focus_impact": content_focus_impact,
        "production_style_impact": production_style_impact,
        "duration_impact": duration_impact,
        "need_density_impact": need_density_impact,
        "temporal_impact": temporal_impact,
        "query_type_impact": query_type_impact,
        "context_importance_ranking": [
            {"factor": factor, "range": data["range"], "impact": data["impact"]}
            for factor, data in context_importance_ranking
        ],
        "vague_queries_sample": vague_queries[:100],
        "need_density_percentiles": {
            "p33": float(p33) if all_need_densities else 0.0,
            "p67": float(p67) if all_need_densities else 0.0
        } if all_need_densities else {}
    }
    

    output_file = os.path.join(output_dir, "contextual_analysis.json")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    

    vague_queries_file = os.path.join(output_dir, "vague_queries.json")
    with open(vague_queries_file, "w", encoding="utf-8") as f:
        json.dump(vague_queries[:500], f, indent=2, ensure_ascii=False)
    

    generate_markdown_report(results, output_dir)
    
    return results

def generate_markdown_report(results: Dict[str, Any], output_dir: str):
    report_file = os.path.join(output_dir, "contextual_analysis_report.md")
    
    with open(report_file, "w", encoding="utf-8") as f:
        f.write("# Contextual Analysis Report\n\n")
        f.write("## Overview\n\n")
        f.write("This analysis focuses on how contextual information affects retrieval quality (relevance).\n")
        f.write("The core finding is that queries without contextual information appear vague, leading to low retrieval quality.\n\n")
        

        f.write("## Context Importance Ranking\n\n")
        f.write("Context factors ranked by their impact on retrieval quality (relevance range):\n\n")
        f.write("| Rank | Context Factor | Relevance Range | Impact Description |\n")
        f.write("|------|----------------|-----------------|-------------------|\n")
        for rank, item in enumerate(results["context_importance_ranking"], 1):
            factor = item["factor"]
            range_val = item["range"]
            impact = item["impact"]

            if impact:
                best = max(impact.items(), key=lambda x: x[1]["mean_relevance"])
                worst = min(impact.items(), key=lambda x: x[1]["mean_relevance"])
                desc = f"Best: {best[0]} ({best[1]['mean_relevance']:.3f}), Worst: {worst[0]} ({worst[1]['mean_relevance']:.3f})"
            else:
                desc = "N/A"
            f.write(f"| {rank} | {factor} | {range_val:.4f} | {desc} |\n")
        f.write("\n")
        

        if "content_focus_impact" in results and results["content_focus_impact"]:
            f.write("## Content Focus Impact on Retrieval Quality\n\n")
            f.write("| Content Focus | Mean Relevance | Median Relevance (Matched) | Std | Vague Rate (of Matched) | Matched Queries | Total Queries |\n")
            f.write("|---------------|----------------|----------------------------|-----|-------------------------|-----------------|---------------|\n")
            for cfocus, stats in sorted(results["content_focus_impact"].items(), 
                                       key=lambda x: x[1]["mean_relevance"], reverse=True):
                f.write(f"| {cfocus} | {stats['mean_relevance']:.4f} | {stats['median_relevance']:.4f} | "
                       f"{stats['std_relevance']:.4f} | {stats['vague_rate']:.2%} | {stats.get('matched_count', 0)} | {stats['total_count']} |\n")
            f.write("\n")
        

        if "content_format_impact" in results and results["content_format_impact"]:
            f.write("## Content Format Impact on Retrieval Quality\n\n")
            f.write("| Content Format | Mean Relevance | Median Relevance (Matched) | Std | Vague Rate (of Matched) | Matched Queries | Total Queries |\n")
            f.write("|----------------|----------------|----------------------------|-----|-------------------------|-----------------|---------------|\n")
            for cformat, stats in sorted(results["content_format_impact"].items(),
                                        key=lambda x: x[1]["mean_relevance"], reverse=True):
                f.write(f"| {cformat} | {stats['mean_relevance']:.4f} | {stats['median_relevance']:.4f} | "
                       f"{stats['std_relevance']:.4f} | {stats['vague_rate']:.2%} | {stats.get('matched_count', 0)} | {stats['total_count']} |\n")
            f.write("\n")
        

        if "production_style_impact" in results and results["production_style_impact"]:
            f.write("## Production Style Impact on Retrieval Quality\n\n")
            f.write("| Production Style | Mean Relevance | Median Relevance (Matched) | Std | Vague Rate (of Matched) | Matched Queries | Total Queries |\n")
            f.write("|------------------|----------------|----------------------------|-----|-------------------------|-----------------|---------------|\n")
            for pstyle, stats in sorted(results["production_style_impact"].items(),
                                       key=lambda x: x[1]["mean_relevance"], reverse=True):
                f.write(f"| {pstyle} | {stats['mean_relevance']:.4f} | {stats['median_relevance']:.4f} | "
                       f"{stats['std_relevance']:.4f} | {stats['vague_rate']:.2%} | {stats.get('matched_count', 0)} | {stats['total_count']} |\n")
        f.write("\n")
        

        f.write("## Duration Impact on Retrieval Quality\n\n")
        f.write("| Duration Group | Mean Relevance | Median Relevance (Matched) | Std | Vague Rate (of Matched) | Matched Queries | Total Queries |\n")
        f.write("|----------------|----------------|----------------------------|-----|-------------------------|-----------------|---------------|\n")
        for group, stats in results["duration_impact"].items():
            f.write(f"| {group} | {stats['mean_relevance']:.4f} | {stats['median_relevance']:.4f} | "
                   f"{stats['std_relevance']:.4f} | {stats['vague_rate']:.2%} | {stats.get('matched_count', 0)} | {stats['total_count']} |\n")
        f.write("\n")
        

        f.write("## Need Density Impact on Retrieval Quality\n\n")
        if "need_density_percentiles" in results and results["need_density_percentiles"]:
            p33 = results["need_density_percentiles"].get("p33", 0.0)
            p67 = results["need_density_percentiles"].get("p67", 0.0)
            f.write(f"*Need density groups (queries per minute): Low (<{p33:.1f}), Medium ({p33:.1f}-{p67:.1f}), High (>{p67:.1f})*\n\n")
        f.write("| Density Group | Mean Relevance | Median Relevance (Matched) | Std | Vague Rate (of Matched) | Matched Queries | Total Queries |\n")
        f.write("|---------------|----------------|----------------------------|-----|-------------------------|-----------------|---------------|\n")
        for group in ["low", "medium", "high"]:
            if group in results["need_density_impact"]:
                stats = results["need_density_impact"][group]
                f.write(f"| {group} | {stats['mean_relevance']:.4f} | {stats['median_relevance']:.4f} | "
                       f"{stats['std_relevance']:.4f} | {stats['vague_rate']:.2%} | {stats.get('matched_count', 0)} | {stats['total_count']} |\n")
        f.write("\n")
        

        f.write("## Temporal Position Impact on Retrieval Quality\n\n")
        f.write("| Temporal Position | Mean Relevance | Median Relevance (Matched) | Std | Vague Rate (of Matched) | Matched Queries | Total Queries |\n")
        f.write("|-------------------|----------------|----------------------------|-----|-------------------------|-----------------|---------------|\n")
        for tpos, stats in results["temporal_impact"].items():
            f.write(f"| {tpos} | {stats['mean_relevance']:.4f} | {stats['median_relevance']:.4f} | "
                   f"{stats['std_relevance']:.4f} | {stats['vague_rate']:.2%} | {stats.get('matched_count', 0)} | {stats['total_count']} |\n")
        f.write("\n")
        

        f.write("## Query Type Impact on Retrieval Quality\n\n")
        f.write("| Query Type | Mean Relevance | Median Relevance (Matched) | Std | Vague Rate (of Matched) | Matched Queries | Total Queries |\n")
        f.write("|------------|----------------|----------------------------|-----|-------------------------|-----------------|---------------|\n")

        sorted_types = sorted(
            results["query_type_impact"].items(),
            key=lambda x: x[1].get("vague_rate", 0),
            reverse=True
        )
        for need_type, stats in sorted_types:
            matched_count = stats.get("matched_count", 0)
            vague_rate = stats.get("vague_rate", 0.0)
            f.write(f"| {need_type} | {stats['mean_relevance']:.4f} | {stats.get('median_relevance', 0.0):.4f} | "
                   f"{stats['std_relevance']:.4f} | {vague_rate:.2%} | {matched_count} | {stats['total_count']} |\n")
        f.write("\n")
        

        f.write("## Query Types That Need Context Most\n\n")
        f.write("Query types with highest vague rate (low relevance despite being matched):\n\n")
        top_vague_types = sorted(
            results["query_type_impact"].items(),
            key=lambda x: x[1]["vague_rate"],
            reverse=True
        )[:5]
        for need_type, stats in top_vague_types:
            f.write(f"### {need_type}\n\n")
            f.write(f"- **Vague Rate**: {stats['vague_rate']:.2%}\n")
            f.write(f"- **Mean Relevance**: {stats['mean_relevance']:.4f}\n")
            f.write(f"- **Total Queries**: {stats['total_count']}\n")
            f.write(f"- **Vague Queries**: {stats['vague_count']}\n\n")
            

            f.write("**Context Dependency:**\n\n")
            if stats.get("by_duration"):
                f.write("- By Duration:\n")
                for dgroup, data in stats["by_duration"].items():
                    f.write(f"  - {dgroup}: {data['mean_relevance']:.4f} (n={data['count']})\n")
            if stats["by_need_density"]:
                f.write("- By Need Density:\n")
                for dgroup, data in stats["by_need_density"].items():
                    f.write(f"  - {dgroup}: {data['mean_relevance']:.4f} (n={data['count']})\n")
            if stats["by_temporal"]:
                f.write("- By Temporal Position:\n")
                for tpos, data in stats["by_temporal"].items():
                    f.write(f"  - {tpos}: {data['mean_relevance']:.4f} (n={data['count']})\n")
            f.write("\n")
        

        f.write("## Key Findings\n\n")
        f.write("1. **Context Importance**: The context factors are ranked by their impact on retrieval quality (relevance range).\n")
        f.write("2. **Vague Queries**: Queries that are matched but have low relevance scores (< 0.1) are considered vague.\n")
        f.write("   - Vague Rate is calculated as: vague_queries / matched_queries (not all queries)\n")
        f.write("   - This indicates the percentage of matched queries that have poor retrieval quality.\n")
        f.write("3. **Median Relevance**: Median relevance is calculated only for matched queries (non-zero relevance scores).\n")
        f.write("   - Most queries have relevance = 0 (unmatched), so overall median is 0.\n")
        f.write("   - Median relevance (matched) shows the typical relevance for queries that were successfully matched.\n")
        f.write("4. **Query Type Dependency**: Some query types (e.g., Missing Context, Ambiguous Language) need more contextual information.\n")
        f.write("5. **Need Density**: Need density groups are based on 20th and 80th percentiles of video-level need density (queries per minute).\n")
        f.write("6. **Contextualized Retrieval**: Future JIR systems should incorporate contextual background during retrieval.\n\n")

if __name__ == "__main__":

    evaluation_dir = "/home/key4/JIRArena-exp/evaluation_output"
    output_dir = "/home/key4/JIRArena-exp/iclr_rebuttal/contextual_analysis"
    

    all_model_ids = get_all_model_ids()
    model_ids = [m for m in all_model_ids if "oracle" not in m.lower()]
    
    
    results = analyze_contextual_impact(
        evaluation_dir=evaluation_dir,
        output_dir=output_dir,
        model_ids=model_ids,
        similarity_threshold=SIMILARITY_THRESHOLD,
        fuzzy_sentence_interval=FUZZY_SENTENCE_INTERVAL,
        num_workers=None,
        use_gpu=True,
        num_gpus=4
    )
    
