# stability_prediction.py
"""
Calculs des métriques sur les prédictions :
- Ambiguity
- Discrepancy
- Moyenne et écart-type des stds
- Jaccard, RBO, KL pour les prédictions
- Extraction et agrégation du MRR à partir des fichiers metrics.json des runs
"""
import os
import json
import numpy as np
from scipy.special import softmax
from scipy.stats import entropy
from typing import List, Dict

def compute_ambiguity(runs: List[Dict], k_values: List[int] = [1, 5, 10]) -> Dict[str, float]:
    """
    Compute ambiguity for multiple K values.
    
    Args:
        runs: List of run dictionaries containing 'ranks' array
        k_values: List of K values to compute ambiguity for
        
    Returns:
        Dictionary with ambiguity scores for each K value
    """

    
    ranks = [np.array(run['preds']['truth_ranks']) for run in runs]
    if len(ranks) < 2:
        return {f'ambiguity@{K}': "NaN" for K in k_values}
    ranks = np.stack(ranks)  # (n_runs, n_queries)
    n_runs, n_queries = ranks.shape
    
    results = {}
    for K in k_values:
        ambiguous = 0
        for i in range(n_queries):
            # Un point est ambigu s'il n'est pas toujours < K ou toujours >= K
            below = np.all(ranks[:, i] <= K)
            above = np.all(ranks[:, i] > K)
            if not (below or above):
                ambiguous += 1
        results[f'ambiguity@{K}'] = float(ambiguous / n_queries)
    
    return results

def compute_discrepancy(runs: List[Dict], k_values: List[int] = [1, 5, 10]) -> Dict[str, float]:
    """
    Compute discrepancy for multiple K values.
    
    Args:
        runs: List of run dictionaries containing 'ranks' array
        k_values: List of K values to compute discrepancy for
        
    Returns:
        Dictionary with discrepancy scores for each K value
    """
    
    ranks = [np.array(run['preds']['truth_ranks']) for run in runs]
    if len(ranks) < 2:
        return {f'discrepancy@{K}': "NaN" for K in k_values}
    ranks = np.stack(ranks)  # (n_runs, n_queries)
    n_runs, n_queries = ranks.shape
    
    results = {}
    for K in k_values:
        max_disc = 0.0
        for i in range(n_runs):
            for j in range(i+1, n_runs):
            # Pour chaque triple, diff = (un run < K, l'autre >= K) ou l'inverse
                # Pour chaque triple, diff = (un run < K, l'autre >= K) ou l'inverse
                diff_mask = (ranks[i] <= K) != (ranks[j] <= K)
                # if np.any(diff_mask):
                #     print(f"[DEBUG] Discrepancy found between run {i} and {j} for queries {np.where(diff_mask)[0]}")
                #     print(f"[DEBUG] Ranks for run {i}: {ranks[i][diff_mask]}")
                #     print(f"[DEBUG] Ranks for run {j}: {ranks[j][diff_mask]}")
                diff = np.sum(diff_mask) / n_queries
                if diff > max_disc:
                    max_disc = diff
        results[f'discrepancy@{K}'] = float(max_disc)
    
    return results

def compute_rank_stats(runs: List[Dict]) -> Dict[str, float]:
    """
    Calcule la moyenne et l'écart-type des stds des ranks et truths sur les runs.
    """
    out = {}

    # On suppose que chaque run a 'truth_ranks'
    if 'truth_ranks' in runs[0] and runs[0]['truth_ranks']:
        ranks_array = np.stack([np.array(run['truth_ranks']) for run in runs])
        per_triple_rank_std = np.std(ranks_array, axis=0, ddof=0)
        out['mean_of_rank_stds'] = float(np.mean(per_triple_rank_std))
        out['std_of_rank_stds'] = float(np.std(per_triple_rank_std, ddof=0))

    return out

def compute_pred_jaccard(preds1, preds2, k=10):
    """
    Compute Jaccard similarity between top-K predictions of two models.
    
    Args:
        preds1: Dictionary containing 'top_k_entities' from first model
        preds2: Dictionary containing 'top_k_entities' from second model
        k: Number of top predictions to consider (must be <= k_value used during prediction)
    """
    # Get top-k predictions for each query
    topk1 = np.array([preds1['top_k_entities'][i][:k] for i in range(len(preds1['top_k_entities']))])
    topk2 = np.array([preds2['top_k_entities'][i][:k] for i in range(len(preds2['top_k_entities']))])
    
    # Compute Jaccard similarity for each query
    jaccards = []
    for i in range(len(topk1)):
        set1 = set(topk1[i])
        set2 = set(topk2[i])
        intersection = len(set1 & set2)
        union = len(set1 | set2)
        jaccards.append(intersection / union if union > 0 else 1.0)
    
    return np.mean(jaccards)



def compute_pred_rbo(preds1, preds2, k=10):
    """
    Compute Rank-Biased Overlap (RBO) between top-K predictions of two models.
    Returns a value between 0 and 1, where 1 means identical rankings.
    
    Args:
        preds1: Dictionary containing 'top_k_entities' from first model
        preds2: Dictionary containing 'top_k_entities' from second model
        k: Number of top predictions to consider (must be <= k_value used during prediction)
    """
    # Get top-k predictions for each query
    topk1 = np.array([preds1['top_k_entities'][i][:k] for i in range(len(preds1['top_k_entities']))])
    topk2 = np.array([preds2['top_k_entities'][i][:k] for i in range(len(preds2['top_k_entities']))])
    
    rbo_scores = []
    
    for i in range(len(topk1)):
        score = 0.0
        for d in range(1, k + 1):
            # Get top d elements
            set1 = set(topk1[i, :d])
            set2 = set(topk2[i, :d])
            
            # Calculate overlap
            inter = len(set1 & set2)
            
            # Add weighted overlap to score
            score += inter / d
        
        # Normalize by k
        rbo_scores.append(score / k)
    
    # Return average RBO across all queries
    return np.mean(rbo_scores) if rbo_scores else 0.0

def compute_pred_kl(preds1, preds2, temperature=1.0):
    """
    Compute KL divergence between prediction distributions of two models.
    For entities not in top_k, we assume their score is -inf (which becomes 0 after softmax).
    
    Args:
        preds1: Dictionary containing 'top_k_entities' and 'top_k_scores' from first model
        preds2: Dictionary containing 'top_k_entities' and 'top_k_scores' from second model
        temperature: Temperature for softmax
    """
    # Get the union of all entities in top-k for both models
    all_entities = []
    for i in range(len(preds1['top_k_entities'])):
        entities1 = set(preds1['top_k_entities'][i])
        entities2 = set(preds2['top_k_entities'][i])
        all_entities.append(entities1.union(entities2))
    
    # For each query, create a vector with scores for all entities
    def get_full_scores(preds, query_idx, all_entities_query):
        # Initialize with -inf (will become 0 after softmax)
        full_scores = np.full(len(all_entities_query), -np.inf)
        entity_to_idx = {e: i for i, e in enumerate(sorted(all_entities_query))}
        
        # Fill in the scores we have
        for e, s in zip(preds['top_k_entities'][query_idx], preds['top_k_scores'][query_idx]):
            if e in entity_to_idx:
                full_scores[entity_to_idx[e]] = s
        
        return full_scores
    
    # Compute KL divergence for each query
    kl_divs = []
    for i in range(len(preds1['top_k_entities'])):
        # Get full score vectors for this query
        s1 = get_full_scores(preds1, i, all_entities[i])
        s2 = get_full_scores(preds2, i, all_entities[i])
        
        # Apply softmax with temperature
        def softmax_with_temp(x):
            e_x = np.exp((x - np.max(x)) / temperature)
            return e_x / e_x.sum()
        
        p = softmax_with_temp(s1)
        q = softmax_with_temp(s2)
        
        # Compute KL divergence
        kl = np.sum(p * (np.log(p + 1e-10) - np.log(q + 1e-10)))
        kl_divs.append(kl)
    
    return np.mean(kl_divs) if kl_divs else 0.0

def compute_all_prediction_metrics(preds_list, k_values=[1, 5, 10]):
    """
    Compute all prediction-based metrics for a list of predictions.
    
    Args:
        preds_list: List of prediction dictionaries from different runs
        k_values: List of K values to compute metrics for
    """
    results = {}
    n_models = len(preds_list)
    
    if n_models < 2:
        return {}
    
    # For each K value, compute metrics
    for k in k_values:
        jaccards = []
        rbos = []
        kls = []
        
        # Compute metrics for each pair of runs
        for i in range(n_models):
            for j in range(i + 1, n_models):
                jaccard = compute_pred_jaccard(preds_list[i], preds_list[j], k)
                rbo = compute_pred_rbo(preds_list[i], preds_list[j], k)
                kl = compute_pred_kl(preds_list[i], preds_list[j])
                
                jaccards.append(jaccard)
                rbos.append(rbo)
                kls.append(kl)
        
        # Store results for this K value
        results[f'pred_jaccard@{k}'] = float(np.mean(jaccards))
        results[f'pred_jaccard@{k}_std'] = float(np.std(jaccards))
        results[f'pred_rbo@{k}'] = float(np.mean(rbos))
        results[f'pred_rbo@{k}_std'] = float(np.std(rbos))
        results[f'pred_kl@{k}'] = float(np.mean(kls))
        results[f'pred_kl@{k}_std'] = float(np.std(kls))
    print("results:", results)
    return results

def compute_mrr_stats_from_runs(runs: List[Dict]) -> Dict[str, float]:
    """
    Récupère le MRR de chaque run dans metrics.json["pessimistic"]["MRR"],
    puis calcule la moyenne et l'écart-type pour le groupe.

    Args:
        runs: liste de dictionnaires contenant au moins la clé 'run_dir'.

    Returns:
        Dict avec 'mrr_mean' et 'mrr_std'. Si aucun MRR n'est trouvé, renvoie "NaN".
    """
    mrr_values = []
    hit_at_1_values = []
    hit_at_10_values = []
    mr_values = []
    for run in runs:
        run_dir = run.get('run_dir')
        if not run_dir:
            continue
        metrics_path = os.path.join(run_dir, 'metrics.json')
        if not os.path.exists(metrics_path):
            # Pas de metrics.json pour ce run
            continue
        try:
            with open(metrics_path, 'r') as f:
                metrics = json.load(f)
            mrr = metrics.get('pessimistic', {}).get('MRR', None)
            hit_at_1 = metrics.get('pessimistic', {}).get('Hit@1', None)
            hit_at_10 = metrics.get('pessimistic', {}).get('Hit@10', None)
            mr = metrics.get('pessimistic', {}).get('MR', None)
            if mrr is not None:
                # Force cast to float in case it's not
                mrr_values.append(float(mrr))
            if hit_at_1 is not None:
                hit_at_1_values.append(float(hit_at_1))
            if hit_at_10 is not None:
                hit_at_10_values.append(float(hit_at_10))
            if mr is not None:
                mr_values.append(float(mr))
        except Exception as e:
            # Ignore silently but could be logged by caller if needed
            pass

    if len(mrr_values) == 0:
        return {"mrr_mean": "NaN", "mrr_std": "NaN", "mrr_count": 0}

    return {
        "mrr_mean": float(np.mean(mrr_values)),
        "mrr_std": float(np.std(mrr_values, ddof=0)),
        "hit@1_mean": float(np.mean(hit_at_1_values)),
        "hit@1_std": float(np.std(hit_at_1_values, ddof=0)),
        "hit@10_mean": float(np.mean(hit_at_10_values)),
        "hit@10_std": float(np.std(hit_at_10_values, ddof=0)),
        "MR_mean": float(np.mean(mr_values)),
        "MR_std": float(np.std(mr_values, ddof=0)),
    }
