#!/usr/bin/env python3
"""
Parameter Sensitivity Analysis Script for JIR Evaluation

This script addresses Reviewer B3: Evaluation Parameter Selection Rationale
- How were parameters such as similarity threshold 0.55, deduplication threshold 0.75, 
  and temporal balancing coefficient 0.9 selected?
- Was parameter sensitivity analysis conducted?
"""

import sys
import os
import json
import numpy as np
import random
from typing import List, Dict, Any, Tuple, Optional
from collections import defaultdict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
import threading
import time
import traceback

# Add parent directory to path
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from utils import load_jsonl, METAINFO_DIR
from evaluate_utils import get_duration_by_youtube_id
from evaluate import JIREvaluator
from model_config import get_all_model_ids
from similarity_cache import get_cache_file_path, load_match_results, save_match_results

# Lock for model initialization
_model_init_lock = threading.Lock()

# Global evaluator pool: pre-loaded models for each GPU
_evaluator_pool = {}
_evaluator_pool_lock = threading.Lock()


def initialize_evaluator_pool(num_gpus: int, similarity_threshold: float = 0.55):
    """Pre-load evaluator models for each GPU to avoid repeated loading."""
    global _evaluator_pool
    
    print(f"Initializing evaluator pool for {num_gpus} GPUs...")
    
    if not torch.cuda.is_available():
        with _model_init_lock:
            _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
        print("  Initialized CPU evaluator")
        return
    
    for gpu_id in range(num_gpus):
        device = f'cuda:{gpu_id}'
        try:
            with _model_init_lock:
                evaluator = JIREvaluator(similarity_threshold=similarity_threshold, device=device)
                _evaluator_pool[device] = evaluator
            print(f"  Initialized evaluator for {device}")
        except Exception as e:
            print(f"  Warning: Failed to initialize evaluator for {device}: {e}", file=sys.stderr)
            if 'cpu' not in _evaluator_pool:
                with _model_init_lock:
                    _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
    
    if 'cpu' not in _evaluator_pool:
        with _model_init_lock:
            _evaluator_pool['cpu'] = JIREvaluator(similarity_threshold=similarity_threshold, device='cpu')
    
    print(f"Evaluator pool initialized: {list(_evaluator_pool.keys())}")


def get_evaluator(device: str, similarity_threshold: float = 0.55):
    """Get evaluator from pool for the specified device."""
    global _evaluator_pool
    
    with _evaluator_pool_lock:
        if device in _evaluator_pool:
            # Check if threshold matches (if not, create new one)
            evaluator = _evaluator_pool[device]
            if evaluator.similarity_threshold == similarity_threshold:
                return evaluator
        
        # Create new evaluator with correct threshold
        with _model_init_lock:
            evaluator = JIREvaluator(similarity_threshold=similarity_threshold, device=device)
            _evaluator_pool[device] = evaluator
            return evaluator


def process_single_video_parameter(args):
    """Process a single video for parameter sensitivity (for parallel processing)."""
    (model_id, youtube_id, filename, youtube_ids_set, threshold, fuzzy_interval, 
     balancing_coefficient, gpu_id, queries_cache) = args
    
    try:
        if torch.cuda.is_available() and gpu_id is not None:
            device = f'cuda:{gpu_id}'
        else:
            device = 'cpu'
        
        if youtube_id not in youtube_ids_set:
            return None
        
        # Use pre-loaded queries cache if available
        if queries_cache is not None and youtube_id in queries_cache:
            queries_to_match = queries_cache[youtube_id]
        else:
            ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
            if not os.path.exists(ground_truth_file):
                return None
            queries_to_match = load_jsonl(ground_truth_file)
        
        candidate_file = os.path.join(parent_dir, f"data/baseline_stream_runs_{model_id}", filename)
        if not os.path.exists(candidate_file):
            return None
        
        candidate_data = json.load(open(candidate_file, "r"))
        candidate_queries = candidate_data.get("needs", [])
        
        # Check cache with parameter-specific key
        cache_dir = os.path.join(evaluation_dir, ".similarity_cache")
        cache_file = get_cache_file_path(cache_dir, model_id, youtube_id, threshold, fuzzy_interval)
        cached = load_match_results(cache_file)
        
        # Get evaluator (lazy initialization)
        evaluator = None
        def get_evaluator_lazy():
            nonlocal evaluator
            if evaluator is None:
                evaluator = get_evaluator(device, threshold)
                if not hasattr(evaluator, 'model') or evaluator.model is None:
                    raise RuntimeError(f"Failed to initialize JIREvaluator model for {model_id}/{youtube_id}")
            return evaluator
        
        # Use cache if available and matches, otherwise compute
        if cached is not None and len(cached.get("queries_to_match", [])) == len(queries_to_match) and len(cached.get("candidate_queries", [])) == len(candidate_queries):
            results = cached["match_results"]
        else:
            # Compute matches with specific parameters
            evaluator = get_evaluator_lazy()
            results = evaluator.batch_find_matches(queries_to_match, candidate_queries, fuzzy_interval)
            # Save to cache for future use
            save_match_results(cache_file, queries_to_match, candidate_queries, results)
        
        # Create evaluator if not already created (needed for compute_recall, etc.)
        if evaluator is None:
            evaluator = get_evaluator_lazy()
        
        recall = evaluator.compute_recall(queries_to_match, results)
        precision = evaluator.compute_precision(candidate_queries, results)
        relevance = evaluator.evaluate_relevance(results, queries_to_match)
        # Use balancing_coefficient for timeliness evaluation
        timeliness = evaluator.evaluate_timeliness(
            queries_to_match, results, 
            fuzzy_sentence_interval=fuzzy_interval,
            balancing_coeffient=balancing_coefficient
        )
        
        precision_value = precision.get("precision", 0.0)
        
        # Use weighted (not matched_weighted) to include all queries (unmatched contribute 0)
        return {
            "recall": recall.get("recall", 0.0),
            "precision": precision_value,
            "relevance": relevance.get("weighted_ndcg", 0.0),
            "timeliness": timeliness.get("weighted_time_match", 0.0)
        }
    except Exception as e:
        import sys
        print(f"Error processing {youtube_id} for {model_id} (threshold={threshold}, fuzzy={fuzzy_interval}, balance={balancing_coefficient}): {e}", file=sys.stderr)
        traceback.print_exc(file=sys.stderr)
        return None


def load_sample_videos(num_samples: int = 5, seed: int = 42) -> List[str]:
    """
    Load and randomly sample videos from scene_metainfo.
    
    Args:
        num_samples: Number of videos to sample
        seed: Random seed for reproducibility
        
    Returns:
        List of youtube_ids
    """
    random.seed(seed)
    
    # Load from scene_metainfo directory
    scene_metainfo_dir = os.getenv("SCENE_METAINFO_DIR", os.path.join(parent_dir, "data", "scene_metainfo"))
    lecture_file = os.path.join(scene_metainfo_dir, "lecture.jsonl")
    paper_file = os.path.join(scene_metainfo_dir, "paper.jsonl")
    
    all_videos = []
    if os.path.exists(lecture_file):
        all_videos.extend(load_jsonl(lecture_file))
    if os.path.exists(paper_file):
        all_videos.extend(load_jsonl(paper_file))
    
    youtube_ids = [item["youtube_id"] for item in all_videos if "youtube_id" in item]
    
    if len(youtube_ids) < num_samples:
        return youtube_ids
    
    sampled = random.sample(youtube_ids, num_samples)
    print(f"Randomly sampled {len(sampled)} videos (seed={seed}): {sampled}")
    return sampled


def get_selected_models() -> List[str]:
    """
    Get selected models based on model_sizes.txt:
    - DeepSeek-V3 (1)
    - Phi-4 (4)
    - Qwen3 (5)
    - And their multimodal variants
    """
    all_model_ids = get_all_model_ids()
    
    # Base models (match patterns from model_config.py)
    base_models = []
    for model_id in all_model_ids:
        # Match DeepSeek-V3 (could be DeepSeek-V3-0324 or similar)
        if "DeepSeek-V3" in model_id and "_multimodal" not in model_id and "oracle" not in model_id.lower():
            base_models.append(model_id)
        # Match Phi-4
        elif "Phi-4" in model_id and "_multimodal" not in model_id and "oracle" not in model_id.lower():
            base_models.append(model_id)
        # Match Qwen3 (could be Qwen3-4B-Instruct-2507 or similar)
        elif "Qwen3" in model_id and "_multimodal" not in model_id and "oracle" not in model_id.lower():
            base_models.append(model_id)
    
    # Add multimodal variants
    selected_models = list(base_models)
    for base_model in base_models:
        multimodal_variant = f"{base_model}_multimodal"
        if multimodal_variant in all_model_ids:
            selected_models.append(multimodal_variant)
    
    # Remove duplicates and sort
    selected_models = sorted(list(set(selected_models)))
    
    print(f"Selected {len(selected_models)} models: {selected_models}")
    return selected_models


def analyze_parameter_sensitivity(
    evaluation_dir: str,
    output_dir: str,
    model_ids: List[str],
    sample_videos: List[str],
    similarity_thresholds: List[float] = [0.45, 0.50, 0.55, 0.60, 0.65],
    fuzzy_intervals: List[int] = [0, 1, 2, 3],
    balancing_coefficients: List[float] = [0.7, 0.8, 0.9, 1.0],
    num_workers: int = 20,
    use_gpu: bool = True,
    num_gpus: int = 4
) -> Dict[str, Any]:
    """
    Analyze parameter sensitivity.
    
    Args:
        evaluation_dir: Directory containing evaluation results
        output_dir: Output directory for analysis results
        model_ids: List of model IDs to analyze
        sample_videos: List of youtube_ids to use as samples
        similarity_thresholds: List of similarity thresholds to test
        fuzzy_intervals: List of fuzzy sentence intervals to test
        balancing_coefficients: List of temporal balancing coefficients to test
        
    Returns:
        Dictionary with parameter sensitivity analysis results
    """
    os.makedirs(output_dir, exist_ok=True)
    
    sample_videos_set = set(sample_videos)
    
    # Pre-load all queries_to_match to memory (optimization: reduce file I/O)
    print("Pre-loading queries_to_match to memory...")
    queries_cache = {}
    for youtube_id in sample_videos:
        ground_truth_file = os.path.join(parent_dir, f"output/{youtube_id}/jir_references_relevance_score.jsonl")
        if os.path.exists(ground_truth_file):
            try:
                queries_cache[youtube_id] = load_jsonl(ground_truth_file)
            except Exception as e:
                print(f"  Warning: Failed to load queries for {youtube_id}: {e}")
    print(f"  Pre-loaded queries for {len(queries_cache)} videos")
    
    # Initialize evaluator pool (with default threshold, will create new ones as needed)
    if use_gpu and torch.cuda.is_available():
        initialize_evaluator_pool(num_gpus, similarity_threshold=0.55)
    
    # Prepare tasks for similarity threshold sensitivity
    similarity_tasks = []
    for threshold in similarity_thresholds:
        for model_id in model_ids:
            model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
            if not os.path.exists(model_dir):
                continue
            
            for filename in os.listdir(model_dir):
                if not filename.endswith(".json"):
                    continue
                
                youtube_id = filename[:-len(".json")]
                if youtube_id not in sample_videos_set:
                    continue
                
                similarity_tasks.append((model_id, youtube_id, filename, sample_videos_set, 
                                       threshold, 1, 0.9, None, queries_cache))  # fuzzy_interval=1, balancing=0.9 for similarity test
    
    # Prepare tasks for fuzzy interval sensitivity
    fuzzy_tasks = []
    for fuzzy_interval in fuzzy_intervals:
        for model_id in model_ids:
            model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
            if not os.path.exists(model_dir):
                continue
            
            for filename in os.listdir(model_dir):
                if not filename.endswith(".json"):
                    continue
                
                youtube_id = filename[:-len(".json")]
                if youtube_id not in sample_videos_set:
                    continue
                
                fuzzy_tasks.append((model_id, youtube_id, filename, sample_videos_set, 
                                  0.55, fuzzy_interval, 0.9, None, queries_cache))  # threshold=0.55, balancing=0.9 for fuzzy test
    
    # Prepare tasks for balancing coefficient sensitivity
    balancing_tasks = []
    for balancing_coefficient in balancing_coefficients:
        for model_id in model_ids:
            model_dir = os.path.join(evaluation_dir, f"rebuttal_baseline_stream_runs_{model_id}")
            if not os.path.exists(model_dir):
                continue
            
            for filename in os.listdir(model_dir):
                if not filename.endswith(".json"):
                    continue
                
                youtube_id = filename[:-len(".json")]
                if youtube_id not in sample_videos_set:
                    continue
                
                balancing_tasks.append((model_id, youtube_id, filename, sample_videos_set, 
                                      0.55, 1, balancing_coefficient, None, queries_cache))  # threshold=0.55, fuzzy=1 for balancing test
    
    # Assign GPU IDs
    all_tasks = similarity_tasks + fuzzy_tasks + balancing_tasks
    if use_gpu and torch.cuda.is_available():
        tasks_with_gpu = []
        for i, task in enumerate(all_tasks):
            gpu_id = i % num_gpus
            # task structure: (model_id, youtube_id, filename, sample_videos_set, threshold, fuzzy_interval, balancing_coefficient, None, queries_cache)
            # Replace None (7th element) with gpu_id
            task_list = list(task)
            task_list[7] = gpu_id
            tasks_with_gpu.append(tuple(task_list))
        all_tasks = tasks_with_gpu
        similarity_tasks = tasks_with_gpu[:len(similarity_tasks)]
        fuzzy_tasks = tasks_with_gpu[len(similarity_tasks):len(similarity_tasks)+len(fuzzy_tasks)]
        balancing_tasks = tasks_with_gpu[len(similarity_tasks)+len(fuzzy_tasks):]
    
    # Similarity threshold sensitivity
    similarity_sensitivity = defaultdict(lambda: {
        "recall": [], "precision": [], "relevance": [], "timeliness": []
    })
    # Per-model performance for ranking analysis
    similarity_per_model = defaultdict(lambda: defaultdict(lambda: {
        "recall": [], "precision": [], "relevance": [], "timeliness": []
    }))
    
    print(f"\n{'='*60}")
    print(f"Analyzing similarity threshold sensitivity ({len(similarity_tasks)} tasks)...")
    print(f"Using {num_gpus} GPUs for parallel processing")
    print(f"{'='*60}")
    start_time = time.time()
    error_count = 0
    
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_video_parameter, task): task for task in similarity_tasks}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing similarity tasks", mininterval=2.0):
            task = futures[future]
            threshold = task[4]
            model_id = task[0]
            try:
                result = future.result(timeout=300)
            except Exception as e:
                error_count += 1
                if error_count <= 5:
                    print(f"\n  Error in task: {e}", file=sys.stderr, flush=True)
                result = None
            
                     
            if result:
                similarity_sensitivity[threshold]["recall"].append(result["recall"])
                similarity_sensitivity[threshold]["precision"].append(result["precision"])
                similarity_sensitivity[threshold]["relevance"].append(result["relevance"])
                similarity_sensitivity[threshold]["timeliness"].append(result["timeliness"])
                # Per-model tracking
                similarity_per_model[threshold][model_id]["recall"].append(result["recall"])
                similarity_per_model[threshold][model_id]["precision"].append(result["precision"])
                similarity_per_model[threshold][model_id]["relevance"].append(result["relevance"])
                similarity_per_model[threshold][model_id]["timeliness"].append(result["timeliness"])
    
    # Fuzzy interval sensitivity
    fuzzy_sensitivity = defaultdict(lambda: {
        "recall": [], "precision": [], "relevance": [], "timeliness": []
    })
    # Per-model performance for ranking analysis
    fuzzy_per_model = defaultdict(lambda: defaultdict(lambda: {
        "recall": [], "precision": [], "relevance": [], "timeliness": []
    }))
    
    print(f"\n{'='*60}")
    print(f"Analyzing fuzzy interval sensitivity ({len(fuzzy_tasks)} tasks)...")
    print(f"{'='*60}")
    error_count = 0
    
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_video_parameter, task): task for task in fuzzy_tasks}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing fuzzy tasks", mininterval=2.0):
            task = futures[future]
            fuzzy_interval = task[5]
            model_id = task[0]
            try:
                result = future.result(timeout=300)
            except Exception as e:
                error_count += 1
                if error_count <= 5:
                    print(f"\n  Error in task: {e}", file=sys.stderr, flush=True)
                result = None
            
            if result:
                fuzzy_sensitivity[fuzzy_interval]["recall"].append(result["recall"])
                fuzzy_sensitivity[fuzzy_interval]["precision"].append(result["precision"])
                fuzzy_sensitivity[fuzzy_interval]["relevance"].append(result["relevance"])
                fuzzy_sensitivity[fuzzy_interval]["timeliness"].append(result["timeliness"])
                # Per-model tracking
                fuzzy_per_model[fuzzy_interval][model_id]["recall"].append(result["recall"])
                fuzzy_per_model[fuzzy_interval][model_id]["precision"].append(result["precision"])
                fuzzy_per_model[fuzzy_interval][model_id]["relevance"].append(result["relevance"])
                fuzzy_per_model[fuzzy_interval][model_id]["timeliness"].append(result["timeliness"])
    
    # Balancing coefficient sensitivity
    balancing_sensitivity = defaultdict(lambda: {
        "recall": [], "precision": [], "relevance": [], "timeliness": []
    })
    # Per-model performance for ranking analysis
    balancing_per_model = defaultdict(lambda: defaultdict(lambda: {
        "recall": [], "precision": [], "relevance": [], "timeliness": []
    }))
    
    print(f"\n{'='*60}")
    print(f"Analyzing balancing coefficient sensitivity ({len(balancing_tasks)} tasks)...")
    print(f"{'='*60}")
    error_count = 0
    
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_video_parameter, task): task for task in balancing_tasks}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing balancing tasks", mininterval=2.0):
            task = futures[future]
            balancing_coefficient = task[6]
            model_id = task[0]
            try:
                result = future.result(timeout=300)
            except Exception as e:
                error_count += 1
                if error_count <= 5:
                    print(f"\n  Error in task: {e}", file=sys.stderr, flush=True)
                result = None
            
                
            if result:
                balancing_sensitivity[balancing_coefficient]["recall"].append(result["recall"])
                balancing_sensitivity[balancing_coefficient]["precision"].append(result["precision"])
                balancing_sensitivity[balancing_coefficient]["relevance"].append(result["relevance"])
                balancing_sensitivity[balancing_coefficient]["timeliness"].append(result["timeliness"])
                # Per-model tracking
                balancing_per_model[balancing_coefficient][model_id]["recall"].append(result["recall"])
                balancing_per_model[balancing_coefficient][model_id]["precision"].append(result["precision"])
                balancing_per_model[balancing_coefficient][model_id]["relevance"].append(result["relevance"])
                balancing_per_model[balancing_coefficient][model_id]["timeliness"].append(result["timeliness"])
    
    # Calculate statistics
    def calc_stats(sensitivity_dict):
        stats = {}
        for param_value, data in sensitivity_dict.items():
            if data["recall"]:
                stats[param_value] = {
                    "recall": {
                        "mean": float(np.mean(data["recall"])),
                        "std": float(np.std(data["recall"]))
                    },
                    "precision": {
                        "mean": float(np.mean(data["precision"])),
                        "std": float(np.std(data["precision"]))
                    },
                    "relevance": {
                        "mean": float(np.mean(data["relevance"])),
                        "std": float(np.std(data["relevance"]))
                    },
                    "timeliness": {
                        "mean": float(np.mean(data["timeliness"])),
                        "std": float(np.std(data["timeliness"]))
                    }
                }
        return stats
    
    similarity_stats = calc_stats(similarity_sensitivity)
    fuzzy_stats = calc_stats(fuzzy_sensitivity)
    balancing_stats = calc_stats(balancing_sensitivity)
    
    # Find optimal parameters
    optimal_similarity = max(similarity_stats.items(), 
                            key=lambda x: x[1]["recall"]["mean"] + x[1]["precision"]["mean"])[0] if similarity_stats else 0.55
    
    optimal_fuzzy = max(fuzzy_stats.items(),
                       key=lambda x: x[1]["recall"]["mean"] + x[1]["precision"]["mean"])[0] if fuzzy_stats else 1
    
    optimal_balancing = max(balancing_stats.items(),
                           key=lambda x: x[1]["timeliness"]["mean"])[0] if balancing_stats else 0.9
    
    # Analyze ranking stability
    def analyze_ranking_stability(per_model_data, param_name):
        """
        Analyze ranking stability across parameter values.
        
        Args:
            per_model_data: Dict mapping param_value -> model_id -> metrics
            param_name: Name of the parameter for reporting
            
        Returns:
            Dict with ranking stability analysis
        """
        from scipy.stats import spearmanr
        
        metrics = ["recall", "precision", "relevance", "timeliness"]
        stability_results = {}
        
        for metric in metrics:
            # Collect per-model averages for each parameter value
            param_model_scores = {}  # param_value -> {model_id: mean_score}
            
            for param_value, model_data in per_model_data.items():
                param_model_scores[param_value] = {}
                for model_id, scores in model_data.items():
                    if scores[metric]:
                        param_model_scores[param_value][model_id] = np.mean(scores[metric])
            
            # Calculate rankings for each parameter value
            rankings = {}  # param_value -> [(model_id, rank), ...]
            for param_value, model_scores in param_model_scores.items():
                # Sort by score (descending) and assign ranks
                sorted_models = sorted(model_scores.items(), key=lambda x: x[1], reverse=True)
                rankings[param_value] = [(model_id, rank+1) for rank, (model_id, _) in enumerate(sorted_models)]
            
            # Calculate Spearman correlation between rankings across parameter values
            param_values = sorted(param_model_scores.keys())
            if len(param_values) < 2:
                continue
            
            # Build ranking matrix: rows = models, cols = parameter values
            all_models = set()
            for param_value in param_values:
                all_models.update(param_model_scores[param_value].keys())
            all_models = sorted(list(all_models))
            
            ranking_matrix = []  # Each row is a model's rankings across parameter values
            for model_id in all_models:
                model_rankings = []
                for param_value in param_values:
                    # Find rank for this model at this parameter value
                    rank = None
                    for m_id, r in rankings[param_value]:
                        if m_id == model_id:
                            rank = r
                            break
                    if rank is None:
                        # Model not present at this parameter value, use average rank
                        rank = len(all_models) / 2
                    model_rankings.append(rank)
                ranking_matrix.append(model_rankings)
            
            # Calculate pairwise Spearman correlations between parameter values
            correlations = []
            for i in range(len(param_values)):
                for j in range(i+1, len(param_values)):
                    ranks_i = [row[i] for row in ranking_matrix]
                    ranks_j = [row[j] for row in ranking_matrix]
                    if len(set(ranks_i)) > 1 and len(set(ranks_j)) > 1:  # Need variation
                        corr, p_value = spearmanr(ranks_i, ranks_j)
                        correlations.append({
                            "param1": float(param_values[i]),
                            "param2": float(param_values[j]),
                            "correlation": float(corr),
                            "p_value": float(p_value)
                        })
            
            # Calculate average correlation
            avg_correlation = np.mean([c["correlation"] for c in correlations]) if correlations else 0.0
            
            # Calculate ranking changes (max rank difference for each model)
            max_rank_changes = {}
            for model_id in all_models:
                model_ranks = []
                for param_value in param_values:
                    for m_id, r in rankings[param_value]:
                        if m_id == model_id:
                            model_ranks.append(r)
                            break
                if model_ranks:
                    max_rank_changes[model_id] = max(model_ranks) - min(model_ranks)
            
            avg_max_rank_change = np.mean(list(max_rank_changes.values())) if max_rank_changes else 0.0
            
            # Normalize parameter values to consistent string format
            # Use a consistent format to avoid float precision issues
            def normalize_key(pv):
                """Normalize parameter value to consistent string format."""
                try:
                    pv_float = float(pv)
                    # Use consistent formatting: integers as "1", floats as "1.0"
                    if abs(pv_float - int(pv_float)) < 1e-6:
                        return str(int(pv_float))
                    else:
                        # Format with 2 decimal places, remove trailing zeros
                        return f"{pv_float:.2f}".rstrip('0').rstrip('.')
                except (ValueError, TypeError):
                    return str(pv)
            
            normalized_rankings = {}
            for pv, ranks in rankings.items():
                key = normalize_key(pv)
                normalized_rankings[key] = {m_id: int(rank) for m_id, rank in ranks}
            
            stability_results[metric] = {
                "average_spearman_correlation": float(avg_correlation),
                "pairwise_correlations": correlations,
                "max_rank_changes": {k: int(v) for k, v in max_rank_changes.items()},
                "average_max_rank_change": float(avg_max_rank_change),
                "rankings_by_parameter": normalized_rankings
            }
        
        return stability_results
    
    print("\nAnalyzing ranking stability...")
    similarity_ranking_stability = analyze_ranking_stability(similarity_per_model, "similarity_threshold")
    fuzzy_ranking_stability = analyze_ranking_stability(fuzzy_per_model, "fuzzy_interval")
    balancing_ranking_stability = analyze_ranking_stability(balancing_per_model, "balancing_coefficient")
    
    # Compile results
    results = {
        "similarity_threshold_sensitivity": similarity_stats,
        "fuzzy_interval_sensitivity": fuzzy_stats,
        "balancing_coefficient_sensitivity": balancing_stats,
        "ranking_stability": {
            "similarity_threshold": similarity_ranking_stability,
            "fuzzy_interval": fuzzy_ranking_stability,
            "balancing_coefficient": balancing_ranking_stability
        },
        "optimal_parameters": {
            "similarity_threshold": float(optimal_similarity),
            "fuzzy_interval": int(optimal_fuzzy),
            "balancing_coefficient": float(optimal_balancing)
        },
        "current_parameters": {
            "similarity_threshold": 0.55,
            "fuzzy_interval": 1,
            "balancing_coefficient": 0.9
        },
        "parameter_rationale": {
            "similarity_threshold": "Selected based on balance between recall and precision",
            "fuzzy_interval": "Selected to allow flexible temporal matching",
            "balancing_coefficient": "Selected to balance start and end time importance"
        },
        "sample_videos": sample_videos,
        "models_analyzed": model_ids
    }
    
    # Save results
    output_file = os.path.join(output_dir, "parameter_sensitivity.json")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    
    print(f"\n{'='*60}")
    print(f"Parameter sensitivity analysis complete!")
    print(f"Results saved to: {output_file}")
    print(f"\nCurrent parameters:")
    print(f"  Similarity threshold: {results['current_parameters']['similarity_threshold']}")
    print(f"  Fuzzy interval: {results['current_parameters']['fuzzy_interval']}")
    print(f"  Balancing coefficient: {results['current_parameters']['balancing_coefficient']}")
    
    # Print ranking stability summary
    print(f"\nRanking Stability Summary:")
    ranking_stability = results.get("ranking_stability", {})
    all_avg_corrs = []
    for param_name, param_stability in ranking_stability.items():
        if not param_stability:
            continue
        for metric, metric_stability in param_stability.items():
            avg_corr = metric_stability.get("average_spearman_correlation", 0.0)
            if avg_corr > 0:
                all_avg_corrs.append(avg_corr)
                print(f"  {param_name.replace('_', ' ').title()} - {metric}: {avg_corr:.3f}")
    
    if all_avg_corrs:
        overall_avg_corr = np.mean(all_avg_corrs)
        print(f"\n  Overall Average Spearman Correlation: {overall_avg_corr:.3f}")
        if overall_avg_corr > 0.9:
            print(f"  ✅ Excellent stability: Rankings are highly consistent across parameter values")
        elif overall_avg_corr > 0.7:
            print(f"  ✅ Good stability: Rankings are generally consistent across parameter values")
        elif overall_avg_corr > 0.5:
            print(f"  ⚠️  Moderate stability: Rankings show some variation across parameter values")
        else:
            print(f"  ❌ Low stability: Rankings vary significantly across parameter values")
    
    print(f"\nNote: To generate visualizations, run:")
    
    return results


if __name__ == "__main__":
    # Configuration
    evaluation_dir = os.getenv("EVALUATION_DIR", "../evaluation_output")
    output_dir = os.getenv("OUTPUT_DIR", "../iclr_rebuttal/parameter_sensitivity")
    
    # Get selected models
    model_ids = get_selected_models()
    
    # Load sample videos (randomly select 5)
    sample_videos = load_sample_videos(num_samples=34, seed=42)
    
    print(f"Analyzing {len(model_ids)} models for parameter sensitivity: {model_ids}")
    print(f"Using {len(sample_videos)} sample videos: {sample_videos}")
    
    print("\nStarting parameter sensitivity analysis...")
    print(f"Using GPU: {torch.cuda.is_available()}, GPUs available: {torch.cuda.device_count() if torch.cuda.is_available() else 0}")
    
    results = analyze_parameter_sensitivity(
        evaluation_dir=evaluation_dir,
        output_dir=output_dir,
        model_ids=model_ids,
        sample_videos=sample_videos,
        similarity_thresholds=[0.45, 0.50, 0.55, 0.60, 0.65],
        fuzzy_intervals=[0, 1, 2, 3],
        balancing_coefficients=[0.7, 0.8, 0.9, 1.0],
        num_workers=20,
        use_gpu=True,
        num_gpus=4
    )
