# experiments/summarize_importance.py
import argparse
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path


class ImportanceAnalyzer:
    """
    Analyzes the aggregated feature importance results to answer key research questions.
    """

    def __init__(self, summary_file: str, output_dir: str, top_n: int = 20):
        """
        Initializes the analyzer with the summary data.

        Args:
            summary_file: Path to the feature_importance_summary.csv file.
            output_dir: Directory to save plots and summary CSVs.
            top_n: The number of top features to focus on for most analyses.
        """
        self.df = pd.read_csv(summary_file)
        self.output_dir = Path(output_dir)
        self.top_n = top_n
        self.output_dir.mkdir(parents=True, exist_ok=True)
        print(f"Loaded summary from '{summary_file}'")
        print(f"Outputs will be saved to '{self.output_dir}'")

        # Set plot style
        sns.set_theme(style="whitegrid")

    def list_top_features_per_task(self):
        """
        Answers Q2: Identifies the most important features for each task.
        """
        print("\n--- Q2: Analyzing Top Features Per Task ---")
        top_features = (
            self.df.groupby("task")
            .apply(lambda x: x.nlargest(self.top_n, "mean_importance"))
            .reset_index(drop=True)
        )

        output_path = self.output_dir / "1_top_features_per_task.csv"
        top_features.to_csv(output_path, index=False)
        print(f"Saved top {self.top_n} features for each task to '{output_path}'")

        # Display a snippet for quick review
        print("Top 5 features for task 'BBB_Martins':")
        print(
            top_features[top_features["task"] == "BBB_Martins"]
            .head(5)
            .to_string(index=False)
        )

    def analyze_feature_generality(self):
        """
        Answers Q1: Investigates if features are task-specific or general.
        """
        print("\n--- Q1: Analyzing Feature Generality vs. Specificity ---")
        # Consider a feature "important" for a task if it's in the top N
        top_features_global = (
            self.df.groupby("task")
            .apply(lambda x: x.nlargest(self.top_n, "mean_importance"))
            .reset_index(drop=True)
        )

        # Create a unique identifier for each feature
        top_features_global["feature_id"] = (
            "L"
            + top_features_global["layer_id"].astype(str)
            + "_F"
            + top_features_global["feature_in_layer"].astype(str)
        )

        # Count how many tasks each feature is important for
        generality_counts = (
            top_features_global["feature_id"].value_counts().reset_index()
        )
        generality_counts.columns = ["feature_id", "num_tasks_important_in"]

        general_features = generality_counts[
            generality_counts["num_tasks_important_in"] > 1
        ]

        output_path = self.output_dir / "2_general_features_summary.csv"
        general_features.to_csv(output_path, index=False)
        print(
            f"Found {len(general_features)} features important for more than one task."
        )
        print(f"Saved summary of general features to '{output_path}'")

        if not general_features.empty:
            plt.figure(figsize=(10, 6))
            sns.countplot(
                data=generality_counts, x="num_tasks_important_in", palette="viridis"
            )
            plt.title(
                f"Distribution of Feature Generality (Top {self.top_n} Features per Task)"
            )
            plt.xlabel("Number of Tasks a Feature is Important For")
            plt.ylabel("Count of Unique SAE Features")
            plot_path = self.output_dir / "2_feature_generality_distribution.png"
            plt.savefig(plot_path, dpi=150, bbox_inches="tight")
            plt.close()
            print(f"Saved generality distribution plot to '{plot_path}'")

    def analyze_layer_contribution(self):
        """
        Answers Q3: Aggregates importance scores by layer to see which are most influential.
        """
        print("\n--- Q3: Analyzing Layer-wise Contribution ---")
        # Filter for only SAE features
        sae_df = self.df[self.df["type"] == "SAE"].copy()

        # Sum the importance of all features within a layer for each task
        layer_importance = (
            sae_df.groupby(["task", "layer_id"])["mean_importance"].sum().reset_index()
        )

        output_path = self.output_dir / "3_layer_contribution_summary.csv"
        layer_importance.to_csv(output_path, index=False)
        print(f"Saved layer contribution summary to '{output_path}'")

        plt.figure(figsize=(14, 8))
        sns.barplot(
            data=layer_importance,
            x="task",
            y="mean_importance",
            hue="layer_id",
            palette="rocket",
        )
        plt.title("Total SAE Feature Importance Contribution by Layer for Each Task")
        plt.ylabel("Sum of Mean Feature Importance")
        plt.xlabel("TDC Task")
        plt.xticks(rotation=45, ha="right")
        plt.legend(title="SAE Layer")
        plot_path = self.output_dir / "3_layer_contribution_by_task.png"
        plt.savefig(plot_path, dpi=150, bbox_inches="tight")
        plt.close()
        print(f"Saved layer contribution plot to '{plot_path}'")

    def analyze_synergy_with_ecfp(self):
        """
        Answers Q4: Compares SAE and ECFP feature importance in the hybrid model.
        """
        print("\n--- Q4: Analyzing Synergy with ECFP Features ---")
        hybrid_df = self.df[self.df["feature_set"] == "SAE ⊕ ECFP"].copy()

        if hybrid_df.empty:
            print("No data for 'SAE ⊕ ECFP' feature set found. Skipping this analysis.")
            return

        # Get total importance for each feature type (SAE vs ECFP) per task
        type_importance = (
            hybrid_df.groupby(["task", "type"])["mean_importance"].sum().reset_index()
        )

        output_path = self.output_dir / "4_feature_type_importance_summary.csv"
        type_importance.to_csv(output_path, index=False)
        print(f"Saved feature type (SAE vs ECFP) importance summary to '{output_path}'")

        plt.figure(figsize=(14, 8))
        sns.violinplot(
            data=hybrid_df,
            x="task",
            y="mean_importance",
            hue="type",
            split=True,
            inner="quart",
            palette={"SAE": "lightblue", "ECFP": "salmon"},
        )
        plt.title("Distribution of Feature Importances (SAE vs. ECFP) in Hybrid Model")
        plt.ylabel("Mean Feature Importance")
        plt.xlabel("TDC Task")
        plt.xticks(rotation=45, ha="right")
        plt.yscale("log")
        plt.legend(title="Feature Type")
        plot_path = self.output_dir / "4_importance_distribution_sae_vs_ecfp.png"
        plt.savefig(plot_path, dpi=150, bbox_inches="tight")
        plt.close()
        print(f"Saved SAE vs. ECFP importance distribution plot to '{plot_path}'")

    def run_all(self):
        """Runs all analysis steps."""
        self.list_top_features_per_task()
        self.analyze_feature_generality()
        self.analyze_layer_contribution()
        self.analyze_synergy_with_ecfp()
        print("\nAnalysis complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Summarize and analyze SAE feature importance results."
    )
    parser.add_argument(
        "--summary_file",
        type=str,
        required=True,
        help="Path to the 'feature_importance_summary.csv' file generated by the analysis script.",
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        default="experiments/sae_importance_summary",
        help="Directory to save the summary plots and CSVs.",
    )
    parser.add_argument(
        "--top_n",
        type=int,
        default=50,
        help="Number of top features to consider for generality analysis.",
    )
    args = parser.parse_args()

    analyzer = ImportanceAnalyzer(
        summary_file=args.summary_file, output_dir=args.out_dir, top_n=args.top_n
    )
    analyzer.run_all()
