
import sys
import os
import json
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from collections import defaultdict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
import threading

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from utils import load_jsonl, METAINFO_DIR
from evaluate_utils import get_duration_by_youtube_id
from evaluate import JIREvaluator
from model_config import get_all_model_ids
from similarity_cache import get_or_compute_matches

HUMAN_ANNOTATION_LIKELIHOOD_SCORE = 9.0
VERY_HIGH_LIKELIHOOD_THRESHOLD = 9.0
HIGH_LIKELIHOOD_THRESHOLD = 8.5
MEDIUM_HIGH_LIKELIHOOD_THRESHOLD = 8.0
MEDIUM_LIKELIHOOD_THRESHOLD = 7.5
MEDIUM_LOW_LIKELIHOOD_THRESHOLD = 7.0
SIMILARITY_THRESHOLD = 0.55
SENTENCE_DURATION = 5

_model_init_lock = threading.Lock()

_evaluator_pool = {}
_evaluator_pool_lock = threading.Lock()

def initialize_evaluator_pool(num_gpus: int, similarity_threshold: float = 0.55):
    global _evaluator_pool
    
    
    if not torch.cuda.is_available():

        with _model_init_lock:
            _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
        return
    

    for gpu_id in range(num_gpus):
        device = f'cuda:{gpu_id}'
        try:
            with _model_init_lock:
                evaluator = JIREvaluator(similarity_threshold=similarity_threshold, device=device)
                _evaluator_pool[device] = evaluator
        except Exception as e:
            import sys

            if 'cpu' not in _evaluator_pool:
                with _model_init_lock:
                    _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
    

    if 'cpu' not in _evaluator_pool:
        with _model_init_lock:
            _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
    

def get_evaluator(device: str, similarity_threshold: float = 0.55):
    global _evaluator_pool
    
    with _evaluator_pool_lock:
        if device in _evaluator_pool:
            return _evaluator_pool[device]
        

        if 'cpu' in _evaluator_pool:
            return _evaluator_pool['cpu']
        

        import sys
        with _model_init_lock:
            return JIREvaluator(similarity_threshold=similarity_threshold, device=device)

def get_likelihood_score(query: Dict[str, Any]) -> Optional[float]:
    if query.get("data_type") == "human":
        return HUMAN_ANNOTATION_LIKELIHOOD_SCORE
    
    if "likelihood_scores" in query and query["likelihood_scores"]:
        scores = [s["score"] for s in query["likelihood_scores"] if isinstance(s, dict) and "score" in s]
        if scores:
            return sum(scores) / len(scores)
    
    return None

def classify_likelihood(score: Optional[float]) -> str:
    if score is None:
        return "unknown"
    if score >= VERY_HIGH_LIKELIHOOD_THRESHOLD:
        return "very_high"
    elif score >= HIGH_LIKELIHOOD_THRESHOLD:
        return "high"
    elif score >= MEDIUM_HIGH_LIKELIHOOD_THRESHOLD:
        return "medium_high"
    elif score >= MEDIUM_LIKELIHOOD_THRESHOLD:
        return "medium"
    elif score >= MEDIUM_LOW_LIKELIHOOD_THRESHOLD:
        return "medium_low"
    else:
        return "low"

def analyze_failure_reason_simplified(
    query: Dict[str, Any],
    candidate_queries: List[Dict[str, Any]],
    fuzzy_sentence_interval: int = 1
) -> Dict[str, Any]:
    query_start = query.get('start_time', 0.0)
    query_end = query.get('end_time', query_start)
    if query_end is None or query_end == query_start:
        query_end = query_start + 2 * SENTENCE_DURATION
    

    effective_query_start = query_start - (fuzzy_sentence_interval * SENTENCE_DURATION)
    effective_query_end = query_end + (fuzzy_sentence_interval * SENTENCE_DURATION)
    

    time_overlap_count = 0
    for candidate in candidate_queries:
        cand_start = candidate.get('start_time', 0.0)
        cand_end = candidate.get('end_time', cand_start)
        if cand_end is None or cand_end == cand_start:
            cand_end = cand_start + 2 * SENTENCE_DURATION
        

        if max(effective_query_start, cand_start) <= min(effective_query_end, cand_end):
            time_overlap_count += 1
    
    if time_overlap_count > 0:
        return {
            "status": "semantic_mismatch",
            "num_time_overlaps": time_overlap_count
        }
    else:
        return {
            "status": "time_mismatch",
            "num_time_overlaps": 0
        }

def analyze_failure_reason_with_similarities(
    query_idx: int,
    query: Dict[str, Any],
    candidate_queries: List[Dict[str, Any]],
    similarity_matrix: torch.Tensor,
    evaluator: JIREvaluator,
    fuzzy_sentence_interval: int = 1
) -> Dict[str, Any]:
    query_start = query.get('start_time', 0.0)
    query_end = query.get('end_time', query_start)
    if query_end is None or query_end == query_start:
        query_end = query_start + 2 * SENTENCE_DURATION
    

    similarities = similarity_matrix[query_idx]
    

    time_overlap_candidates = []
    semantic_match_candidates = []
    full_match_candidates = []
    
    for j, candidate in enumerate(candidate_queries):
        cand_start = candidate.get('start_time', 0.0)
        cand_end = candidate.get('end_time', cand_start)
        if cand_end is None or cand_end == cand_start:
            cand_end = cand_start + 2 * SENTENCE_DURATION
        

        time_overlaps = evaluator._check_time_overlap(
            query_start, query_end,
            cand_start, cand_end,
            fuzzy_sentence_interval
        )
        
        if time_overlaps:
            time_overlap_candidates.append(candidate)
            

            similarity = similarities[j].item()
            
            if similarity >= SIMILARITY_THRESHOLD:
                full_match_candidates.append((candidate, similarity))
            else:
                semantic_match_candidates.append((candidate, similarity))
    

    if len(full_match_candidates) > 0:
        return {
            "status": "matched",
            "num_matches": len(full_match_candidates),
            "best_similarity": max([s for _, s in full_match_candidates]) if full_match_candidates else 0.0
        }
    elif len(time_overlap_candidates) > 0:

        best_similarity = max([s for _, s in semantic_match_candidates]) if semantic_match_candidates else 0.0
        return {
            "status": "semantic_mismatch",
            "num_time_overlaps": len(time_overlap_candidates),
            "best_similarity": best_similarity,
            "similarity_gap": SIMILARITY_THRESHOLD - best_similarity
        }
    else:

        return {
            "status": "time_mismatch",
            "num_time_overlaps": 0
        }

def analyze_failure_reason(
    query: Dict[str, Any],
    candidate_queries: List[Dict[str, Any]],
    evaluator: JIREvaluator,
    fuzzy_sentence_interval: int = 1
) -> Dict[str, Any]:
    query_start = query.get('start_time', 0.0)
    query_end = query.get('end_time', query_start)
    if query_end is None or query_end == query_start:
        query_end = query_start + 2 * SENTENCE_DURATION
    
    query_question = query.get('question', '')
    

    time_overlap_candidates = []
    semantic_match_candidates = []
    full_match_candidates = []
    
    for candidate in candidate_queries:
        cand_start = candidate.get('start_time', 0.0)
        cand_end = candidate.get('end_time', cand_start)
        if cand_end is None or cand_end == cand_start:
            cand_end = cand_start + 2 * SENTENCE_DURATION
        

        time_overlaps = evaluator._check_time_overlap(
            query_start, query_end,
            cand_start, cand_end,
            fuzzy_sentence_interval
        )
        
        if time_overlaps:
            time_overlap_candidates.append(candidate)
            

            similarity = evaluator._compute_similarity(query_question, candidate.get('question', ''))
            
            if similarity >= SIMILARITY_THRESHOLD:
                full_match_candidates.append((candidate, similarity))
            else:
                semantic_match_candidates.append((candidate, similarity))
    

    if len(full_match_candidates) > 0:
        return {
            "status": "matched",
            "num_matches": len(full_match_candidates),
            "best_similarity": max([s for _, s in full_match_candidates]) if full_match_candidates else 0.0
        }
    elif len(time_overlap_candidates) > 0:

        best_similarity = max([s for _, s in semantic_match_candidates]) if semantic_match_candidates else 0.0
        return {
            "status": "semantic_mismatch",
            "num_time_overlaps": len(time_overlap_candidates),
            "best_similarity": best_similarity,
            "similarity_gap": SIMILARITY_THRESHOLD - best_similarity
        }
    else:

        return {
            "status": "time_mismatch",
            "num_time_overlaps": 0
        }

def process_single_video_error_analysis_with_results(
    model_id: str,
    youtube_id: str,
    queries_to_match: List[Dict[str, Any]],
    candidate_queries: List[Dict[str, Any]],
    match_results: Dict[int, List[Tuple[Dict[str, Any], float]]],
    similarity_threshold: float,
    fuzzy_sentence_interval: int
):
    try:

        

        video_queries = []
        video_failures = []
        video_stats = {
            "likelihood_distribution": defaultdict(int),
            "failure_by_likelihood": defaultdict(lambda: {"total": 0, "failed": 0}),
            "failure_reasons": defaultdict(int),
            "failure_by_type": defaultdict(lambda: {"total": 0, "failed": 0, "reasons": defaultdict(int)}),
            "high_likelihood_failures": []
        }
        
        for idx, query in enumerate(queries_to_match):
            likelihood_score = get_likelihood_score(query)
            likelihood_category = classify_likelihood(likelihood_score)
            
            video_queries.append({
                "youtube_id": youtube_id,
                "model_id": model_id,
                "query_idx": idx,
                "likelihood_score": likelihood_score,
                "likelihood_category": likelihood_category,
                "type": query.get("type", "Unknown"),
                "subtype": query.get("subtype", ""),
                "question": query.get("question", ""),
                "start_time": query.get("start_time", 0.0),
                "end_time": query.get("end_time", 0.0)
            })
            
            if likelihood_score is not None:
                video_stats["likelihood_distribution"][likelihood_category] += 1
                video_stats["failure_by_likelihood"][likelihood_category]["total"] += 1
            
            matches = match_results.get(idx, [])
            is_matched = len(matches) > 0
            
            if not is_matched:

                failure_analysis = analyze_failure_reason_simplified(
                    query, candidate_queries, fuzzy_sentence_interval
                )
                
                failure_info = {
                    "youtube_id": youtube_id,
                    "model_id": model_id,
                    "query_idx": idx,
                    "likelihood_score": likelihood_score,
                    "likelihood_category": likelihood_category,
                    "type": query.get("type", "Unknown"),
                    "subtype": query.get("subtype", ""),
                    "question": query.get("question", ""),
                    "start_time": query.get("start_time", 0.0),
                    "end_time": query.get("end_time", 0.0),
                    "failure_reason": failure_analysis["status"],
                    "failure_details": failure_analysis
                }
                
                video_failures.append(failure_info)
                
                if likelihood_score is not None:
                    video_stats["failure_by_likelihood"][likelihood_category]["failed"] += 1
                
                video_stats["failure_reasons"][failure_analysis["status"]] += 1
                
                need_type = f"{query.get('type', 'Unknown')} ({query.get('subtype', '')})"
                video_stats["failure_by_type"][need_type]["total"] += 1
                video_stats["failure_by_type"][need_type]["failed"] += 1
                video_stats["failure_by_type"][need_type]["reasons"][failure_analysis["status"]] += 1
                

                if likelihood_category in ["very_high", "high"]:
                    video_stats["high_likelihood_failures"].append(failure_info)
        
        return {
            "queries": video_queries,
            "failures": video_failures,
            "stats": video_stats
        }
    except Exception as e:
        import traceback
        import sys
        error_msg = f"FATAL ERROR processing cached {youtube_id} for {model_id}: {e}"
        traceback.print_exc(file=sys.stderr)
        sys.exit(1)

def process_single_video_error_analysis(args):
    model_id, youtube_id, filename, youtube_ids_set, similarity_threshold, fuzzy_sentence_interval, gpu_id = args
    
    try:

        if torch.cuda.is_available() and gpu_id is not None:
            device = f'cuda:{gpu_id}'
        else:
            device = 'cpu'
        
        if youtube_id not in youtube_ids_set:
            return None
        
        evaluation_dir = "/home/key4/JIRArena-exp/evaluation_output"
        model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
        evaluation_file = os.path.join(model_dir, filename)
        
        if not os.path.exists(evaluation_file):
            return None
        
        ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
        if not os.path.exists(ground_truth_file):
            return None
        
        candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{model_id}", filename)
        if not os.path.exists(candidate_file):
            return None
        
        queries_to_match = load_jsonl(ground_truth_file)
        candidate_data = json.load(open(candidate_file, "r"))
        candidate_queries = candidate_data.get("needs", [])
        evaluation_data = json.load(open(evaluation_file, "r"))
        

        evaluator = get_evaluator(device, similarity_threshold)
        

        from similarity_cache import get_cache_file_path, load_match_results, save_match_results
        
        cache_file = get_cache_file_path(
            "/home/key4/JIRArena-exp/evaluation_output/.similarity_cache",
            model_id, youtube_id, similarity_threshold, fuzzy_sentence_interval
        )
        cached_data = load_match_results(cache_file)
        
        if cached_data is not None:

            results = cached_data["match_results"]
            similarity_matrix = None
        else:

            results, similarity_matrix = evaluator.batch_find_matches(
                queries_to_match, candidate_queries, fuzzy_sentence_interval,
                return_similarity_matrix=True
            )

            save_match_results(cache_file, queries_to_match, candidate_queries, results)
        

        video_queries = []
        video_failures = []
        video_stats = {
            "likelihood_distribution": defaultdict(int),
            "failure_by_likelihood": defaultdict(lambda: {"total": 0, "failed": 0}),
            "failure_reasons": defaultdict(int),
            "failure_by_type": defaultdict(lambda: {"total": 0, "failed": 0, "reasons": defaultdict(int)}),
            "high_likelihood_failures": []
        }
        
        for idx, query in enumerate(queries_to_match):
            likelihood_score = get_likelihood_score(query)
            likelihood_category = classify_likelihood(likelihood_score)
            
            video_queries.append({
                "youtube_id": youtube_id,
                "model_id": model_id,
                "query_idx": idx,
                "likelihood_score": likelihood_score,
                "likelihood_category": likelihood_category,
                "type": query.get("type", "Unknown"),
                "subtype": query.get("subtype", ""),
                "question": query.get("question", ""),
                "start_time": query.get("start_time", 0.0),
                "end_time": query.get("end_time", 0.0)
            })
            
            if likelihood_score is not None:
                video_stats["likelihood_distribution"][likelihood_category] += 1
                video_stats["failure_by_likelihood"][likelihood_category]["total"] += 1
            
            matches = results.get(idx, [])
            is_matched = len(matches) > 0
            
            if not is_matched:

                if similarity_matrix is not None:
                    failure_analysis = analyze_failure_reason_with_similarities(
                        idx, query, candidate_queries, similarity_matrix, 
                        evaluator, fuzzy_sentence_interval
                    )
                else:

                    failure_analysis = analyze_failure_reason(
                    query, candidate_queries, evaluator, fuzzy_sentence_interval
                )
                
                failure_info = {
                    "youtube_id": youtube_id,
                    "model_id": model_id,
                    "query_idx": idx,
                    "likelihood_score": likelihood_score,
                    "likelihood_category": likelihood_category,
                    "type": query.get("type", "Unknown"),
                    "subtype": query.get("subtype", ""),
                    "question": query.get("question", ""),
                    "start_time": query.get("start_time", 0.0),
                    "end_time": query.get("end_time", 0.0),
                    "failure_reason": failure_analysis["status"],
                    "failure_details": failure_analysis
                }
                
                video_failures.append(failure_info)
                
                if likelihood_score is not None:
                    video_stats["failure_by_likelihood"][likelihood_category]["failed"] += 1
                
                video_stats["failure_reasons"][failure_analysis["status"]] += 1
                
                need_type = f"{query.get('type', 'Unknown')} ({query.get('subtype', '')})"
                video_stats["failure_by_type"][need_type]["total"] += 1
                video_stats["failure_by_type"][need_type]["failed"] += 1
                video_stats["failure_by_type"][need_type]["reasons"][failure_analysis["status"]] += 1
                

                if likelihood_category in ["very_high", "high"]:
                    video_stats["high_likelihood_failures"].append(failure_info)
        
        return {
            "queries": video_queries,
            "failures": video_failures,
            "stats": video_stats
        }
    except Exception as e:

        import traceback
        import sys
        error_msg = f"FATAL ERROR processing {youtube_id} for {model_id}: {e}"
        traceback.print_exc(file=sys.stderr)

        sys.exit(1)

def analyze_evaluation_errors(
    evaluation_dir: str,
    output_dir: str,
    model_ids: List[str],
    similarity_threshold: float = 0.55,
    fuzzy_sentence_interval: int = 1,
    num_workers: int = None,
    use_gpu: bool = True,
    num_gpus: int = 4
) -> Dict[str, Any]:
    os.makedirs(output_dir, exist_ok=True)
    

    youtube_ids = load_jsonl(os.path.join(METAINFO_DIR, "lecture_.jsonl")) + \
                  load_jsonl(os.path.join(METAINFO_DIR, "paper_.jsonl"))
    youtube_ids_set = set([item["youtube_id"] for item in youtube_ids])
    

    from similarity_cache import get_cache_file_path, load_match_results
    
    all_tasks = []
    cached_tasks = []
    uncached_tasks = []
    
    for model_id in model_ids:
        model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
        if not os.path.exists(model_dir):
            continue
        
        for filename in os.listdir(model_dir):
            if not filename.endswith(".json"):
                continue
            
            youtube_id = filename[:-len(".json")]
            if youtube_id not in youtube_ids_set:
                continue
            
            task = (model_id, youtube_id, filename, youtube_ids_set, 
                   similarity_threshold, fuzzy_sentence_interval, None)
            all_tasks.append(task)
            

            cache_file = get_cache_file_path(
                "/home/key4/JIRArena-exp/evaluation_output/.similarity_cache",
                model_id, youtube_id, similarity_threshold, fuzzy_sentence_interval
            )
            cached_data = load_match_results(cache_file)
            
            if cached_data is not None:

                ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
                candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{model_id}", filename)
                
                if os.path.exists(ground_truth_file) and os.path.exists(candidate_file):

                    queries_to_match = load_jsonl(ground_truth_file)
                    candidate_data = json.load(open(candidate_file, "r"))
                    candidate_queries = candidate_data.get("needs", [])
                    
                    if (len(cached_data.get("queries_to_match", [])) == len(queries_to_match) and
                        len(cached_data.get("candidate_queries", [])) == len(candidate_queries)):
                        cached_tasks.append((task, cached_data["match_results"]))
                        continue
            

            uncached_tasks.append(task)
    
    

    if num_workers is None:
        num_workers = min(len(uncached_tasks), num_gpus * 4 if use_gpu and torch.cuda.is_available() else os.cpu_count())
    

    if use_gpu and torch.cuda.is_available():
        initialize_evaluator_pool(num_gpus, similarity_threshold)

        tasks_with_gpu = []
        for i, task in enumerate(uncached_tasks):
            gpu_id = i % num_gpus
            tasks_with_gpu.append(task[:-1] + (gpu_id,))
        uncached_tasks = tasks_with_gpu
    else:

        initialize_evaluator_pool(0, similarity_threshold)
    

    all_queries = []
    all_failures = []
    likelihood_distribution = defaultdict(int)
    failure_by_likelihood = defaultdict(lambda: {"total": 0, "failed": 0})
    failure_reasons = defaultdict(int)
    failure_by_type = defaultdict(lambda: {"total": 0, "failed": 0, "reasons": defaultdict(int)})
    high_likelihood_failures = []
    

    for task, cached_match_results in tqdm(cached_tasks, desc="Loading cached results"):
        model_id, youtube_id, filename, youtube_ids_set, similarity_threshold, fuzzy_sentence_interval, gpu_id = task
        
        try:
            ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
            candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{model_id}", filename)
            
            queries_to_match = load_jsonl(ground_truth_file)
            candidate_data = json.load(open(candidate_file, "r"))
            candidate_queries = candidate_data.get("needs", [])
            

            result = process_single_video_error_analysis_with_results(
                model_id, youtube_id, queries_to_match, candidate_queries,
                cached_match_results, similarity_threshold, fuzzy_sentence_interval
            )
            
            if result is None:
                continue
            

            all_queries.extend(result["queries"])
            all_failures.extend(result["failures"])
            

            for category, count in result["stats"]["likelihood_distribution"].items():
                likelihood_distribution[category] += count
            
            for category, stats in result["stats"]["failure_by_likelihood"].items():
                failure_by_likelihood[category]["total"] += stats["total"]
                failure_by_likelihood[category]["failed"] += stats["failed"]
            
            for reason, count in result["stats"]["failure_reasons"].items():
                failure_reasons[reason] += count
            
            for need_type, stats in result["stats"]["failure_by_type"].items():
                failure_by_type[need_type]["total"] += stats["total"]
                failure_by_type[need_type]["failed"] += stats["failed"]
                for reason, count in stats["reasons"].items():
                    failure_by_type[need_type]["reasons"][reason] += count
            
            high_likelihood_failures.extend(result["stats"]["high_likelihood_failures"])
        except Exception as e:
            import sys
            import traceback
            traceback.print_exc(file=sys.stderr)
            sys.exit(1)
    

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_video_error_analysis, task): task for task in uncached_tasks}
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Computing with GPU"):
            try:
                result = future.result()
                if result is None:
                    continue
            except Exception as e:

                import sys
                import traceback
                traceback.print_exc(file=sys.stderr)
                sys.exit(1)
            

            all_queries.extend(result["queries"])
            all_failures.extend(result["failures"])
            

            for category, count in result["stats"]["likelihood_distribution"].items():
                likelihood_distribution[category] += count
            
            for category, stats in result["stats"]["failure_by_likelihood"].items():
                failure_by_likelihood[category]["total"] += stats["total"]
                failure_by_likelihood[category]["failed"] += stats["failed"]
            
            for reason, count in result["stats"]["failure_reasons"].items():
                failure_reasons[reason] += count
            
            for need_type, stats in result["stats"]["failure_by_type"].items():
                failure_by_type[need_type]["total"] += stats["total"]
                failure_by_type[need_type]["failed"] += stats["failed"]
                for reason, count in stats["reasons"].items():
                    failure_by_type[need_type]["reasons"][reason] += count
            
            high_likelihood_failures.extend(result["stats"]["high_likelihood_failures"])
    

    failure_rates_by_likelihood = {}
    for category, stats in failure_by_likelihood.items():
        if stats["total"] > 0:
            failure_rates_by_likelihood[category] = {
                "total": stats["total"],
                "failed": stats["failed"],
                "failure_rate": stats["failed"] / stats["total"]
            }
    
    failure_rates_by_type = {}
    for need_type, stats in failure_by_type.items():
        if stats["total"] > 0:
            failure_rates_by_type[need_type] = {
                "total": stats["total"],
                "failed": stats["failed"],
                "failure_rate": stats["failed"] / stats["total"],
                "failure_reasons": dict(stats["reasons"])
            }
    

    total_queries = len(all_queries)
    total_failures = len(all_failures)
    overall_failure_rate = total_failures / total_queries if total_queries > 0 else 0.0
    

    high_likelihood_total = (failure_by_likelihood.get("very_high", {}).get("total", 0) + 
                            failure_by_likelihood.get("high", {}).get("total", 0))
    high_likelihood_failed = (failure_by_likelihood.get("very_high", {}).get("failed", 0) + 
                             failure_by_likelihood.get("high", {}).get("failed", 0))
    high_likelihood_failure_rate = high_likelihood_failed / high_likelihood_total if high_likelihood_total > 0 else 0.0
    

    results = {
        "likelihood_definition": {
            "very_high_threshold": VERY_HIGH_LIKELIHOOD_THRESHOLD,
            "high_threshold": HIGH_LIKELIHOOD_THRESHOLD,
            "medium_high_threshold": MEDIUM_HIGH_LIKELIHOOD_THRESHOLD,
            "medium_threshold": MEDIUM_LIKELIHOOD_THRESHOLD,
            "medium_low_threshold": MEDIUM_LOW_LIKELIHOOD_THRESHOLD,
            "human_annotation_score": HUMAN_ANNOTATION_LIKELIHOOD_SCORE,
            "very_high_count": likelihood_distribution.get("very_high", 0),
            "high_count": likelihood_distribution.get("high", 0),
            "medium_high_count": likelihood_distribution.get("medium_high", 0),
            "medium_count": likelihood_distribution.get("medium", 0),
            "medium_low_count": likelihood_distribution.get("medium_low", 0),
            "low_count": likelihood_distribution.get("low", 0),
            "high_likelihood_count": likelihood_distribution.get("very_high", 0) + likelihood_distribution.get("high", 0),
            "high_likelihood_failed": high_likelihood_failed,
            "high_likelihood_failure_rate": high_likelihood_failure_rate
        },
        "overall_statistics": {
            "total_queries": total_queries,
            "total_failures": total_failures,
            "overall_failure_rate": overall_failure_rate,
            "likelihood_distribution": dict(likelihood_distribution)
        },
        "failure_reasons": {
            "time_mismatch": failure_reasons.get("time_mismatch", 0),
            "semantic_mismatch": failure_reasons.get("semantic_mismatch", 0),
            "matched": total_queries - total_failures,
            "total": total_queries
        },
        "failure_by_likelihood": failure_rates_by_likelihood,
        "failure_by_type": failure_rates_by_type,
        "high_likelihood_failures": high_likelihood_failures[:100]
    }
    

    output_file = os.path.join(output_dir, "error_analysis.json")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    

    high_likelihood_file = os.path.join(output_dir, "high_likelihood_failures.json")
    with open(high_likelihood_file, "w", encoding="utf-8") as f:
        json.dump(high_likelihood_failures, f, indent=2, ensure_ascii=False)
    

    generate_markdown_report(results, output_dir)
    

    
    return results

def generate_markdown_report(results: Dict[str, Any], output_dir: str):
    report_file = os.path.join(output_dir, "error_analysis_report.md")
    
    with open(report_file, "w", encoding="utf-8") as f:
        f.write("# Error Analysis Report\n\n")
        

        f.write("## Likelihood Match Definition\n\n")
        f.write("We define likelihood categories based on the `likelihood_score` field:\n\n")
        f.write(f"- **Very High**: score >= {VERY_HIGH_LIKELIHOOD_THRESHOLD} (human annotations or extremely high likelihood)\n")
        f.write(f"- **High**: {HIGH_LIKELIHOOD_THRESHOLD} <= score < {VERY_HIGH_LIKELIHOOD_THRESHOLD}\n")
        f.write(f"- **Medium-High**: {MEDIUM_HIGH_LIKELIHOOD_THRESHOLD} <= score < {HIGH_LIKELIHOOD_THRESHOLD}\n")
        f.write(f"- **Medium**: {MEDIUM_LIKELIHOOD_THRESHOLD} <= score < {MEDIUM_HIGH_LIKELIHOOD_THRESHOLD}\n")
        f.write(f"- **Medium-Low**: {MEDIUM_LOW_LIKELIHOOD_THRESHOLD} <= score < {MEDIUM_LIKELIHOOD_THRESHOLD}\n")
        f.write(f"- **Low**: score < {MEDIUM_LOW_LIKELIHOOD_THRESHOLD}\n")
        f.write(f"- **Human annotation**: data_type == 'human', score = {HUMAN_ANNOTATION_LIKELIHOOD_SCORE}\n\n")
        
        hl_def = results["likelihood_definition"]
        f.write(f"### Statistics\n\n")
        f.write(f"- Very High likelihood queries: {hl_def['very_high_count']}\n")
        f.write(f"- High likelihood queries: {hl_def['high_count']}\n")
        f.write(f"- Medium-High likelihood queries: {hl_def['medium_high_count']}\n")
        f.write(f"- Medium likelihood queries: {hl_def['medium_count']}\n")
        f.write(f"- Medium-Low likelihood queries: {hl_def['medium_low_count']}\n")
        f.write(f"- Low likelihood queries: {hl_def['low_count']}\n")
        f.write(f"\n- High likelihood (>= {HIGH_LIKELIHOOD_THRESHOLD}) queries: {hl_def['high_likelihood_count']}\n")
        f.write(f"- High likelihood failures: {hl_def['high_likelihood_failed']}\n")
        f.write(f"- High likelihood failure rate: {hl_def['high_likelihood_failure_rate']:.2%}\n\n")
        

        f.write("## Overall Statistics\n\n")
        overall = results["overall_statistics"]
        f.write(f"- Total queries analyzed: {overall['total_queries']}\n")
        f.write(f"- Total failures: {overall['total_failures']}\n")
        f.write(f"- Overall failure rate: {overall['overall_failure_rate']:.2%}\n\n")
        

        f.write("## Failure Reasons Distribution\n\n")
        reasons = results["failure_reasons"]
        total = reasons["total"]
        if total > 0:
            f.write(f"- **Time mismatch**: {reasons['time_mismatch']} ({reasons['time_mismatch']/total:.2%})\n")
            f.write(f"- **Semantic mismatch**: {reasons['semantic_mismatch']} ({reasons['semantic_mismatch']/total:.2%})\n")
            f.write(f"- **Matched**: {reasons['matched']} ({reasons['matched']/total:.2%})\n\n")
        else:
            f.write("- No data found\n\n")
        

        f.write("## Failure Rate by Likelihood Category\n\n")
        f.write("| Category | Total | Failed | Failure Rate |\n")
        f.write("|----------|-------|--------|--------------|\n")
        for category, stats in results["failure_by_likelihood"].items():
            f.write(f"| {category} | {stats['total']} | {stats['failed']} | {stats['failure_rate']:.2%} |\n")
        f.write("\n")
        

        f.write("## Failure Rate by Need Type\n\n")
        f.write("| Need Type | Total | Failed | Failure Rate |\n")
        f.write("|-----------|-------|--------|--------------|\n")
        sorted_types = sorted(
            results["failure_by_type"].items(),
            key=lambda x: x[1]["failure_rate"],
            reverse=True
        )
        for need_type, stats in sorted_types[:20]:
            f.write(f"| {need_type} | {stats['total']} | {stats['failed']} | {stats['failure_rate']:.2%} |\n")
        f.write("\n")
        

        f.write("## High Likelihood Failure Examples\n\n")
        f.write(f"Total high likelihood failures: {len(results['high_likelihood_failures'])}\n\n")
        f.write("### Top 10 Examples\n\n")
        for i, failure in enumerate(results["high_likelihood_failures"][:10], 1):
            f.write(f"#### Example {i}\n\n")
            f.write(f"- **Question**: {failure['question']}\n")
            f.write(f"- **Type**: {failure['type']} ({failure['subtype']})\n")
            f.write(f"- **Likelihood Score**: {failure['likelihood_score']:.2f}\n")
            f.write(f"- **Failure Reason**: {failure['failure_reason']}\n")
            f.write(f"- **Time**: [{failure['start_time']:.1f}s - {failure['end_time']:.1f}s]\n")
            if "best_similarity" in failure.get("failure_details", {}):
                f.write(f"- **Best Similarity**: {failure['failure_details']['best_similarity']:.3f}\n")
            f.write("\n")

if __name__ == "__main__":

    evaluation_dir = "/home/key4/JIRArena-exp/evaluation_output"
    output_dir = "/home/key4/JIRArena-exp/iclr_rebuttal/error_analysis"
    

    all_model_ids = get_all_model_ids()
    model_ids = [m for m in all_model_ids if "oracle" not in m.lower()]
    
    
    results = analyze_evaluation_errors(
        evaluation_dir=evaluation_dir,
        output_dir=output_dir,
        model_ids=model_ids,
        similarity_threshold=0.55,
        fuzzy_sentence_interval=1,
        num_workers=16,
        use_gpu=True,
        num_gpus=4
    )
    

