
import sys
import os
import json
import numpy as np
from typing import List, Dict, Any, Tuple
from collections import defaultdict
from tqdm import tqdm

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from utils import load_jsonl, METAINFO_DIR
from evaluate_utils import get_duration_by_youtube_id

from model_config import get_all_model_ids

def analyze_retrieval_models(
    evaluation_dir: str,
    output_dir: str,
    model_ids: List[str]
) -> Dict[str, Any]:
    os.makedirs(output_dir, exist_ok=True)
    

    youtube_ids = load_jsonl(os.path.join(METAINFO_DIR, "lecture_.jsonl")) + \
                  load_jsonl(os.path.join(METAINFO_DIR, "paper_.jsonl"))
    youtube_ids = [item["youtube_id"] for item in youtube_ids]
    

    retrieval_performance = {
        "opensearch": {"scores": [], "durations": []},
        "dense": {"scores": [], "durations": []},
        "reranked": {"scores": [], "durations": []}
    }
    
    model_retrieval_performance = defaultdict(lambda: {
        "opensearch": {"scores": [], "durations": []},
        "dense": {"scores": [], "durations": []},
        "reranked": {"scores": [], "durations": []}
    })
    
    for model_id in tqdm(model_ids, desc="Processing models"):
        model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
        if not os.path.exists(model_dir):
            continue
        
        for filename in tqdm(os.listdir(model_dir), desc=f"Processing {model_id}", leave=False):
            if not filename.endswith(".json"):
                continue
            
            youtube_id = filename[:-len(".json")]
            if youtube_id not in youtube_ids:
                continue
            
            evaluation_file = os.path.join(model_dir, filename)
            if not os.path.exists(evaluation_file):
                continue
            
            evaluation_data = json.load(open(evaluation_file, "r"))
            duration = evaluation_data.get("duration", 0)
            
            relevance = evaluation_data.get("relevance", {})
            

            opensearch_score = relevance.get("weighted_ndcg", 0.0)
            dense_score = 0.0
            reranked_score = 0.0
            

            if "biencoder" in relevance:
                if isinstance(relevance["biencoder"], dict):
                    dense_score = relevance["biencoder"].get("weighted_ndcg", 0.0)
                else:
                    dense_score = opensearch_score
            

            if "reranked" in relevance:
                if isinstance(relevance["reranked"], dict):
                    reranked_score = relevance["reranked"].get("weighted_ndcg", 0.0)
                else:
                    reranked_score = dense_score
            

            retrieval_performance["opensearch"]["scores"].append(opensearch_score)
            retrieval_performance["opensearch"]["durations"].append(duration)
            
            if dense_score > 0:
                retrieval_performance["dense"]["scores"].append(dense_score)
                retrieval_performance["dense"]["durations"].append(duration)
            
            if reranked_score > 0:
                retrieval_performance["reranked"]["scores"].append(reranked_score)
                retrieval_performance["reranked"]["durations"].append(duration)
            

            model_retrieval_performance[model_id]["opensearch"]["scores"].append(opensearch_score)
            model_retrieval_performance[model_id]["opensearch"]["durations"].append(duration)
            
            if dense_score > 0:
                model_retrieval_performance[model_id]["dense"]["scores"].append(dense_score)
                model_retrieval_performance[model_id]["dense"]["durations"].append(duration)
            
            if reranked_score > 0:
                model_retrieval_performance[model_id]["reranked"]["scores"].append(reranked_score)
                model_retrieval_performance[model_id]["reranked"]["durations"].append(duration)
    

    overall_stats = {}
    for retrieval_model, data in retrieval_performance.items():
        if data["scores"]:
            durations = np.array(data["durations"])
            scores = np.array(data["scores"])
            total_duration = durations.sum()
            
            weighted_avg = np.average(scores, weights=durations) if total_duration > 0 else np.mean(scores)
            
            overall_stats[retrieval_model] = {
                "weighted_average": float(weighted_avg),
                "simple_average": float(np.mean(scores)),
                "std": float(np.std(scores)),
                "num_videos": len(data["scores"]),
                "total_duration": float(total_duration)
            }
    

    per_model_stats = {}
    for model_id, data in model_retrieval_performance.items():
        per_model_stats[model_id] = {}
        for retrieval_model, scores_data in data.items():
            if scores_data["scores"]:
                durations = np.array(scores_data["durations"])
                scores = np.array(scores_data["scores"])
                total_duration = durations.sum()
                
                weighted_avg = np.average(scores, weights=durations) if total_duration > 0 else np.mean(scores)
                
                per_model_stats[model_id][retrieval_model] = {
                    "weighted_average": float(weighted_avg),
                    "simple_average": float(np.mean(scores)),
                    "num_videos": len(scores_data["scores"])
                }
    

    improvements = {
        "dense_vs_opensearch": {},
        "reranked_vs_dense": {},
        "reranked_vs_opensearch": {}
    }
    
    if "opensearch" in overall_stats and "dense" in overall_stats:
        dense_improvement = overall_stats["dense"]["weighted_average"] - overall_stats["opensearch"]["weighted_average"]
        improvements["dense_vs_opensearch"] = {
            "improvement": float(dense_improvement),
            "improvement_percentage": float(dense_improvement / overall_stats["opensearch"]["weighted_average"] * 100) if overall_stats["opensearch"]["weighted_average"] > 0 else 0.0
        }
    
    if "dense" in overall_stats and "reranked" in overall_stats:
        reranked_improvement = overall_stats["reranked"]["weighted_average"] - overall_stats["dense"]["weighted_average"]
        improvements["reranked_vs_dense"] = {
            "improvement": float(reranked_improvement),
            "improvement_percentage": float(reranked_improvement / overall_stats["dense"]["weighted_average"] * 100) if overall_stats["dense"]["weighted_average"] > 0 else 0.0
        }
    
    if "opensearch" in overall_stats and "reranked" in overall_stats:
        reranked_vs_opensearch = overall_stats["reranked"]["weighted_average"] - overall_stats["opensearch"]["weighted_average"]
        improvements["reranked_vs_opensearch"] = {
            "improvement": float(reranked_vs_opensearch),
            "improvement_percentage": float(reranked_vs_opensearch / overall_stats["opensearch"]["weighted_average"] * 100) if overall_stats["opensearch"]["weighted_average"] > 0 else 0.0
        }
    

    results = {
        "overall_performance": overall_stats,
        "per_model_performance": per_model_stats,
        "improvements": improvements
    }
    

    output_file = os.path.join(output_dir, "retrieval_model_analysis.json")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    

    generate_markdown_report(results, output_dir)
    
    
    return results

def generate_markdown_report(results: Dict[str, Any], output_dir: str):
    report_file = os.path.join(output_dir, "retrieval_model_analysis_report.md")
    
    with open(report_file, "w", encoding="utf-8") as f:
        f.write("# Retrieval Model Analysis Report\n\n")
        

        f.write("## Overall Performance by Retrieval Model\n\n")
        f.write("| Retrieval Model | Weighted Average | Simple Average | Std | Num Videos |\n")
        f.write("|-----------------|------------------|----------------|-----|------------|\n")
        for model, stats in results["overall_performance"].items():
            f.write(f"| {model} | {stats['weighted_average']:.4f} | {stats['simple_average']:.4f} | "
                   f"{stats['std']:.4f} | {stats['num_videos']} |\n")
        f.write("\n")
        

        f.write("## Performance Improvements\n\n")
        improvements = results["improvements"]
        if "dense_vs_opensearch" in improvements:
            imp = improvements["dense_vs_opensearch"]
            f.write(f"- **Dense vs OpenSearch**: {imp['improvement']:+.4f} ({imp['improvement_percentage']:+.2f}%)\n")
        if "reranked_vs_dense" in improvements:
            imp = improvements["reranked_vs_dense"]
            f.write(f"- **Reranked vs Dense**: {imp['improvement']:+.4f} ({imp['improvement_percentage']:+.2f}%)\n")
        if "reranked_vs_opensearch" in improvements:
            imp = improvements["reranked_vs_opensearch"]
            f.write(f"- **Reranked vs OpenSearch**: {imp['improvement']:+.4f} ({imp['improvement_percentage']:+.2f}%)\n")
        f.write("\n")
        

        f.write("## Performance by Generative Model\n\n")
        f.write("| Generative Model | OpenSearch | Dense | Reranked |\n")
        f.write("|------------------|------------|-------|----------|\n")
        for model_id, stats in results["per_model_performance"].items():
            opensearch = stats.get("opensearch", {}).get("weighted_average", 0.0)
            dense = stats.get("dense", {}).get("weighted_average", 0.0)
            reranked = stats.get("reranked", {}).get("weighted_average", 0.0)
            f.write(f"| {model_id} | {opensearch:.4f} | {dense:.4f} | {reranked:.4f} |\n")
        f.write("\n")

if __name__ == "__main__":

    evaluation_dir = "/home/key4/JIRArena-exp/evaluation_output"
    output_dir = "/home/key4/JIRArena-exp/iclr_rebuttal/retrieval_model_analysis"
    

    all_model_ids = get_all_model_ids()
    model_ids = all_model_ids
    
    
    results = analyze_retrieval_models(
        evaluation_dir=evaluation_dir,
        output_dir=output_dir,
        model_ids=model_ids
    )
    

