
import sys
import os
import json
import csv
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from collections import defaultdict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from utils import load_jsonl, METAINFO_DIR
from evaluate_utils import get_duration_by_youtube_id
from evaluate import JIREvaluator

from model_config import get_all_model_ids
from similarity_cache import get_or_compute_matches
from need_type_utils import normalize_need_type_label, get_need_type_label, load_type_patch_cache

def process_single_video_need_type(args):
    model_id, youtube_id, filename, youtube_ids_set, similarity_threshold, fuzzy_sentence_interval, gpu_id = args
    
    try:

        if torch.cuda.is_available() and gpu_id is not None:
            device = f'cuda:{gpu_id}'
        else:
            device = 'cpu'
        
        if youtube_id not in youtube_ids_set:
            return None
        
        evaluation_dir = "/home/key4/JIRArena-exp/evaluation_output"
        model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
        evaluation_file = os.path.join(model_dir, filename)
        
        if not os.path.exists(evaluation_file):
            return None
        
        ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
        if not os.path.exists(ground_truth_file):
            return None
        
        candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{model_id}", filename)
        if not os.path.exists(candidate_file):
            return None
        
        queries_to_match = load_jsonl(ground_truth_file)
        candidate_data = json.load(open(candidate_file, "r"))
        candidate_queries = candidate_data.get("needs", [])
        evaluation_data = json.load(open(evaluation_file, "r"))
        

        evaluator = None
        def get_evaluator():
            nonlocal evaluator
            if evaluator is None:
                evaluator = JIREvaluator(similarity_threshold=similarity_threshold, device=device)
            return evaluator
        
        results = get_or_compute_matches(
            model_id, youtube_id, queries_to_match, candidate_queries,
            get_evaluator, similarity_threshold, fuzzy_sentence_interval,
            use_cache=True
        )
        

        recall = evaluation_data.get("recall", {})
        precision = evaluation_data.get("precision", {})
        relevance = evaluation_data.get("relevance", {})
        timeliness = evaluation_data.get("timeliness", {})
        

        video_type_stats = defaultdict(lambda: {
            "count": 0,
            "matched": 0,
            "similarity_scores": [],
            "relevance_scores": {"opensearch": [], "dense": [], "reranked": []},
            "timeliness_scores": {"start": [], "end": [], "avg": []}
        })
        

        unknown_queries_list = []
        for idx, query in enumerate(queries_to_match):
            need_type = get_need_type_label(query, youtube_id=youtube_id)
            if need_type == "Unknown":
                unknown_queries_list.append({
                    "youtube_id": youtube_id,
                    "model_id": model_id,
                    "question": query.get("question", ""),
                    "need": query.get("need", ""),
                    "type": query.get("type", ""),
                    "index": idx
                })
            stats = video_type_stats[need_type]
            
            stats["count"] += 1
            
            matches = results.get(idx, [])
            is_matched = len(matches) > 0
            
            if is_matched:
                stats["matched"] += 1
                
                if matches:
                    best_similarity = max([s for _, s in matches])
                    stats["similarity_scores"].append(best_similarity)
                

                query_relevance = relevance.get("query_scores", {}).get(str(idx), 0.0)
            if query_relevance is not None:
                        stats["relevance_scores"]["opensearch"].append(query_relevance)
                
                query_timeliness = timeliness.get("query_time_scores", {}).get(str(idx), {})
                if query_timeliness:
                    stats["timeliness_scores"]["start"].append(query_timeliness.get("start_time_match", 0.0))
                    stats["timeliness_scores"]["end"].append(query_timeliness.get("end_time_match", 0.0))
                    stats["timeliness_scores"]["avg"].append(query_timeliness.get("avg_time_match", 0.0))
        
        return {
            "type_stats": video_type_stats,
            "recall": recall.get("recall", 0.0),
            "precision": precision.get("precision", 0.0),
            "unknown_queries": unknown_queries_list
        }
    except Exception as e:
        return None

def analyze_need_type_performance(
    evaluation_dir: str,
    output_dir: str,
    model_ids: List[str],
    similarity_threshold: float = 0.55,
    fuzzy_sentence_interval: int = 1,
    num_workers: int = None,
    use_gpu: bool = True,
    num_gpus: int = 4
) -> Dict[str, Any]:
    os.makedirs(output_dir, exist_ok=True)
    

    youtube_ids = load_jsonl(os.path.join(METAINFO_DIR, "lecture_.jsonl")) + \
                  load_jsonl(os.path.join(METAINFO_DIR, "paper_.jsonl"))
    youtube_ids_set = set([item["youtube_id"] for item in youtube_ids])
    

    tasks = []
    for model_id in model_ids:
        model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
        if not os.path.exists(model_dir):
            continue
        
        for filename in os.listdir(model_dir):
            if not filename.endswith(".json"):
                continue
            
            youtube_id = filename[:-len(".json")]
            if youtube_id not in youtube_ids_set:
                continue
            
            tasks.append((model_id, youtube_id, filename, youtube_ids_set, 
                         similarity_threshold, fuzzy_sentence_interval, None))
    

    if num_workers is None:

        num_workers = min(len(tasks), 32, os.cpu_count() * 2)
    

    if use_gpu and torch.cuda.is_available():
        tasks_with_gpu = []
        for i, task in enumerate(tasks):
            gpu_id = i % num_gpus
            tasks_with_gpu.append(task[:-1] + (gpu_id,))
        tasks = tasks_with_gpu
    

    type_stats = defaultdict(lambda: {
        "count": 0,
        "matched": 0,
        "recall_scores": [],
        "precision_scores": [],
        "relevance_scores": {"opensearch": [], "dense": [], "reranked": []},
        "timeliness_scores": {"start": [], "end": [], "avg": []},
        "similarity_scores": []
    })
    

    all_unknown_queries = []
    

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_video_need_type, task): task for task in tasks}
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing videos"):
            result = future.result()
            if result is None:
                continue
            

            if "unknown_queries" in result:
                all_unknown_queries.extend(result["unknown_queries"])
            

            for need_type, video_stats in result["type_stats"].items():
                type_stats[need_type]["count"] += video_stats["count"]
                type_stats[need_type]["matched"] += video_stats["matched"]
                type_stats[need_type]["similarity_scores"].extend(video_stats["similarity_scores"])
                type_stats[need_type]["relevance_scores"]["opensearch"].extend(video_stats["relevance_scores"]["opensearch"])
                type_stats[need_type]["relevance_scores"]["dense"].extend(video_stats["relevance_scores"]["dense"])
                type_stats[need_type]["relevance_scores"]["reranked"].extend(video_stats["relevance_scores"]["reranked"])
                type_stats[need_type]["timeliness_scores"]["start"].extend(video_stats["timeliness_scores"]["start"])
                type_stats[need_type]["timeliness_scores"]["end"].extend(video_stats["timeliness_scores"]["end"])
                type_stats[need_type]["timeliness_scores"]["avg"].extend(video_stats["timeliness_scores"]["avg"])
            

            recall_val = result["recall"]
            precision_val = result["precision"]
            for need_type in type_stats:
                type_stats[need_type]["recall_scores"].append(recall_val)
                type_stats[need_type]["precision_scores"].append(precision_val)
    

    if all_unknown_queries:
        unknown_file = os.path.join(output_dir, "unknown_queries_found.json")
        with open(unknown_file, 'w', encoding='utf-8') as f:
            json.dump(all_unknown_queries, f, indent=2, ensure_ascii=False)
    

    type_performance = {}
    for need_type, stats in type_stats.items():
        recall_rate = stats["matched"] / stats["count"] if stats["count"] > 0 else 0.0
        
        avg_recall = np.mean(stats["recall_scores"]) if stats["recall_scores"] else 0.0
        avg_precision = np.mean(stats["precision_scores"]) if stats["precision_scores"] else 0.0
        
        avg_relevance_opensearch = np.mean(stats["relevance_scores"]["opensearch"]) if stats["relevance_scores"]["opensearch"] else 0.0
        avg_relevance_dense = np.mean(stats["relevance_scores"]["dense"]) if stats["relevance_scores"]["dense"] else 0.0
        avg_relevance_reranked = np.mean(stats["relevance_scores"]["reranked"]) if stats["relevance_scores"]["reranked"] else 0.0
        
        avg_timeliness_start = np.mean(stats["timeliness_scores"]["start"]) if stats["timeliness_scores"]["start"] else 0.0
        avg_timeliness_end = np.mean(stats["timeliness_scores"]["end"]) if stats["timeliness_scores"]["end"] else 0.0
        avg_timeliness = np.mean(stats["timeliness_scores"]["avg"]) if stats["timeliness_scores"]["avg"] else 0.0
        
        avg_similarity = np.mean(stats["similarity_scores"]) if stats["similarity_scores"] else 0.0
        
        type_performance[need_type] = {
            "count": stats["count"],
            "percentage": 0.0,
            "matched": stats["matched"],
            "recall_rate": recall_rate,
            "average_recall": avg_recall,
            "average_precision": avg_precision,
            "relevance_opensearch": avg_relevance_opensearch,
            "relevance_dense": avg_relevance_dense,
            "relevance_reranked": avg_relevance_reranked,
            "timeliness_start": avg_timeliness_start,
            "timeliness_end": avg_timeliness_end,
            "timeliness_avg": avg_timeliness,
            "average_similarity": avg_similarity
        }
    

    total_count = sum([p["count"] for p in type_performance.values()])
    for need_type in type_performance:
        type_performance[need_type]["percentage"] = type_performance[need_type]["count"] / total_count if total_count > 0 else 0.0
    

    challenging_types = []
    for need_type, perf in type_performance.items():
        challenge_score = (
            (1 - perf["recall_rate"]) * 0.4 +
            (1 - perf["relevance_reranked"]) * 0.3 +
            (1 - perf["timeliness_avg"]) * 0.3
        )
        
        challenging_types.append({
            "type": need_type,
            "challenge_score": challenge_score,
            "recall_rate": perf["recall_rate"],
            "relevance": perf["relevance_reranked"],
            "timeliness": perf["timeliness_avg"],
            "count": perf["count"]
        })
    
    challenging_types.sort(key=lambda x: x["challenge_score"], reverse=True)
    

    results = {
        "distribution": type_performance,
        "challenging_types": challenging_types[:20],
        "summary": {
            "total_types": len(type_performance),
            "total_queries": total_count,
            "most_common_type": max(type_performance.items(), key=lambda x: x[1]["count"])[0] if type_performance else None,
            "least_common_type": min(type_performance.items(), key=lambda x: x[1]["count"])[0] if type_performance else None
        }
    }
    

    output_file = os.path.join(output_dir, "need_type_analysis.json")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    

    generate_markdown_report(results, output_dir)
    

    
    return results

def generate_markdown_report(results: Dict[str, Any], output_dir: str):
    report_file = os.path.join(output_dir, "need_type_analysis_report.md")
    
    with open(report_file, "w", encoding="utf-8") as f:
        f.write("# Need Type Analysis Report\n\n")
        

        f.write("## Need Type Distribution\n\n")
        f.write("| Need Type | Count | Percentage |\n")
        f.write("|-----------|-------|------------|\n")
        sorted_types = sorted(
            results["distribution"].items(),
            key=lambda x: x[1]["count"],
            reverse=True
        )
        for need_type, perf in sorted_types:
            f.write(f"| {need_type} | {perf['count']} | {perf['percentage']:.2%} |\n")
        f.write("\n")
        

        f.write("## Performance by Need Type\n\n")
        f.write("| Need Type | Recall | Precision | R_relevance (O/D/R) | R_timeliness |\n")
        f.write("|-----------|--------|-----------|---------------------|--------------|\n")
        for need_type, perf in sorted_types:
            relevance_str = f"{perf['relevance_opensearch']:.3f}/{perf['relevance_dense']:.3f}/{perf['relevance_reranked']:.3f}"
            f.write(f"| {need_type} | {perf['recall_rate']:.3f} | {perf['average_precision']:.3f} | {relevance_str} | {perf['timeliness_avg']:.3f} |\n")
        f.write("\n")
        

        f.write("## Most Challenging Need Types\n\n")
        f.write("| Rank | Need Type | Challenge Score | Recall | Relevance | Timeliness | Count |\n")
        f.write("|------|-----------|-----------------|--------|-----------|------------|-------|\n")
        for i, challenge in enumerate(results["challenging_types"][:10], 1):
            f.write(f"| {i} | {challenge['type']} | {challenge['challenge_score']:.3f} | "
                   f"{challenge['recall_rate']:.3f} | {challenge['relevance']:.3f} | "
                   f"{challenge['timeliness']:.3f} | {challenge['count']} |\n")
        f.write("\n")
        

        f.write("## Summary\n\n")
        summary = results["summary"]
        f.write(f"- Total need types: {summary['total_types']}\n")
        f.write(f"- Total queries: {summary['total_queries']}\n")
        f.write(f"- Most common type: {summary['most_common_type']}\n")
        f.write(f"- Least common type: {summary['least_common_type']}\n")

if __name__ == "__main__":

    evaluation_dir = "/home/key4/JIRArena-exp/evaluation_output"
    output_dir = "/home/key4/JIRArena-exp/iclr_rebuttal/need_type_analysis"
    

    all_model_ids = get_all_model_ids()
    model_ids = [m for m in all_model_ids if "oracle" not in m.lower()]
    
    
    results = analyze_need_type_performance(
        evaluation_dir=evaluation_dir,
        output_dir=output_dir,
        model_ids=model_ids,
        similarity_threshold=0.55,
        fuzzy_sentence_interval=1,
        num_workers=32,
        use_gpu=True,
        num_gpus=4
    )
    

