import os
import json
import pandas as pd
import numpy as np
import glob
import sys
from collections import defaultdict

# --- Helper Functions ---


def find_file(base_path, pattern):
    """Finds a single file matching a pattern, returns None if not found."""
    files = glob.glob(os.path.join(base_path, pattern))
    return files[0] if files else None


def print_header(title):
    """Prints a formatted header."""
    print("\n" + "=" * 80)
    print(f"## {title}")
    print("=" * 80)


def print_subheader(title):
    """Prints a formatted subheader."""
    print("\n" + "-" * 30)
    print(f"# {title}")
    print("-" * 30)


# --- Analysis Functions ---


def analyze_representation(base_path):
    """
    Analyzes probe accuracy to determine WHERE and HOW WELL valence is represented.
    """
    print_header("1. Representation: How well is valence encoded across layers?")

    probe_file = find_file(base_path, "budget/valence_probe_by_layer.csv")
    if not probe_file:
        print("!! Valence probe file not found. Skipping representation analysis.")
        return None

    probe_df = pd.read_csv(probe_file)
    if probe_df.empty:
        print("!! Valence probe file is empty.")
        return None

    # Find the layer with the highest accuracy
    best_layer_idx = probe_df["acc"].idxmax()
    best_layer = probe_df.iloc[best_layer_idx]

    print(
        "Linear probes were trained to predict remaining atomic valence from each layer's activations."
    )
    print("\nProbe Accuracy by Layer:")
    print(probe_df[["layer", "acc", "f1_macro"]].to_string(index=False))

    print(
        f"\n*   Key Finding: The model's representation of valence is most linearly accessible in Layer {int(best_layer['layer'])},"
    )
    print(f"    achieving a probe accuracy of {best_layer['acc']:.2%}.")

    return probe_df


def analyze_localization(base_path, num_layers):
    """
    Analyzes head alignment and ablation results to find key 'valence heads'.
    """
    print_header("2. Localization: Which attention heads are responsible for valence?")
    print(
        "This analysis identifies heads that are both aligned with the valence direction and causally important."
    )

    all_head_metrics = []

    for i in range(num_layers):
        layer_path = os.path.join(base_path, f"L{i}")

        # Load alignment and ablation data
        align_file = find_file(layer_path, "localize/localize_head_alignment_L*.csv")
        ablate_file = find_file(layer_path, "localize/localize_head_ablation_L*.csv")

        if not align_file or not ablate_file:
            continue

        align_df = pd.read_csv(align_file)
        ablate_df = pd.read_csv(ablate_file)

        # The ablation file contains drop in logits for different bond types.
        # We create a single metric for causal importance: the average drop across bond types.
        ablate_df["avg_logit_drop"] = ablate_df[
            ["drop_minus", "drop_eq", "drop_hash"]
        ].mean(axis=1)

        # Merge the dataframes
        head_df = pd.merge(align_df, ablate_df[["head", "avg_logit_drop"]], on="head")

        # Normalize metrics to be in a similar range for scoring
        head_df["align_score"] = head_df["cos_mean"] * head_df["frac_positive"]
        head_df["ablation_score"] = head_df["avg_logit_drop"]

        # Create a combined score to rank heads
        head_df["combined_score"] = head_df["align_score"] * head_df["ablation_score"]

        all_head_metrics.append(head_df)

    if not all_head_metrics:
        print("\n!! Localization files not found. Skipping analysis.")
        return None

    full_df = pd.concat(all_head_metrics)

    # Find and print top 3 heads overall
    top_overall = full_df.sort_values("combined_score", ascending=False).head(3)

    print_subheader("Top 3 Valence Heads (Overall)")
    print(
        top_overall[
            [
                "layer",
                "head",
                "combined_score",
                "align_score",
                "ablation_score",
                "cos_mean",
            ]
        ].to_string(index=False)
    )

    print("\n*   Key Finding: A small number of heads appear to specialize in valence.")
    best_head = top_overall.iloc[0]
    print(
        f"    The most influential is L{int(best_head['layer'])}H{int(best_head['head'])}, showing strong alignment and high causal importance."
    )

    return full_df


def analyze_causality(base_path, num_layers):
    """
    Analyzes causal intervention results (steering experiments).
    """
    print_header("3. Causality: Can we control valence-related predictions?")
    print(
        "This analysis measures the effect of directly manipulating the valence representation."
    )

    causality_results = []
    decision_results = []

    for i in range(num_layers):
        layer_path = os.path.join(base_path, f"L{i}")

        causality_file = find_file(layer_path, "causality/valence_causality_L*.json")
        decision_file = find_file(layer_path, "decisions/decision_metrics_L*.csv")

        if causality_file:
            with open(causality_file, "r") as f:
                data = json.load(f)
                for res in data["results"]:
                    res["layer"] = data["layer"]
                    causality_results.append(res)

        if decision_file:
            decision_df = pd.read_csv(decision_file)
            # Get the result for the strongest intervention (max alpha)
            strongest_intervention = decision_df.loc[
                decision_df["alpha"].abs().idxmax()
            ]
            decision_results.append(strongest_intervention)

    if not causality_results or not decision_results:
        print("\n!! Causality or decision metrics files not found. Skipping analysis.")
        return None

    causality_df = pd.DataFrame(causality_results)
    decision_df = pd.DataFrame(decision_results)

    # Find the layer with the strongest causal effect
    # Effect = difference in logit change between single and double/triple bonds
    causality_df["effect_strength"] = (
        causality_df["dlogit_double"] + causality_df["dlogit_triple"]
    ) / 2 - causality_df["dlogit_single"]

    # We look at positive alpha, which should suppress single bonds (negative dlogit) and promote others (positive dlogit)
    # So we want the most negative effect_strength for a positive alpha
    strongest_effect_idx = causality_df[causality_df["alpha"] > 0][
        "effect_strength"
    ].idxmin()
    best_causal_layer_row = causality_df.loc[strongest_effect_idx]
    best_causal_layer = int(best_causal_layer_row["layer"])

    print_subheader("Causal Effect of Steering Valence Representation (at alpha=2.0)")
    summary_table = causality_df[causality_df["alpha"] == 2.0][
        ["layer", "dlogit_single", "dlogit_double", "dlogit_triple"]
    ]
    print(summary_table.to_string(index=False))

    print(
        f"\n*   Key Finding: The causal mechanism is strongest in Layer {best_causal_layer}."
    )
    print(
        "    Directly increasing the 'remaining valence' signal in this layer systematically"
    )
    print(
        "    suppresses single bond predictions and promotes double/triple bond predictions."
    )

    print_subheader("Impact on Model Decisions")
    print(
        "This shows the change in logit margin and prediction 'switch rate' when steering."
    )
    print(
        decision_df[
            ["layer", "alpha", "dmargin3_event", "switch_rate_event", "dmargin3_ctrl"]
        ].to_string(index=False)
    )

    best_decision_row = decision_df[decision_df["layer"] == best_causal_layer].iloc[0]
    print(
        f"\n*   At Layer {best_causal_layer}, this intervention increased the model's prediction margin by {best_decision_row['dmargin3_event']:.3f}"
    )
    print(
        f"    and caused it to change its top prediction {best_decision_row['switch_rate_event']:.1%} of the time on valence-critical decisions."
    )
    print(
        f"    The effect on control tokens was minimal ({best_decision_row['dmargin3_ctrl']:.3f}), confirming specificity."
    )

    return causality_df, decision_df


def summarize_findings(probe_df, loc_df, cause_df):
    """
    Creates a final summary synthesizing all results.
    """
    print_header("4. Overall Summary and Narrative")

    if probe_df is not None:
        best_probe_layer = int(probe_df["acc"].idxmax())
        best_probe_acc = probe_df["acc"].max()
        print(
            f"1.  **Representation**: The model clearly learns a representation of atomic valence. This concept is most"
        )
        print(
            f"    explicitly and linearly encoded in **Layer {best_probe_layer}**, where a simple probe can predict the"
        )
        print(f"    remaining valence with **{best_probe_acc:.2%} accuracy**.")
    else:
        print("1.  **Representation**: Analysis incomplete (probe data missing).")

    if loc_df is not None:
        top_head = loc_df.sort_values("combined_score", ascending=False).iloc[0]
        th_layer, th_head = int(top_head["layer"]), int(top_head["head"])
        print(
            f"\n2.  **Localization**: This function is not distributed; it is localized to specific attention heads."
        )
        print(
            f"    The most critical 'valence head' identified is **L{th_layer}H{th_head}**, which shows a powerful combination"
        )
        print(
            f"    of alignment to the valence concept and causal impact on predictions."
        )
    else:
        print(
            "\n2.  **Localization**: Analysis incomplete (localization data missing)."
        )

    if cause_df is not None:
        cause_df["effect_strength"] = (
            cause_df["dlogit_double"] + cause_df["dlogit_triple"]
        ) / 2 - cause_df["dlogit_single"]
        strongest_effect_idx = cause_df[cause_df["alpha"] > 0][
            "effect_strength"
        ].idxmin()
        best_causal_layer = int(cause_df.loc[strongest_effect_idx]["layer"])

        print(
            f"\n3.  **Causality**: The identified mechanism is causal. By manipulating a 'valence direction' vector,"
        )
        print(
            f"    we can predictably control the model's outputs. This effect is strongest in **Layer {best_causal_layer}**,"
        )
        print(
            f"    confirming that this layer is a key locus for valence-based reasoning."
        )
    else:
        print("\n3.  **Causality**: Analysis incomplete (causality data missing).")

    print(
        "\n**Narrative**: The transformer learns a sophisticated, multi-layer mechanism for handling chemical valence."
    )
    print(
        "Valence information is progressively refined, becoming most explicit and causally tractable in the middle layers"
    )
    print(
        "of the network. Within these layers, a small subset of specialized attention heads are responsible for writing"
    )
    print(
        "valence information to the residual stream, directly influencing the model's decisions to follow chemical rules."
    )


def main():
    """
    Main function to run the analysis script.
    """
    if len(sys.argv) < 2:
        print("Usage: python analyze_results.py <path_to_valence_suite_all_layers>")
        sys.exit(1)

    base_path = sys.argv[1]
    if not os.path.isdir(base_path):
        print(f"Error: Directory not found at '{base_path}'")
        sys.exit(1)

    # Infer number of layers by finding L* directories
    layer_dirs = glob.glob(os.path.join(base_path, "L[0-9]*"))
    if not layer_dirs:
        print(f"Error: No layer directories (e.g., 'L0', 'L1') found in '{base_path}'")
        sys.exit(1)
    num_layers = len(layer_dirs)
    print(f"Detected {num_layers} layers of experimental results.")

    # Run analyses
    probe_results = analyze_representation(base_path)
    localization_results = analyze_localization(base_path, num_layers)
    causality_causality_df, causality_decision_df = analyze_causality(
        base_path, num_layers
    )

    # Final summary
    summarize_findings(probe_results, localization_results, causality_causality_df)


if __name__ == "__main__":
    main()
