
import sys
import os
import json
import numpy as np
import pandas as pd
from typing import List, Dict, Any, Tuple
from collections import defaultdict
from tqdm import tqdm
from scipy.stats import pearsonr
from itertools import combinations
from concurrent.futures import ThreadPoolExecutor, as_completed

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from utils import load_jsonl, METAINFO_DIR
from evaluate_utils import get_duration_by_youtube_id
from evaluate import JIREvaluator
from visualization_utils import (
    plot_performance_comparison_bar,
    plot_correlation_heatmap,
    plot_scatter_with_trend,
    plot_model_performance_radar
)
from model_config import get_all_model_ids, categorize_models
from similarity_cache import get_or_compute_matches

PROPRIETARY_MODELS = ["gpt-4o", "claude-3-7", "gemini"]
OPEN_SOURCE_MODELS = ["Phi-4", "Qwen3-4B-Instruct-2507"]

def load_model_sizes(model_sizes_file: str = "../data/metainfo/model_sizes.txt") -> Dict[str, int]:
    model_sizes = {}
    if os.path.exists(model_sizes_file):
        with open(model_sizes_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue

                if '(' in line and ')' in line:
                    parts = line.split('(')
                    if len(parts) == 2:
                        model_name = parts[0].strip()
                        size_str = parts[1].rstrip(')').strip()

                        size_str = size_str.rstrip('B').strip()
                        try:
                            size = int(float(size_str))
                            model_sizes[model_name] = size

                            if "DeepSeek" in model_name:
                                model_sizes["DeepSeek-V3-0324"] = size
                            if "Llama" in model_name:
                                model_sizes["Llama-3"] = size
                            if "Qwen3" in model_name:
                                model_sizes["Qwen3-4B-Instruct-2507"] = size
                            if "Cohere" in model_name:
                                model_sizes["Cohere"] = size
                            if "Phi-4" in model_name:
                                model_sizes["Phi-4"] = size
                        except ValueError:
                            continue
    return model_sizes

def get_small_and_large_models(model_ids: List[str], size_threshold: int = 50) -> Tuple[List[str], List[str]]:
    model_sizes = load_model_sizes()
    small_models = []
    large_models = []

    for model_id in model_ids:

        size = None
        for key, value in model_sizes.items():
            if key in model_id or model_id in key:
                size = value
                break
        if size is not None:
            if size <= size_threshold:
                small_models.append(model_id)
            else:
                large_models.append(model_id)
        else:

            if model_id in ["gpt-4o", "claude-3-7", "gemini", "Grok-3", "Mistral"]:
                large_models.append(model_id)
            elif model_id in ["Phi-4", "Qwen3-4B-Instruct-2507"]:
                small_models.append(model_id)
    return small_models, large_models

def load_single_model_performance(args):

    evaluation_dir, model_id = args
    return load_model_performance(evaluation_dir, model_id)

def load_model_performance(evaluation_dir: str, model_id: str) -> Dict[str, Any]:
    model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
    if not os.path.exists(model_dir):
        return None

    youtube_ids = load_jsonl(os.path.join(METAINFO_DIR, "lecture_.jsonl")) + \
                  load_jsonl(os.path.join(METAINFO_DIR, "paper_.jsonl"))
    youtube_ids = [item["youtube_id"] for item in youtube_ids]
    metrics = {
        "recall": [],
        "precision": [],
        "relevance_opensearch": [],
        "relevance_dense": [],
        "relevance_reranked": [],
        "timeliness_start": [],
        "timeliness_end": [],
        "timeliness_avg": [],
        "duration": [],
        "total_candidates": [],
        "matched_candidates": []
    }
    for filename in os.listdir(model_dir):
        if not filename.endswith(".json"):
            continue
        youtube_id = filename[:-len(".json")]
        if youtube_id not in youtube_ids:
            continue
        evaluation_file = os.path.join(model_dir, filename)
        if not os.path.exists(evaluation_file):
            continue
        try:
            data = json.load(open(evaluation_file, "r"))
            duration = data.get("duration", 0)
            recall = data.get("recall", {}).get("recall", 0.0)
            precision_data = data.get("precision", {})
            precision = precision_data.get("precision", 0.0)
            total_candidates = precision_data.get("total_candidates", 0)
            matched_candidates = precision_data.get("matched_candidates", 0)
            relevance = data.get("relevance", {})
            relevance_opensearch = relevance.get("weighted_ndcg", 0.0)
            relevance_dense = relevance.get("biencoder", {}).get("weighted_ndcg", 0.0) if isinstance(relevance.get("biencoder"), dict) else 0.0
            relevance_reranked = relevance.get("reranked", {}).get("weighted_ndcg", 0.0) if isinstance(relevance.get("reranked"), dict) else 0.0
            timeliness = data.get("timeliness", {})

            timeliness_start = timeliness.get("weighted_start_match", 0.0)
            timeliness_end = timeliness.get("weighted_end_match", 0.0)
            timeliness_avg = timeliness.get("weighted_time_match", 0.0)
            metrics["recall"].append(recall)
            metrics["precision"].append(precision)
            metrics["relevance_opensearch"].append(relevance_opensearch)
            metrics["relevance_dense"].append(relevance_dense)
            metrics["relevance_reranked"].append(relevance_reranked)
            metrics["timeliness_start"].append(timeliness_start)
            metrics["timeliness_end"].append(timeliness_end)
            metrics["timeliness_avg"].append(timeliness_avg)
            metrics["duration"].append(duration)
            metrics["total_candidates"].append(total_candidates)
            metrics["matched_candidates"].append(matched_candidates)
        except Exception as e:
            continue
    if not metrics["duration"]:
        return None

    durations = np.array(metrics["duration"])
    total_duration = durations.sum()
    if total_duration == 0:
        return None
    weighted_metrics = {}
    for key in ["recall", "precision", "relevance_opensearch", "relevance_dense", 
                "relevance_reranked", "timeliness_start", "timeliness_end", "timeliness_avg"]:
        values = np.array(metrics[key])
        weighted_avg = np.average(values, weights=durations)
        weighted_metrics[key] = weighted_avg

    weighted_metrics["total_candidates"] = sum(metrics["total_candidates"])
    weighted_metrics["matched_candidates"] = sum(metrics["matched_candidates"])
    weighted_metrics["num_videos"] = len(metrics["duration"])
    weighted_metrics["total_duration"] = total_duration
    return weighted_metrics

def analyze_cross_metric_correlation(model_performances: Dict[str, Dict[str, float]]) -> Dict[str, Any]:

    metrics_list = ["recall", "precision", "relevance_opensearch", "timeliness_avg"]
    data = {metric: [] for metric in metrics_list}
    for model_id, perf in model_performances.items():
        if perf is None:
            continue
        for metric in metrics_list:
            data[metric].append(perf.get(metric, 0.0))

    correlations = {}
    for metric1, metric2 in combinations(metrics_list, 2):
        if len(data[metric1]) > 1 and len(data[metric2]) > 1:
            corr, p_value = pearsonr(data[metric1], data[metric2])
            correlations[f"{metric1}_{metric2}"] = {
                "correlation": float(corr),
                "p_value": float(p_value)
            }
    return correlations

def analyze_model_rankings(model_performances: Dict[str, Dict[str, float]]) -> Dict[str, List[Tuple[str, float]]]:
    metrics_list = ["recall", "precision", "relevance_opensearch", "timeliness_avg"]
    rankings = {}
    for metric in metrics_list:
        model_scores = []
        for model_id, perf in model_performances.items():
            if perf is not None:
                score = perf.get(metric, 0.0)
                model_scores.append((model_id, score))
        model_scores.sort(key=lambda x: x[1], reverse=True)
        rankings[metric] = model_scores
    return rankings

def load_model_display_names(model_sizes_file: str = "../data/metainfo/model_sizes.txt") -> Dict[str, str]:
    display_names = {}
    if os.path.exists(model_sizes_file):
        with open(model_sizes_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue

                if '(' in line and ')' in line:
                    parts = line.split('(')
                    if len(parts) == 2:
                        model_name = parts[0].strip()
                        size_str = parts[1].rstrip(')').strip()
                        display_name = f"{model_name} ({size_str})"

                        if "DeepSeek" in model_name:
                            display_names["DeepSeek-V3-0324"] = display_name
                        if "Llama" in model_name:
                            display_names["Llama-3"] = display_name
                        if "Qwen3" in model_name:
                            display_names["Qwen3-4B-Instruct-2507"] = display_name
                        if "Cohere" in model_name:
                            display_names["Cohere"] = display_name
                        if "Phi-4" in model_name:
                            display_names["Phi-4"] = display_name
    return display_names

def analyze_model_size_trend(
    model_performances: Dict[str, Dict[str, float]],
    evaluation_dir: str
) -> Dict[str, Any]:
    model_sizes = load_model_sizes()

    models_with_sizes = []
    for model_id, perf in model_performances.items():
        if perf is None:
            continue

        size = None
        for key, value in model_sizes.items():
            if key in model_id or model_id in key:
                size = value
                break
        if size is not None:
            models_with_sizes.append({
                "model_id": model_id,
                "size": size,
                "performance": perf
            })

    models_with_sizes.sort(key=lambda x: x["size"])

    metrics_list = ["recall", "precision", "relevance_opensearch", "timeliness_avg"]
    trend_data = {}
    for metric in metrics_list:
        sizes = [m["size"] for m in models_with_sizes]
        scores = [m["performance"].get(metric, 0.0) for m in models_with_sizes]
        model_ids = [m["model_id"] for m in models_with_sizes]

        if len(sizes) > 1 and len(scores) > 1:
            corr, p_value = pearsonr(sizes, scores)
        else:
            corr, p_value = 0.0, 1.0
        trend_data[metric] = {
            "sizes": sizes,
            "scores": scores,
            "model_ids": model_ids,
            "correlation": float(corr),
            "p_value": float(p_value)
        }
    return {
        "trend_data": trend_data,
        "models_analyzed": [m["model_id"] for m in models_with_sizes]
    }

def analyze_proprietary_vs_open_source(
    model_performances: Dict[str, Dict[str, float]]
) -> Dict[str, Any]:
    proprietary_perfs = {}
    open_source_perfs = {}
    for model_id, perf in model_performances.items():
        if perf is None:
            continue
        if model_id in PROPRIETARY_MODELS:
            proprietary_perfs[model_id] = perf
        elif model_id in OPEN_SOURCE_MODELS:
            open_source_perfs[model_id] = perf

    metrics_list = ["recall", "precision", "relevance_opensearch", "timeliness_avg"]
    comparison = {}
    for metric in metrics_list:
        proprietary_scores = [perf.get(metric, 0.0) for perf in proprietary_perfs.values()]
        open_source_scores = [perf.get(metric, 0.0) for perf in open_source_perfs.values()]
        comparison[metric] = {
            "proprietary": {
                "mean": float(np.mean(proprietary_scores)) if proprietary_scores else 0.0,
                "std": float(np.std(proprietary_scores)) if proprietary_scores else 0.0,
                "models": list(proprietary_perfs.keys())
            },
            "open_source": {
                "mean": float(np.mean(open_source_scores)) if open_source_scores else 0.0,
                "std": float(np.std(open_source_scores)) if open_source_scores else 0.0,
                "models": list(open_source_perfs.keys())
            },
            "gap": float(np.mean(proprietary_scores) - np.mean(open_source_scores)) if (proprietary_scores and open_source_scores) else 0.0,
            "gap_percentage": float((np.mean(proprietary_scores) - np.mean(open_source_scores)) / np.mean(proprietary_scores) * 100) if (proprietary_scores and open_source_scores and np.mean(proprietary_scores) > 0) else 0.0
        }
    return comparison

def analyze_multimodal_vs_text(
    model_performances: Dict[str, Dict[str, float]]
) -> Dict[str, Any]:
    multimodal_perfs = {}
    text_perfs = {}
    for model_id, perf in model_performances.items():
        if perf is None:
            continue
        if "_multimodal" in model_id.lower():

            base_model = model_id.replace("_multimodal", "")
            multimodal_perfs[base_model] = perf
        elif model_id not in ["oracle"]:

            multimodal_variant = f"{model_id}_multimodal"
            if multimodal_variant not in model_performances:

                text_perfs[model_id] = perf

    for model_id, perf in model_performances.items():
        if perf is None or model_id == "oracle":
            continue
        if "_multimodal" not in model_id.lower():
            multimodal_variant = f"{model_id}_multimodal"
            if multimodal_variant in model_performances:

                text_perfs[model_id] = perf

    metrics_list = ["recall", "precision", "relevance_opensearch", "timeliness_avg"]
    comparison = {}
    for metric in metrics_list:
        multimodal_scores = [perf.get(metric, 0.0) for perf in multimodal_perfs.values()]
        text_scores = [perf.get(metric, 0.0) for perf in text_perfs.values()]
        comparison[metric] = {
            "multimodal": {
                "mean": float(np.mean(multimodal_scores)) if multimodal_scores else 0.0,
                "std": float(np.std(multimodal_scores)) if multimodal_scores else 0.0,
                "models": list(multimodal_perfs.keys())
            },
            "text": {
                "mean": float(np.mean(text_scores)) if text_scores else 0.0,
                "std": float(np.std(text_scores)) if text_scores else 0.0,
                "models": list(text_perfs.keys())
            },
            "gap": float(np.mean(multimodal_scores) - np.mean(text_scores)) if (multimodal_scores and text_scores) else 0.0,
            "gap_percentage": float((np.mean(multimodal_scores) - np.mean(text_scores)) / np.mean(text_scores) * 100) if (multimodal_scores and text_scores and np.mean(text_scores) > 0) else 0.0
        }
    return comparison

def process_single_video_matching(args):

    model_id, youtube_id, filename, youtube_ids_set, similarity_threshold, fuzzy_sentence_interval, gpu_id = args
    try:
        import torch
        if torch.cuda.is_available() and gpu_id is not None:
            device = f'cuda:{gpu_id}'
        else:
            device = 'cpu'
        if youtube_id not in youtube_ids_set:
            return None
        
        evaluation_dir = os.getenv("EVALUATION_DIR", "../evaluation_output")
        model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
        evaluation_file = os.path.join(model_dir, filename)
        if not os.path.exists(evaluation_file):
            return None
        ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
        if not os.path.exists(ground_truth_file):
            return None
        candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{model_id}", filename)
        if not os.path.exists(candidate_file):
            return None
        queries_to_match = load_jsonl(ground_truth_file)
        candidate_data = json.load(open(candidate_file, "r"))
        candidate_queries = candidate_data.get("needs", [])

        def get_evaluator():
            return JIREvaluator(similarity_threshold=similarity_threshold, device=device)

        results = get_or_compute_matches(
            model_id, youtube_id, queries_to_match, candidate_queries,
            get_evaluator, similarity_threshold, fuzzy_sentence_interval,
            use_cache=True
        )

        matching_patterns = {"1_to_1": 0, "1_to_many": 0, "many_to_1": 0, "many_to_many": 0}
        candidate_match_count = defaultdict(int)
        for idx, matches in results.items():
            num_matches = len(matches)
            if num_matches == 0:
                continue
            elif num_matches == 1:
                matching_patterns["1_to_1"] += 1
            else:
                matching_patterns["1_to_many"] += 1
            for match, _ in matches:
                candidate_id = id(match)
                candidate_match_count[candidate_id] += 1
        for candidate_id, count in candidate_match_count.items():
            if count > 1:
                matching_patterns["many_to_1"] += count - 1
        return matching_patterns
    except Exception as e:
        return None

def analyze_many_to_many_matching(
    evaluation_dir: str,
    model_ids: List[str],
    similarity_threshold: float = 0.55,
    fuzzy_sentence_interval: int = 1,
    num_workers: int = 32,
    use_gpu: bool = True,
    num_gpus: int = 4
) -> Dict[str, Any]:
    youtube_ids = load_jsonl(os.path.join(METAINFO_DIR, "lecture_.jsonl")) + \
                  load_jsonl(os.path.join(METAINFO_DIR, "paper_.jsonl"))
    youtube_ids_set = set([item["youtube_id"] for item in youtube_ids])

    tasks = []
    for model_id in model_ids:
        model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
        if not os.path.exists(model_dir):
            continue
        for filename in os.listdir(model_dir):
            if not filename.endswith(".json"):
                continue
            youtube_id = filename[:-len(".json")]
            if youtube_id not in youtube_ids_set:
                continue
            tasks.append((model_id, youtube_id, filename, youtube_ids_set, 
                         similarity_threshold, fuzzy_sentence_interval, None))

    if use_gpu:
        import torch
        if torch.cuda.is_available():
            tasks_with_gpu = []
            for i, task in enumerate(tasks):
                gpu_id = i % num_gpus
                tasks_with_gpu.append(task[:-1] + (gpu_id,))
            tasks = tasks_with_gpu
    matching_patterns = {
        "1_to_1": 0,
        "1_to_many": 0,
        "many_to_1": 0,
        "many_to_many": 0
    }

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_video_matching, task): task for task in tasks}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Analyzing matching patterns"):
            result = future.result()
            if result:
                for key in matching_patterns:
                    matching_patterns[key] += result.get(key, 0)
    total = sum(matching_patterns.values())
    if total > 0:
        matching_patterns = {k: v / total for k, v in matching_patterns.items()}
    return matching_patterns

def comprehensive_analysis(
    evaluation_dir: str,
    output_dir: str,
    model_ids: List[str]
) -> Dict[str, Any]:
    os.makedirs(output_dir, exist_ok=True)
    model_performances = {}

    tasks = [(evaluation_dir, model_id) for model_id in model_ids]
    with ThreadPoolExecutor(max_workers=min(len(tasks), 8)) as executor:
        futures = {executor.submit(load_single_model_performance, task): task[1] for task in tasks}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Loading models"):
            model_id = futures[future]
            try:
                perf = future.result()
                if perf is not None:
                    model_performances[model_id] = perf
            except Exception as e:
                continue
    correlations = analyze_cross_metric_correlation(model_performances)
    rankings = analyze_model_rankings(model_performances)
    size_trend = analyze_model_size_trend(model_performances, evaluation_dir)
    proprietary_vs_open = analyze_proprietary_vs_open_source(model_performances)
    multimodal_vs_text = analyze_multimodal_vs_text(model_performances)

    output_file = os.path.join(output_dir, "comprehensive_analysis.json")
    matching_patterns = None
    if os.path.exists(output_file):
        try:
            existing_results = json.load(open(output_file, "r"))
            if "many_to_many_matching" in existing_results:
                matching_patterns = existing_results["many_to_many_matching"]
        except Exception as e:
            pass

    if matching_patterns is None:
        matching_patterns = analyze_many_to_many_matching(
        evaluation_dir, model_ids, 
            num_workers=32, use_gpu=True, num_gpus=4
    )

    pattern_summary = generate_performance_pattern_summary(
        model_performances, rankings, correlations, size_trend, proprietary_vs_open
    )

    results = {
        "model_performances": {k: {m: float(v) if isinstance(v, (np.floating, np.integer)) else v 
                                   for m, v in p.items()} 
                               for k, p in model_performances.items()},
        "cross_metric_correlation": correlations,
        "model_rankings": {k: [(m, float(s)) for m, s in v] for k, v in rankings.items()},
        "model_size_trend": size_trend,
        "proprietary_vs_open_source": proprietary_vs_open,
        "multimodal_vs_text": multimodal_vs_text,
        "many_to_many_matching": matching_patterns,
        "performance_pattern_summary": pattern_summary
    }

    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    return results

def generate_performance_pattern_summary(
    model_performances: Dict[str, Dict[str, float]],
    rankings: Dict[str, List[Tuple[str, float]]],
    correlations: Dict[str, Any],
    size_trend: Dict[str, Any],
    proprietary_vs_open: Dict[str, Any]
) -> Dict[str, Any]:

    model_strengths = defaultdict(list)
    for metric, ranking in rankings.items():
        if ranking:
            top_model = ranking[0][0]
            model_strengths[top_model].append(metric)

    trade_offs = []
    if "recall_precision" in correlations:
        corr = correlations["recall_precision"]["correlation"]
        if corr < 0:
            trade_offs.append("Recall vs Precision (negative correlation)")

    size_correlations = {}
    if "trend_data" in size_trend:
        for metric, trend in size_trend["trend_data"].items():
            size_correlations[metric] = {
                "correlation": trend["correlation"],
                "p_value": trend["p_value"]
            }
    summary = {
        "model_strengths": {k: v for k, v in model_strengths.items()},
        "trade_offs": trade_offs,
        "size_trend_correlations": size_correlations,
        "proprietary_open_source_gap": {
            metric: comp["gap_percentage"] 
            for metric, comp in proprietary_vs_open.items()
        }
    }
    return summary

if __name__ == "__main__":

    evaluation_dir = os.getenv("EVALUATION_DIR", "../evaluation_output")
    output_dir = os.getenv("OUTPUT_DIR", "../iclr_rebuttal/comprehensive_analysis")

    all_model_ids = get_all_model_ids()
    model_ids = all_model_ids
    results = comprehensive_analysis(
        evaluation_dir=evaluation_dir,
        output_dir=output_dir,
        model_ids=model_ids
    )
