import pandas as pd
import json
import os


class PointerExperimentAnalyzer:
    """
    Analyzes the output files from the SMILES Transformer interpretability experiments.
    """

    def __init__(self, results_dir: str):
        """
        Initializes the analyzer with the path to the main results directory.

        Args:
            results_dir: The root directory containing 'suite', 'stability', etc.
        """
        self.base_dir = results_dir
        if not os.path.isdir(self.base_dir):
            raise FileNotFoundError(
                f"The provided results directory does not exist: {self.base_dir}"
            )

        # Define paths to the key output files
        self.suite_dir = os.path.join(self.base_dir, "suite")
        self.stability_dir = os.path.join(self.base_dir, "stability")

        self.summary_file = os.path.join(self.suite_dir, "pointer_suite_summary.csv")
        self.scatter_file = os.path.join(
            self.suite_dir, "pointer_vs_global_scatter.csv"
        )
        self.redundancy_file = os.path.join(self.suite_dir, "redundancy_table.csv")
        self.robustness_file = os.path.join(self.suite_dir, "robustness_curves.json")
        self.probe_file = os.path.join(self.suite_dir, "value_probe.json")
        self.stability_file = os.path.join(self.stability_dir, "stability_summary.json")

    def _read_csv(self, file_path):
        """Safely reads a CSV file into a pandas DataFrame."""
        if not os.path.exists(file_path):
            print(f"Warning: File not found at {file_path}. Skipping related analysis.")
            return None
        return pd.read_csv(file_path)

    def _read_json(self, file_path):
        """Safely reads a JSON file into a dictionary."""
        if not os.path.exists(file_path):
            print(f"Warning: File not found at {file_path}. Skipping related analysis.")
            return None
        with open(file_path, "r") as f:
            return json.load(f)

    def find_top_pointer_heads(self, top_n: int = 5):
        """
        Identifies the most important pointer heads based on pointer mass and causal effects.
        """
        print("\n--- 1. Identifying Top Pointer Heads ---")
        df = self._read_csv(self.summary_file)
        if df is None:
            return None, None

        # Add a column for head ID string 'L<layer>H<head>'
        df["head_id"] = "L" + df["layer"].astype(str) + "H" + df["head"].astype(str)

        # Find top heads by pointer mass
        top_ring_heads = df.sort_values("pointer_ring", ascending=False).head(top_n)
        top_paren_heads = df.sort_values("pointer_paren", ascending=False).head(top_n)

        # Also find top heads by causal effect (drop in accuracy)
        top_causal_ring = df.sort_values("delta_acc_ring", ascending=False).head(top_n)
        top_causal_paren = df.sort_values("delta_acc_paren", ascending=False).head(
            top_n
        )

        print(f"\nTop {top_n} Ring Pointer Heads (by Pointer Mass):")
        print(
            top_ring_heads[
                ["head_id", "pointer_ring", "delta_margin_ring", "delta_acc_ring"]
            ].to_string(index=False)
        )

        print(f"\nTop {top_n} Parenthesis Pointer Heads (by Pointer Mass):")
        print(
            top_paren_heads[
                ["head_id", "pointer_paren", "delta_margin_paren", "delta_acc_paren"]
            ].to_string(index=False)
        )

        print("\nInsight:")
        print(
            "The tables above show the heads with the highest 'Pointer Mass'—the attention they pay to the opening token."
        )
        print(
            "High 'delta_margin_ring' or 'delta_acc_ring' values confirm these heads are causally necessary for correct predictions."
        )

        return top_ring_heads, top_paren_heads

    def analyze_stability(self):
        """
        Analyzes the stability of pointer heads across different training checkpoints.
        """
        print("\n--- 2. Analyzing Pointer Head Stability ---")
        data = self._read_json(self.stability_file)
        if data is None:
            return

        print(f"Comparison between checkpoint {data['ckpt_a']} and {data['ckpt_b']}:")
        print(f"  - Spearman Correlation (Rings):   {data['spearman_ring']:.4f}")
        print(f"  - Spearman Correlation (Parens):  {data['spearman_paren']:.4f}")
        print(f"  - Jaccard Similarity @5 (Rings):  {data['jaccard5_ring']:.4f}")
        print(f"  - Jaccard Similarity @5 (Parens): {data['jaccard5_paren']:.4f}")

        print("\nInsight:")
        print(
            "High Spearman correlation (close to 1.0) means the ranking of heads is consistent across checkpoints."
        )
        print(
            "High Jaccard similarity (close to 1.0) means the set of top-performing heads is nearly identical."
        )
        print(
            "Together, high scores indicate that these pointer heads are a stable, convergent feature learned by the model."
        )

    def analyze_specificity(self, top_ring_heads, top_paren_heads):
        """
        Analyzes if pointer heads are specialized for their task or have global effects.
        """
        print("\n--- 3. Analyzing Head Specificity ---")
        df = self._read_csv(self.scatter_file)
        if df is None:
            return

        print(
            "Comparing performance drop on pointer events vs. random 'control' tokens."
        )

        if top_ring_heads is not None:
            print("\nTop Ring Heads:")
            ring_head_ids = top_ring_heads["head"].tolist()
            ring_layer_ids = top_ring_heads["layer"].tolist()

            scatter_ring = df[
                df["layer"].isin(ring_layer_ids) & df["head"].isin(ring_head_ids)
            ]
            scatter_ring = scatter_ring[
                scatter_ring["dmargin_ring"] > 0
            ]  # Focus on ring-related effects
            scatter_ring["specificity_ratio"] = scatter_ring["dmargin_ring"] / (
                scatter_ring["dmargin_ctrl"] + 1e-9
            )
            print(
                scatter_ring[
                    [
                        "layer",
                        "head",
                        "dmargin_ring",
                        "dmargin_ctrl",
                        "specificity_ratio",
                    ]
                ].to_string(index=False)
            )

        if top_paren_heads is not None:
            print("\nTop Parenthesis Heads:")
            paren_head_ids = top_paren_heads["head"].tolist()
            paren_layer_ids = top_paren_heads["layer"].tolist()

            scatter_paren = df[
                df["layer"].isin(paren_layer_ids) & df["head"].isin(paren_head_ids)
            ]
            scatter_paren = scatter_paren[scatter_paren["dmargin_paren"] > 0]
            scatter_paren["specificity_ratio"] = scatter_paren["dmargin_paren"] / (
                scatter_paren["dmargin_ctrl"] + 1e-9
            )
            print(
                scatter_paren[
                    [
                        "layer",
                        "head",
                        "dmargin_paren",
                        "dmargin_ctrl",
                        "specificity_ratio",
                    ]
                ].to_string(index=False)
            )

        print("\nInsight:")
        print("A high 'specificity_ratio' (>> 1.0) indicates a specialized head.")
        print(
            "This means ablating the head has a much larger negative impact on pointer events than on other tokens,"
        )
        print("confirming it has a dedicated, rather than a general-purpose, role.")

    def analyze_redundancy(self):
        """
        Analyzes the redundancy between the top two pointer heads for each task.
        """
        print("\n--- 4. Analyzing Head Redundancy ---")
        df = self._read_csv(self.redundancy_file)
        if df is None:
            return

        print(df.to_string(index=False))

        print("\nInsight:")
        print(
            "The 'redundancy_index' measures the overlap in function between the top two heads."
        )
        print(
            "  - An index close to 1.0 means their effects are additive (low redundancy)."
        )
        print(
            "  - An index close to 0.5 suggests strong redundancy, as ablating both is not much worse than ablating one."
        )
        print(
            "This tells us if the model has a single 'backup' head or multiple independent heads for the task."
        )

    def analyze_value_probe(self):
        """
        Analyzes what information is being moved by the pointer heads' value vectors.
        """
        print("\n--- 5. Analyzing Value-Stream Probes ---")
        data = self._read_json(self.probe_file)
        if data is None:
            return

        print(f"Probing head '{data['ring_head']}' for Rings:")
        print(
            f"  - Pre-W_o Accuracy:  {data['ring_pre']['acc']:.3f} (Can we decode the ring ID from the value vector?)"
        )
        print(
            f"  - Post-W_o Accuracy: {data['ring_post']['acc']:.3f} (After projection by the output matrix)"
        )

        print(f"\nProbing head '{data['paren_head']}' for Parentheses:")
        print(
            f"  - Pre-W_o Accuracy:  {data['paren_pre']['acc']:.3f} (Can we decode paren depth from the value vector?)"
        )
        print(
            f"  - Post-W_o Accuracy: {data['paren_post']['acc']:.3f} (After projection)"
        )

        print("\nInsight:")
        print(
            "High probe accuracy demonstrates that the head's value vector contains specific information"
        )
        print(
            "about the *opening* token (e.g., ring digit '1' or '2'). This confirms the head is actively"
        )
        print(
            "copying this information to the closing position, not just pointing to it."
        )

    def analyze_robustness(self):
        """
        Summarizes how pointer head performance changes with distance/depth.
        """
        print("\n--- 6. Analyzing Robustness to Distance ---")
        data = self._read_json(self.robustness_file)
        if data is None:
            return

        print(f"Robustness for Ring Head '{data['ring_head']}':")
        for i, (start, end) in enumerate(
            zip(data["ring_bins"][:-1], data["ring_bins"][1:])
        ):
            mass = data["ring_pointer_mass"][i]
            margin = data["ring_delta_margin"][i]
            print(
                f"  - Span {start}-{end} tokens: Pointer Mass = {mass:.3f}, Delta Margin = {margin:.3f}"
            )

        print(f"\nRobustness for Parenthesis Head '{data['paren_head']}':")
        for i, (start, end) in enumerate(
            zip(data["paren_bins"][:-1], data["paren_bins"][1:])
        ):
            mass = data["paren_pointer_mass"][i]
            margin = data["paren_delta_margin"][i]
            print(
                f"  - Depth {start}-{end}: Pointer Mass = {mass:.3f}, Delta Margin = {margin:.3f}"
            )

        print("\nInsight:")
        print(
            "This analysis shows if the heads' performance degrades over long distances or deep nesting."
        )
        print(
            "A gradual decline in 'Pointer Mass' and 'Delta Margin' is expected. A sharp drop-off would indicate"
        )
        print(
            "a limitation in the model's ability to handle long-range dependencies for this task."
        )
        print("You can plot these values against the bins to visualize the trend.")

    def run_full_analysis(self):
        """
        Runs all analysis methods in sequence and prints a full report.
        """
        print("=" * 60)
        print("Running Full Analysis of Pointer Head Experiments")
        print("=" * 60)

        top_ring, top_paren = self.find_top_pointer_heads()
        self.analyze_stability()
        self.analyze_specificity(top_ring, top_paren)
        self.analyze_redundancy()
        self.analyze_value_probe()
        self.analyze_robustness()

        print("\n" + "=" * 60)
        print("Analysis Complete.")
        print("=" * 60)


if __name__ == "__main__":
    # --- USAGE EXAMPLE ---
    # 1. Place this script in a directory next to your 'experiments' folder.
    # 2. Set the RESULTS_DIRECTORY to the path containing your experiment outputs.

    # Example: if your outputs are in './outputs/suite', './outputs/stability', etc.
    # then RESULTS_DIRECTORY should be './outputs'
    RESULTS_DIRECTORY = "./"  # Assuming 'suite' and 'stability' are in the current dir

    try:
        analyzer = PointerExperimentAnalyzer(RESULTS_DIRECTORY)
        analyzer.run_full_analysis()
    except FileNotFoundError as e:
        print(f"\nError: {e}")
        print("Please ensure the RESULTS_DIRECTORY is set correctly and points to the")
        print("folder containing the 'suite' and 'stability' subdirectories.")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
