
import sys
import os
import json
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from collections import defaultdict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from utils import load_jsonl, METAINFO_DIR
from evaluate_utils import get_duration_by_youtube_id
from evaluate import JIREvaluator
from similarity_cache import get_or_compute_matches
import csv
import threading

_model_init_lock = threading.Lock()

_evaluator_pool = {}
_evaluator_pool_lock = threading.Lock()

def initialize_evaluator_pool(num_gpus: int, similarity_threshold: float = 0.55):
    global _evaluator_pool
    
    
    if not torch.cuda.is_available():

        with _model_init_lock:
            _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
        return
    

    for gpu_id in range(num_gpus):
        device = f'cuda:{gpu_id}'
        try:
            with _model_init_lock:
                evaluator = JIREvaluator(similarity_threshold=similarity_threshold, device=device)
                _evaluator_pool[device] = evaluator
        except Exception as e:
            import sys

            if 'cpu' not in _evaluator_pool:
                with _model_init_lock:
                    _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
    

    if 'cpu' not in _evaluator_pool:
        with _model_init_lock:
            _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
    

def get_evaluator_from_pool(device: str, similarity_threshold: float = 0.55):
    global _evaluator_pool
    
    with _evaluator_pool_lock:
        if device in _evaluator_pool:
            return _evaluator_pool[device]
        

        if 'cpu' in _evaluator_pool:
            return _evaluator_pool['cpu']
        

        import sys
        with _model_init_lock:
            return JIREvaluator(similarity_threshold=similarity_threshold, device=device)

def load_video_durations_from_csv(csv_file: str = "/home/key4/JIRArena-exp/data/metainfo/video_classess_patch.csv") -> Dict[str, float]:
    durations = {}
    if os.path.exists(csv_file):
        with open(csv_file, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                youtube_id = row.get('youtube_id', '')
                if youtube_id:
                    duration_sec = row.get('duration_sec', '')
                    if duration_sec:
                        try:
                            durations[youtube_id] = float(duration_sec)
                        except ValueError:
                            continue
    return durations

def load_oracle_data(oracle_dir: str, youtube_id: str) -> Optional[List[Dict[str, Any]]]:

    possible_paths = [
        os.path.join(oracle_dir, f"{youtube_id}.json"),
        os.path.join("/home/key4/JIRArena-exp/evaluation_output/rebuttal_baseline_stream_runs_oracle", f"{youtube_id}.json"),
        os.path.join(oracle_dir, f"rebuttal_baseline_stream_runs_oracle", f"{youtube_id}.json")
    ]
    
    for oracle_file in possible_paths:
        if os.path.exists(oracle_file):
            try:
                data = json.load(open(oracle_file, "r"))

                if "needs" in data:
                    return data.get("needs", [])

                return None
            except Exception:
                continue
    
    return None

def process_single_video_oracle(args):
    model_id, youtube_id, filename, youtube_ids_set, oracle_dir, similarity_threshold, fuzzy_sentence_interval, gpu_id, queries_cache, duration_cache = args
    
    try:
        if torch.cuda.is_available() and gpu_id is not None:
            device = f'cuda:{gpu_id}'
        else:
            device = 'cpu'
        
        if youtube_id not in youtube_ids_set:
            return None
        
        evaluation_dir = "/home/key4/JIRArena-exp/evaluation_output"
        baseline_file = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}", filename)
        if not os.path.exists(baseline_file):
            return None
        

        if queries_cache is not None and youtube_id in queries_cache:
            queries_to_match = queries_cache[youtube_id]
        else:
        ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
        if not os.path.exists(ground_truth_file):
            return None
            queries_to_match = load_jsonl(ground_truth_file)
        

        oracle_eval_file = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_oracle", filename)
        if os.path.exists(oracle_eval_file):

            oracle_candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_oracle", filename)
            if os.path.exists(oracle_candidate_file):
                oracle_data = json.load(open(oracle_candidate_file, "r"))
                oracle_queries = oracle_data.get("needs", [])
            else:

                return None
        else:
            return None
        
        baseline_candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{model_id}", filename)
        if not os.path.exists(baseline_candidate_file):
            return None
        
        baseline_candidates = json.load(open(baseline_candidate_file, "r")).get("needs", [])
        baseline_data = json.load(open(baseline_file, "r"))
        

        from similarity_cache import get_cache_file_path, load_match_results
        cache_dir = "/home/key4/JIRArena-exp/evaluation_output/.similarity_cache"
        
        oracle_cache_file = get_cache_file_path(cache_dir, "oracle", youtube_id, similarity_threshold, fuzzy_sentence_interval)
        oracle_cached = load_match_results(oracle_cache_file)
        

        def get_evaluator():
            return get_evaluator_from_pool(device, similarity_threshold)
        

        if oracle_cached is not None and len(oracle_cached.get("queries_to_match", [])) == len(queries_to_match) and len(oracle_cached.get("candidate_queries", [])) == len(oracle_queries):
            oracle_results = oracle_cached["match_results"]
        else:
            oracle_results = get_or_compute_matches(
                "oracle", youtube_id, queries_to_match, oracle_queries,
                get_evaluator, similarity_threshold, fuzzy_sentence_interval,
                use_cache=True
            )
        

        evaluator = get_evaluator_from_pool(device, similarity_threshold)
        
        oracle_recall = evaluator.compute_recall(queries_to_match, oracle_results)
        oracle_precision = evaluator.compute_precision(oracle_queries, oracle_results)
        oracle_relevance = evaluator.evaluate_relevance(oracle_results, queries_to_match)
        oracle_timeliness = evaluator.evaluate_timeliness(queries_to_match, oracle_results)
        
        baseline_recall = baseline_data.get("recall", {}).get("recall", 0.0)
        baseline_precision = baseline_data.get("precision", {}).get("precision", 0.0)

        baseline_relevance = baseline_data.get("relevance", {}).get("weighted_ndcg", 0.0)
        baseline_timeliness = baseline_data.get("timeliness", {}).get("weighted_time_match", 0.0)
        

        duration = 0
        if duration_cache is not None and youtube_id in duration_cache:
            duration = duration_cache[youtube_id]
        elif baseline_data.get("duration", 0) > 0:
            duration = baseline_data.get("duration", 0)
        else:

            duration = get_duration_by_youtube_id(youtube_id)
        
        oracle_relevance_score = oracle_relevance.get("weighted_ndcg", 0.0)
        relevance_improvement = oracle_relevance_score - baseline_relevance
        
        return {
            "oracle_recall": oracle_recall.get("recall", 0.0),
            "oracle_precision": oracle_precision.get("precision", 0.0),
            "oracle_relevance": oracle_relevance_score,
            "oracle_timeliness": oracle_timeliness.get("weighted_time_match", 0.0),
            "baseline_recall": baseline_recall,
            "baseline_precision": baseline_precision,
            "baseline_relevance": baseline_relevance,
            "baseline_timeliness": baseline_timeliness,
            "relevance_improvement": relevance_improvement,
            "duration": duration
        }
    except Exception as e:
        return None

def analyze_oracle_performance(
    evaluation_dir: str,
    oracle_dir: str,
    output_dir: str,
    model_ids: List[str],
    similarity_threshold: float = 0.55,
    fuzzy_sentence_interval: int = 1,
    num_workers: int = 20,
    use_gpu: bool = True,
    num_gpus: int = 4
) -> Dict[str, Any]:
    os.makedirs(output_dir, exist_ok=True)
    

    youtube_ids = load_jsonl(os.path.join(METAINFO_DIR, "lecture_.jsonl")) + \
                  load_jsonl(os.path.join(METAINFO_DIR, "paper_.jsonl"))
    youtube_ids = [item["youtube_id"] for item in youtube_ids]
    youtube_ids_set = set(youtube_ids)
    

    queries_cache = {}
    for youtube_id in youtube_ids:
        ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
        if os.path.exists(ground_truth_file):
            try:
                queries_cache[youtube_id] = load_jsonl(ground_truth_file)
            except Exception as e:
    

    duration_cache = load_video_durations_from_csv()
    

    if use_gpu and torch.cuda.is_available():
        initialize_evaluator_pool(num_gpus, similarity_threshold)
    else:
        initialize_evaluator_pool(0, similarity_threshold)
    

    tasks = []
    for model_id in model_ids:
        model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
        if not os.path.exists(model_dir):
            continue
        
        for filename in os.listdir(model_dir):
            if not filename.endswith(".json"):
                continue
            
            youtube_id = filename[:-len(".json")]
            if youtube_id not in youtube_ids_set:
                continue
            

            oracle_eval_file = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_oracle", filename)
            if not os.path.exists(oracle_eval_file):
                continue
            
            tasks.append((model_id, youtube_id, filename, youtube_ids_set, oracle_dir,
                         similarity_threshold, fuzzy_sentence_interval, None, queries_cache, duration_cache))
    

    if use_gpu and torch.cuda.is_available():
        tasks_with_gpu = []
        for i, task in enumerate(tasks):
            gpu_id = i % num_gpus

            task_list = list(task)
            task_list[7] = gpu_id
            tasks_with_gpu.append(tuple(task_list))
        tasks = tasks_with_gpu
    

    oracle_vs_baseline = {
        "oracle": {
            "recall": [], "precision": [], "relevance": [], "timeliness": [], "duration": []
        },
        "baseline": {
            "recall": [], "precision": [], "relevance": [], "timeliness": [], "duration": []
        }
    }
    
    bottleneck_analysis = {
        "need_generation_bottleneck": 0,
        "retrieval_bottleneck": 0,
        "both_bottleneck": 0,
        "no_bottleneck": 0,
        "other_bottleneck": 0
    }
    

    bottleneck_samples = {
        "need_generation_bottleneck": [],
        "retrieval_bottleneck": [],
        "both_bottleneck": [],
        "no_bottleneck": [],
        "other_bottleneck": []
    }
    

    all_relevance_improvements = []
    all_oracle_relevances = []
    all_baseline_relevances = []
    all_youtube_ids = []
    

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_video_oracle, task): task for task in tasks}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing videos", mininterval=2.0):
            try:
                result = future.result(timeout=300)
                task = futures[future]
            except Exception as e:
                result = None
                task = None
            if result is None or task is None:
                continue
            

            youtube_id = task[1]
            

            duration = result.get("duration", 0)
            oracle_vs_baseline["oracle"]["recall"].append(result["oracle_recall"])
            oracle_vs_baseline["oracle"]["precision"].append(result["oracle_precision"])
            oracle_vs_baseline["oracle"]["relevance"].append(result["oracle_relevance"])
            oracle_vs_baseline["oracle"]["timeliness"].append(result["oracle_timeliness"])
            oracle_vs_baseline["oracle"]["duration"].append(duration)
            
            oracle_vs_baseline["baseline"]["recall"].append(result["baseline_recall"])
            oracle_vs_baseline["baseline"]["precision"].append(result["baseline_precision"])
            oracle_vs_baseline["baseline"]["relevance"].append(result["baseline_relevance"])
            oracle_vs_baseline["baseline"]["timeliness"].append(result["baseline_timeliness"])
            oracle_vs_baseline["baseline"]["duration"].append(duration)
            

            all_relevance_improvements.append(result["relevance_improvement"])
            all_oracle_relevances.append(result["oracle_relevance"])
            all_baseline_relevances.append(result["baseline_relevance"])
            all_youtube_ids.append(youtube_id)
    

    BAD_RELEVANCE_THRESHOLD = 0.1
    NEED_GEN_BASELINE_THRESHOLD = 0.2
    BOTH_BOTTLENECK_BASELINE_THRESHOLD = 0.2
    GOOD_BASELINE_THRESHOLD = 0.6
    GOOD_ORACLE_THRESHOLD = 0.6
    EXCELLENT_ORACLE_THRESHOLD = 0.75
    RELATIVE_IMPROVEMENT_THRESHOLD = 0.6
    
    

    for i, (oracle_relevance_score, baseline_relevance_score, youtube_id) in enumerate(zip(all_oracle_relevances, all_baseline_relevances, all_youtube_ids)):

        

        if oracle_relevance_score < GOOD_ORACLE_THRESHOLD and baseline_relevance_score < BOTH_BOTTLENECK_BASELINE_THRESHOLD:

            bottleneck_analysis["both_bottleneck"] += 1
            bottleneck_samples["both_bottleneck"].append({
                "youtube_id": youtube_id,
                "oracle_relevance": float(oracle_relevance_score),
                "baseline_relevance": float(baseline_relevance_score)
            })

        elif oracle_relevance_score >= EXCELLENT_ORACLE_THRESHOLD and baseline_relevance_score >= GOOD_BASELINE_THRESHOLD:

            bottleneck_analysis["no_bottleneck"] += 1
            bottleneck_samples["no_bottleneck"].append({
                "youtube_id": youtube_id,
                "oracle_relevance": float(oracle_relevance_score),
                "baseline_relevance": float(baseline_relevance_score)
            })

        elif baseline_relevance_score < NEED_GEN_BASELINE_THRESHOLD:

            if baseline_relevance_score > 0:
                relative_improvement = (oracle_relevance_score - baseline_relevance_score) / baseline_relevance_score
                if oracle_relevance_score >= GOOD_ORACLE_THRESHOLD or relative_improvement > RELATIVE_IMPROVEMENT_THRESHOLD:

                bottleneck_analysis["need_generation_bottleneck"] += 1
                    bottleneck_samples["need_generation_bottleneck"].append({
                        "youtube_id": youtube_id,
                        "oracle_relevance": float(oracle_relevance_score),
                        "baseline_relevance": float(baseline_relevance_score),
                        "relative_improvement": float(relative_improvement)
                    })
                else:
                    bottleneck_analysis["other_bottleneck"] += 1
                    bottleneck_samples["other_bottleneck"].append({
                        "youtube_id": youtube_id,
                        "oracle_relevance": float(oracle_relevance_score),
                        "baseline_relevance": float(baseline_relevance_score),
                        "relative_improvement": float(relative_improvement)
                    })
            else:

                if oracle_relevance_score >= GOOD_ORACLE_THRESHOLD:
                    bottleneck_analysis["need_generation_bottleneck"] += 1
                    bottleneck_samples["need_generation_bottleneck"].append({
                        "youtube_id": youtube_id,
                        "oracle_relevance": float(oracle_relevance_score),
                        "baseline_relevance": float(baseline_relevance_score)
                    })
                else:
                    bottleneck_analysis["other_bottleneck"] += 1
                    bottleneck_samples["other_bottleneck"].append({
                        "youtube_id": youtube_id,
                        "oracle_relevance": float(oracle_relevance_score),
                        "baseline_relevance": float(baseline_relevance_score)
                    })

        elif baseline_relevance_score >= GOOD_BASELINE_THRESHOLD and oracle_relevance_score < EXCELLENT_ORACLE_THRESHOLD:

            bottleneck_analysis["retrieval_bottleneck"] += 1
            bottleneck_samples["retrieval_bottleneck"].append({
                "youtube_id": youtube_id,
                "oracle_relevance": float(oracle_relevance_score),
                "baseline_relevance": float(baseline_relevance_score)
            })
        else:

            bottleneck_analysis["other_bottleneck"] += 1
            bottleneck_samples["other_bottleneck"].append({
                "youtube_id": youtube_id,
                "oracle_relevance": float(oracle_relevance_score),
                "baseline_relevance": float(baseline_relevance_score)
            })
    

    oracle_stats = {}
    baseline_stats = {}
    
    for category in ["oracle", "baseline"]:
        data = oracle_vs_baseline[category]
        durations = np.array(data["duration"])
        total_duration = durations.sum()
        

        if len(data["recall"]) == 0:
            continue
        
            stats = {}
            for metric in ["recall", "precision", "relevance", "timeliness"]:
                values = np.array(data[metric])
            if len(values) == 0:
                continue
            

            if total_duration > 0:
                weighted_avg = np.average(values, weights=durations)
            else:
                weighted_avg = np.mean(values)
            
                stats[metric] = {
                    "weighted_average": float(weighted_avg),
                    "simple_average": float(np.mean(values)),
                    "std": float(np.std(values))
                }
            
            if category == "oracle":
                oracle_stats = stats
            else:
                baseline_stats = stats
    

    improvements = {}
    for metric in ["recall", "precision", "relevance", "timeliness"]:
        oracle_avg = oracle_stats.get(metric, {}).get("weighted_average", 0.0)
        baseline_avg = baseline_stats.get(metric, {}).get("weighted_average", 0.0)
        improvement = oracle_avg - baseline_avg
        improvement_pct = (improvement / baseline_avg * 100) if baseline_avg > 0 else 0.0
        
        improvements[metric] = {
            "improvement": float(improvement),
            "improvement_percentage": float(improvement_pct),
            "oracle_score": float(oracle_avg),
            "baseline_score": float(baseline_avg)
        }
    

    all_relevance_pairs = []
    for i, (youtube_id, oracle_rel, baseline_rel) in enumerate(zip(all_youtube_ids, all_oracle_relevances, all_baseline_relevances)):
        all_relevance_pairs.append({
            "youtube_id": youtube_id,
            "oracle_relevance": float(oracle_rel),
            "baseline_relevance": float(baseline_rel),
            "relevance_improvement": float(all_relevance_improvements[i]) if i < len(all_relevance_improvements) else 0.0
        })
    

    results = {
        "oracle_performance": oracle_stats,
        "baseline_performance": baseline_stats,
        "improvements": improvements,
        "bottleneck_analysis": bottleneck_analysis,
        "bottleneck_samples": bottleneck_samples,
        "all_relevance_pairs": all_relevance_pairs,
        "bottleneck_thresholds": {
            "bad_relevance_threshold": BAD_RELEVANCE_THRESHOLD,
            "need_gen_baseline_threshold": NEED_GEN_BASELINE_THRESHOLD,
            "both_bottleneck_baseline_threshold": BOTH_BOTTLENECK_BASELINE_THRESHOLD,
            "good_baseline_threshold": GOOD_BASELINE_THRESHOLD,
            "good_oracle_threshold": GOOD_ORACLE_THRESHOLD,
            "excellent_oracle_threshold": EXCELLENT_ORACLE_THRESHOLD,
            "relative_improvement_threshold": RELATIVE_IMPROVEMENT_THRESHOLD,
            "oracle_relevance_stats": {
                "mean": float(np.mean(all_oracle_relevances)) if all_oracle_relevances else 0.0,
                "median": float(np.median(all_oracle_relevances)) if all_oracle_relevances else 0.0,
                "std": float(np.std(all_oracle_relevances)) if all_oracle_relevances else 0.0,
                "min": float(np.min(all_oracle_relevances)) if all_oracle_relevances else 0.0,
                "max": float(np.max(all_oracle_relevances)) if all_oracle_relevances else 0.0
            },
            "baseline_relevance_stats": {
                "mean": float(np.mean(all_baseline_relevances)) if all_baseline_relevances else 0.0,
                "median": float(np.median(all_baseline_relevances)) if all_baseline_relevances else 0.0,
                "std": float(np.std(all_baseline_relevances)) if all_baseline_relevances else 0.0,
                "min": float(np.min(all_baseline_relevances)) if all_baseline_relevances else 0.0,
                "max": float(np.max(all_baseline_relevances)) if all_baseline_relevances else 0.0
            }
        },
        "summary": {
            "total_videos_analyzed": len(oracle_vs_baseline["oracle"]["duration"]),
            "primary_bottleneck": max(bottleneck_analysis.items(), key=lambda x: x[1])[0] if bottleneck_analysis else "unknown"
        }
    }
    

    output_file = os.path.join(output_dir, "oracle_study.json")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    

    generate_markdown_report(results, output_dir)
    

    
    return results

def generate_markdown_report(results: Dict[str, Any], output_dir: str):
    report_file = os.path.join(output_dir, "oracle_study_report.md")
    
    with open(report_file, "w", encoding="utf-8") as f:
        f.write("# Oracle Study Report\n\n")
        
        f.write("## Overview\n\n")
        f.write("This study uses ground truth questions for retrieval to identify bottlenecks.\n\n")
        

        f.write("## Oracle vs Baseline Performance\n\n")
        f.write("| Metric | Oracle | Baseline | Improvement | Improvement % |\n")
        f.write("|--------|--------|----------|-------------|---------------|\n")
        for metric, imp in results["improvements"].items():
            f.write(f"| {metric} | {imp['oracle_score']:.4f} | {imp['baseline_score']:.4f} | "
                   f"{imp['improvement']:+.4f} | {imp['improvement_percentage']:+.2f}% |\n")
        f.write("\n")
        

        f.write("## Bottleneck Analysis\n\n")
        

        if "bottleneck_thresholds" in results:
            thresholds = results["bottleneck_thresholds"]
            f.write("### Classification Thresholds (Absolute Values)\n\n")
            f.write("Using absolute thresholds (not relative percentiles):\n\n")
            f.write(f"- **Bad Relevance**: < {thresholds['bad_relevance_threshold']:.1f}\n")
            f.write(f"- **Need Generation Baseline Threshold**: < {thresholds['need_gen_baseline_threshold']:.1f}\n")
            f.write(f"- **Both Bottleneck Baseline Threshold**: < {thresholds['both_bottleneck_baseline_threshold']:.1f}\n")
            f.write(f"- **Good Baseline**: >= {thresholds['good_baseline_threshold']:.1f}\n")
            f.write(f"- **Good Oracle**: >= {thresholds['good_oracle_threshold']:.1f}\n")
            f.write(f"- **Excellent Oracle**: >= {thresholds['excellent_oracle_threshold']:.1f}\n")
            f.write(f"- **Relative Improvement Threshold**: > {thresholds['relative_improvement_threshold']:.1f}\n\n")
            
            if "oracle_relevance_stats" in thresholds and "baseline_relevance_stats" in thresholds:
                oracle_stats = thresholds["oracle_relevance_stats"]
                baseline_stats = thresholds["baseline_relevance_stats"]
                f.write("**Relevance Score Distributions:**\n\n")
                f.write("| Metric | Oracle Relevance | Baseline Relevance |\n")
                f.write("|--------|------------------|-------------------|\n")
                f.write(f"| Mean | {oracle_stats['mean']:.4f} | {baseline_stats['mean']:.4f} |\n")
                f.write(f"| Median | {oracle_stats['median']:.4f} | {baseline_stats['median']:.4f} |\n")
                f.write(f"| Std | {oracle_stats['std']:.4f} | {baseline_stats['std']:.4f} |\n")
                f.write(f"| Min | {oracle_stats['min']:.4f} | {baseline_stats['min']:.4f} |\n")
                f.write(f"| Max | {oracle_stats['max']:.4f} | {baseline_stats['max']:.4f} |\n\n")
        
        f.write("### Classification Logic\n\n")
        f.write("- **need_generation_bottleneck**: baseline_relevance < 0.2 AND (oracle_relevance >= 0.6 OR relative_improvement > 0.6)\n")
        f.write("  → Baseline is bad but oracle is good or has high relative improvement, problem in query generation\n\n")
        f.write("- **retrieval_bottleneck**: baseline_relevance >= 0.6 AND oracle_relevance < 0.75\n")
        f.write("  → Baseline is good but oracle is not excellent, problem in retriever\n\n")
        f.write("- **both_bottleneck**: oracle_relevance < 0.6 AND baseline_relevance < 0.2\n")
        f.write("  → Oracle is not good and baseline is bad, both have problems\n\n")
        f.write("- **no_bottleneck**: oracle_relevance >= 0.75 AND baseline_relevance >= 0.6\n")
        f.write("  → Both are good, no problems\n\n")
        f.write("- **other_bottleneck**: Other cases\n\n")
        
        bottleneck = results["bottleneck_analysis"]
        total = sum(bottleneck.values())
        if total > 0:
            f.write("### Results\n\n")
            f.write("| Bottleneck Type | Count | Percentage |\n")
            f.write("|----------------|-------|------------|\n")
            for btype, count in bottleneck.items():
                f.write(f"| {btype} | {count} | {count/total:.2%} |\n")
        f.write("\n")
        
        f.write(f"**Primary Bottleneck**: {results['summary']['primary_bottleneck']}\n\n")
        f.write(f"**Total Videos Analyzed**: {results['summary']['total_videos_analyzed']}\n\n")
        

        if "bottleneck_samples" in results and "other_bottleneck" in results["bottleneck_samples"]:
            other_samples = results["bottleneck_samples"]["other_bottleneck"]
            if len(other_samples) > 0:
                f.write("### Other Bottleneck Sample Details\n\n")
                f.write("First 10 samples in `other_bottleneck` category:\n\n")
                f.write("| YouTube ID | Oracle Relevance | Baseline Relevance | Relative Improvement |\n")
                f.write("|-----------|------------------|---------------------|---------------------|\n")
                for sample in other_samples[:10]:
                    youtube_id = sample.get("youtube_id", "N/A")
                    oracle_rel = sample.get("oracle_relevance", 0.0)
                    baseline_rel = sample.get("baseline_relevance", 0.0)
                    rel_improvement = sample.get("relative_improvement", "N/A")
                    if rel_improvement == "N/A":
                        if baseline_rel > 0:
                            rel_improvement = (oracle_rel - baseline_rel) / baseline_rel
                        else:
                            rel_improvement = "N/A"
                    if rel_improvement != "N/A":
                        f.write(f"| {youtube_id} | {oracle_rel:.4f} | {baseline_rel:.4f} | {rel_improvement:.4f} |\n")
                    else:
                        f.write(f"| {youtube_id} | {oracle_rel:.4f} | {baseline_rel:.4f} | N/A |\n")
                f.write("\n")

if __name__ == "__main__":

    evaluation_dir = "/home/key4/JIRArena-exp/evaluation_output"
    oracle_dir = "/home/key4/JIRArena-exp/evaluation_output"
    output_dir = "/home/key4/JIRArena-exp/iclr_rebuttal/oracle_study"
    

    from model_config import get_all_model_ids
    all_model_ids = get_all_model_ids()
    model_ids = [m for m in all_model_ids if "oracle" not in m.lower()]
    
    
    
    results = analyze_oracle_performance(
        evaluation_dir=evaluation_dir,
        oracle_dir=oracle_dir,
        output_dir=output_dir,
        model_ids=model_ids,
        similarity_threshold=0.55,
        fuzzy_sentence_interval=1,
        num_workers=4,
        use_gpu=True,
        num_gpus=4
    )
    

