import os
import json
import yaml
import numpy as np
import sys
import argparse
import matplotlib.pyplot as plt
import networkx as nx

from pathlib import Path
from sklearn.metrics import PrecisionRecallDisplay, f1_score
from scipy.stats import spearmanr

import sys
sys.path.append('../')

from src.xgrag.explainer.encoder import Encoder
from src.xgrag.explainer.embedding_comparator import EmbeddingComparator

os.environ['no_proxy'] = '127.0.0.1'
os.environ['NO_PROXY'] = '127.0.0.1'

CONFIG_PATH = Path(__file__).resolve().parent.parent / "config.yaml"

def load_config(config_path):
    """Load the YAML configuration file."""
    with open(config_path, 'r', encoding='utf-8') as f:
        return yaml.safe_load(f)

def importance(embedding_comparator,output_original, output_perturbed):
    # compute similarity scores
    scores = embedding_comparator.compare(output_original, output_perturbed)
    return scores

def evaluation(embedding_comparator, input_features, output_original):
    # compute eval similarity score
    scores_eval = embedding_comparator.compare_eval(output_original,input_features)
    return scores_eval

def create_graph_from_context(context_dict: dict) -> nx.DiGraph:
    """
    Creates a directed networkx graph from a context dictionary.
    """
    G = nx.DiGraph()
    entities = context_dict.get('entities_context', [])
    if not entities:
        print("Warning: 'entities_context' key not found or is empty in the context dict.")

    for entity_data in entities:
        entity_name = entity_data.get('entity')
        if entity_name:
            G.add_node(entity_name)

    relations = context_dict.get('relations_context', [])
    if not relations:
        print("Warning: 'relations_context' key not found or is empty in the context dict.")

    for rel_data in relations:
        source = rel_data.get('entity1')
        target = rel_data.get('entity2')
        if source and target:
            # Nodes are added implicitly if they don't exist
            G.add_edge(source, target)
    return G

def interpret_correlation(r_value: float) -> str:
    """Interprets the strength of a correlation coefficient based on its absolute value."""
    r_abs = abs(r_value)
    if r_abs <= 0.19:
        return "Very weak"
    elif r_abs <= 0.39:
        return "Weak"
    elif r_abs <= 0.59:
        return "Moderate"
    elif r_abs <= 0.79:
        return "Strong"
    else: # r_abs <= 1.00
        return "Very strong"

def evaluate_top_k(input_features, predicted_scores, ground_truth_scores, k_percent):
    """
    Compares the top-k percent predicted important features with the top-k percent ground truth features.
    This comparison is order-agnostic as it relies on set intersection.
    """
    num_features = len(input_features)
    if not (0 < k_percent <= 100):
        print(f"Warning: k_percent ({k_percent}%) is not in the range (0, 100]. Skipping Top-K evaluation.")
        return

    # Calculate the absolute number of items for the top-k percentage
    k = max(1, int(num_features * (k_percent / 100.0)))

    print(f"\n--- Top-{k_percent}% (Top-{k}) Feature Comparison ---")

    # Predicted importance: Higher score is more important.
    # We sort descending by score and take the first k features to form a set.
    predicted_indices_sorted_desc = np.argsort(predicted_scores)[::-1]
    top_k_predicted_indices = predicted_indices_sorted_desc[:k]
    top_k_predicted_features = {input_features[i] for i in top_k_predicted_indices}

    # Ground truth importance: Higher score is more important (more similar to output)
    # We sort descending by score and take the first k features to form a set.
    ground_truth_indices_sorted_desc = np.argsort(ground_truth_scores)[::-1]
    top_k_ground_truth_indices = ground_truth_indices_sorted_desc[:k]
    top_k_ground_truth_features = {input_features[i] for i in top_k_ground_truth_indices}

    print("\nTop-k Predicted Feature Ranks (0 is best, higher rank is worse):")
    for rank, feature_idx in enumerate(top_k_predicted_indices):
        print(f"  - Rank {rank}: Feature '{input_features[feature_idx]}' (Index: {feature_idx}, Score: {predicted_scores[feature_idx]:.4f})")

    print("\nTop-k Ground Truth Feature Ranks (0 is best, higher rank is worse):")
    for rank, feature_idx in enumerate(top_k_ground_truth_indices):
        print(f"  - Rank {rank}: Feature '{input_features[feature_idx]}' (Index: {feature_idx}, Score: {ground_truth_scores[feature_idx]:.4f})")

    print(f"\nTop-{k} ({k_percent}%) Predicted Important Features (Set): {sorted(list(top_k_predicted_features))}")
    print(f"Top-{k} ({k_percent}%) Ground Truth Important Features (Set): {sorted(list(top_k_ground_truth_features))}")

    # Calculate the intersection of the two sets. This is an order-agnostic operation.
    intersection = top_k_predicted_features.intersection(top_k_ground_truth_features)
    print(f"Intersection ({len(intersection)} items): {sorted(list(intersection))}")

    # Precision@k measures the fraction of predicted top-k items that are correct.
    precision_at_k = len(intersection) / k if k > 0 else 0
    print(f"Precision@{k}: {precision_at_k:.2f}")

def evaluate_degree_centrality_correlation(input_features, predicted_scores, graph: nx.Graph):
    """
    Calculates the correlation between feature importance and graph degree centrality.
    This is only applicable when features are graph nodes (i.e., entities).
    """
    print("\n--- Degree Centrality vs. Importance Score Correlation ---")

    # Check if features are likely entities (nodes) by checking for relation format
    if any(" -> " in str(feature) for feature in input_features):
        print("Skipping degree centrality evaluation: Features appear to be relations, not nodes.")
        return

    if not graph.nodes:
        print("Skipping degree centrality evaluation: The graph has no nodes.")
        return

    # Calculate degree centrality for all nodes. For DiGraph, this is in-degree + out-degree.
    degree_centrality = nx.degree_centrality(graph)
    centrality_scores = [degree_centrality.get(feature, 0) for feature in input_features]

    if len(centrality_scores) < 2 or len(predicted_scores) < 2:
        print("Skipping correlation: Not enough data points.")
        return

    correlation, p_value = spearmanr(predicted_scores, centrality_scores)

    interpretation = interpret_correlation(correlation)
    print(f"Spearman's Rank Correlation between Importance and Degree Centrality: {correlation:.4f} ({interpretation})")
    print(f"P-value: {p_value:.4f}")
    if p_value < 0.05: # type: ignore
        print("The correlation is statistically significant, suggesting a relationship between a node's connectedness and its importance score.\n")
    else:
        print("The correlation is not statistically significant.\n")

def evaluate_pagerank_correlation(input_features, predicted_scores, graph: nx.DiGraph):
    """
    Calculates the correlation between feature importance and graph PageRank.
    This is only applicable when features are graph nodes (i.e., entities).
    """
    print("\n--- PageRank vs. Importance Score Correlation ---")

    # Check if features are likely entities (nodes) by checking for relation format
    if any(" -> " in str(feature) for feature in input_features):
        print("Skipping PageRank evaluation: Features appear to be relations, not nodes.")
        return

    if not graph.nodes:
        print("Skipping PageRank evaluation: The graph has no nodes.")
        return

    # Calculate PageRank for all nodes.
    try:
        pagerank = nx.pagerank(graph)
        pagerank_scores = [pagerank.get(feature, 0) for feature in input_features]

        if len(pagerank_scores) < 2 or len(predicted_scores) < 2:
            print("Skipping correlation: Not enough data points.")
            return

        correlation, p_value = spearmanr(predicted_scores, pagerank_scores)
        interpretation = interpret_correlation(correlation)
        print(f"Spearman's Rank Correlation between Importance and PageRank: {correlation:.4f} ({interpretation})")
        print(f"P-value: {p_value:.4f}")
        print("The correlation is " + ("" if p_value < 0.05 else "not ") + "statistically significant, suggesting a relationship between a node's PageRank and its importance score.\n") # type: ignore
    except nx.PowerIterationFailedConvergence:
        print("Skipping PageRank evaluation: PageRank algorithm did not converge.")

def calculate_reciprocal_rank_for_top_feature(input_features, predicted_scores, ground_truth_scores):
    """
    Calculates the Reciprocal Rank for the single most important ground truth feature.
    This helps evaluate how well the predicted ranking places the most important item.
    A higher score is considered more important (higher rank).

    Returns:
        tuple[float, int | None]: A tuple containing:
            - The reciprocal rank of the top ground truth feature.
            - The 0-indexed predicted rank of the top feature (or None if not found).
    """
    print("\n--- Reciprocal Rank Calculation for Top Feature ---")

    # Find the index of the feature with the highest ground truth score.
    ground_truth_top_feature_idx = np.argmax(ground_truth_scores)
    print(f"\nTop ground truth feature: '{input_features[ground_truth_top_feature_idx]}' (Index: {ground_truth_top_feature_idx}, Score: {ground_truth_scores[ground_truth_top_feature_idx]:.4f})")

    # Get the predicted ranking of features. Higher score is better.
    # `argsort` sorts ascending, so we reverse it to get descending ranks.
    predicted_indices_sorted_desc = np.argsort(predicted_scores)[::-1]

    # --- Competition Ranking for Tied Scores (e.g., 1, 2, 2, 4) ---
    # Create an array to hold the rank of each feature at its original index.
    feature_ranks = np.zeros(len(predicted_scores), dtype=int)
    if len(predicted_indices_sorted_desc) > 0:
        # Assign rank 0 to the top-scoring feature.
        # The rank is the 0-based index in the sorted list.
        feature_ranks[predicted_indices_sorted_desc[0]] = 0
        # Iterate through the rest of the sorted features.
        for i in range(1, len(predicted_indices_sorted_desc)):
            current_idx = predicted_indices_sorted_desc[i]
            previous_idx = predicted_indices_sorted_desc[i-1]
            # If the score is the same as the previous, it gets the same rank.
            if predicted_scores[current_idx] == predicted_scores[previous_idx]:
                feature_ranks[current_idx] = feature_ranks[previous_idx]
            # If the score is different, it gets the rank of its position in the sorted list.
            else:
                feature_ranks[current_idx] = i

    print("\nFull Predicted Feature Ranking (0 is best):")
    for feature_idx in predicted_indices_sorted_desc:
        rank = feature_ranks[feature_idx]
        print(f"  - Rank {rank}: Feature '{input_features[feature_idx]}' (Index: {feature_idx}, Score: {predicted_scores[feature_idx]:.4f})")

    reciprocal_rank = 0.0
    print() # Add a newline for better spacing before the summary.
    predicted_rank = feature_ranks[ground_truth_top_feature_idx]
    reciprocal_rank = 1 / (predicted_rank + 1)
    print(f"Top ground truth feature is at predicted rank (0-indexed, 0 is best): {predicted_rank}")
    
    print(f"Reciprocal Rank: {reciprocal_rank:.4f}")
    return reciprocal_rank, predicted_rank

def calculate_and_report_metrics(input_features, scores, scores_eval, experiment_dir: Path, f1_thresholds):

    # TODO: find best thresholds
    # positive_threshold = np.mean(scores)*1.2        #1.2=60%, 1.4=70%
    # correct_threshold = np.mean(scores_eval)*1.2    #1.2=60%, 1.4=70%
    positive_threshold = f1_thresholds["importance"]
    correct_threshold = f1_thresholds["evaluation"]

    fp, fn, tp, tn = 0, 0, 0, 0
    y_true_for_plot = []
    y_pred_for_plot = scores

    print("Feature, Importance Score, Ground Truth Score")
    for feature, score, score_eval in zip(input_features, scores, scores_eval):
        print(f"  - {feature}: {score:.4f}, {score_eval:.4f}")

        # Determine the ground truth label for both F1 and the plot
        true_label = 1 if score_eval > correct_threshold else 0
        y_true_for_plot.append(true_label)

        # Determine the predicted label for F1 score calculation
        predicted_label = 1 if score > positive_threshold else 0

        # Update confusion matrix for F1 score
        if predicted_label == 1 and true_label == 1:
            tp += 1
        elif predicted_label == 1 and true_label == 0:
            fp += 1
        elif predicted_label == 0 and true_label == 1:
            fn += 1
        elif predicted_label == 0 and true_label == 0:
            tn += 1

    print(f"\nConfusion Matrix (at importance threshold {positive_threshold}, ground truth threshold {correct_threshold}):")
    print(f"TP: {tp}, FP: {fp}, FN: {fn}, TN: {tn}")

    # Safely calculate F1 score
    denominator = 2 * tp + fp + fn
    if denominator > 0:
        f1 = (2 * tp) / denominator
        print(f"F1 Score: {f1:.4f}")
    else:
        print("Cannot calculate F1-score, denominator is zero.")

    # --- Reciprocal Rank Calculation ---
    calculate_reciprocal_rank_for_top_feature(input_features, scores, scores_eval)
    
    # --- Plotting ---
    # Now we have the ground truth labels and the prediction scores needed for the plot.
    print("\nPlotting Precision-Recall Curve...")
    display = PrecisionRecallDisplay.from_predictions(
        y_true=y_true_for_plot, y_pred=y_pred_for_plot, name="RAG-EX Feature Importance"
    )
    _ = display.ax_.set_title("Feature Importance Precision-Recall Curve")
    try:
        display = PrecisionRecallDisplay.from_predictions(
            y_true=y_true_for_plot, y_pred=y_pred_for_plot, name="RAG-EX Feature Importance"
        )
        _ = display.ax_.set_title("Feature Importance Precision-Recall Curve")
        plot_path = experiment_dir / "explainer_pr_curve.png"
        plt.savefig(plot_path)
        plt.close()  # Close the plot to free up memory and prevent it from being displayed in non-UI contexts.
        print(f"Precision-Recall curve saved to {plot_path}")
    except Exception as e:
        print(f"Could not generate or save plot: {e}")


def explainer(experiment_dir: Path):

    output_file_path = experiment_dir / "explainer_results_final.txt"
    print(f"INFO: Saving explainer output to {output_file_path.resolve()}")

    original_stdout = sys.stdout
    with open(output_file_path, 'w') as f:
        sys.stdout = f
        try:
            config = load_config(CONFIG_PATH)
            explainer_cfg = config['explainer']
            # Construct paths relative to the config file's location
            initial_answer_path = experiment_dir  / config['initial_answer_path_template'].format(experiment_dir=experiment_dir)
            perturbed_answers_path = experiment_dir / config['perturbed_answers_path_template'].format(experiment_dir=experiment_dir)
            with open(initial_answer_path, 'r') as f:
                original_output = f.read()
            print(f'Output Original: {original_output}\n')

            with open(perturbed_answers_path,'r') as f:
                perturbed_results = json.load(f)
                input_features = list(perturbed_results.keys())
                perturbed_outputs = list(perturbed_results.values())

            for perturbed_context, perturbed_output in perturbed_results.items():
                print(f"Perturbed context: {perturbed_context} \nOutput Perturbed: {perturbed_output}\n\n")

            encoder_model = config['lightrag']['embedding']['model']
            encoder = Encoder(model_name=encoder_model)
            embedding_comparator = EmbeddingComparator(encoder)
            scores = importance(embedding_comparator, original_output, perturbed_outputs)
            scores_eval = evaluation(embedding_comparator, input_features, original_output)

            # --- Degree Centrality Correlation ---
            initial_context_deduplicated_path = experiment_dir / config['initial_context_deduplicated_path_template'].format(experiment_dir=experiment_dir)
            if initial_context_deduplicated_path.exists():
                with open(initial_context_deduplicated_path, 'r', encoding='utf-8') as f_ctx:
                    context_dict = json.load(f_ctx)
                graph = create_graph_from_context(context_dict)
                evaluate_degree_centrality_correlation(input_features, scores, graph)
                evaluate_pagerank_correlation(input_features, scores, graph)
            else:
                print(f"Warning: Deduplicated context file not found at {initial_context_deduplicated_path}. Skipping degree centrality evaluation.")

            f1_thresholds = explainer_cfg['f1_thresholds']
            calculate_and_report_metrics(input_features, scores, scores_eval, experiment_dir, f1_thresholds)

            # --- Top-K Comparison ---
            top_k_values = explainer_cfg.get('top_k_eval', [5]) # Default to [5] if not in config
            if isinstance(top_k_values, int): # Handle case where it's a single number for backward compatibility
                top_k_values = [top_k_values]
            for k in top_k_values:
                evaluate_top_k(input_features, scores, scores_eval, k)
        finally:
            sys.stdout = original_stdout

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Run the RAG-Ex explainer on experiment results.")
    parser.add_argument("experiment_dir", type=str, help="The path to the experiment directory to be analyzed.")
    args = parser.parse_args()

    exp_dir_path = Path(args.experiment_dir)
    if not exp_dir_path.is_dir():
        print(f"Error: Experiment directory not found at '{exp_dir_path}'")
        exit(1)

    explainer(exp_dir_path)
