# experiments/summarize_ablation.py
import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
import warnings

# Suppress warnings from matplotlib about font caching
warnings.filterwarnings("ignore", category=UserWarning)


class AblationAnalyzer:
    """
    Analyzes and visualizes the results from the feature ablation experiments.
    """

    def __init__(self, base_dir: str, output_dir: str):
        self.base_dir = Path(base_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        print(f"Loading ablation results from: {self.base_dir}")
        self.data = self._load_and_parse_data()

        if self.data.empty:
            raise FileNotFoundError(
                f"No valid ablation CSV files found in subdirectories of {self.base_dir}"
            )

        print(
            f"Successfully loaded and parsed data from {len(self.data.ablated_feature.unique())} ablation runs."
        )

        # Set a professional plot style
        sns.set_theme(style="whitegrid", context="paper", font_scale=1.2)

    def _load_and_parse_data(self) -> pd.DataFrame:
        """Loads all ablation CSVs and parses metadata from their paths more robustly."""
        all_dfs = []
        csv_files = list(self.base_dir.rglob("ablation_impact_*.csv"))

        for f in csv_files:
            try:
                # Extract metadata
                ablated_feature_str = f.parent.name
                stem = f.stem

                # Robustly find the seed and task name
                # The filename is like 'ablation_impact_TASK_NAME_seedX'
                if "_seed" not in stem:
                    continue  # Skip files that don't match the pattern

                parts = stem.split("_seed")
                seed = int(parts[-1])

                # The task name is everything between the prefix and the seed part
                prefix = "ablation_impact_"
                task = parts[0][len(prefix) :]

                # Load data
                df = pd.read_csv(f)
                df["ablated_feature"] = ablated_feature_str
                df["task"] = task
                df["seed"] = seed
                all_dfs.append(df)
            except Exception as e:
                print(f"Warning: Could not parse file {f}. Error: {e}")

        return pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame()

    def summarize_overall_impact(self):
        """
        Answers Q1: Quantifies the overall impact of each ablation.
        Generates a summary table.
        """
        print("\n--- Question 1: Summarizing Causal Impact of Ablation ---")

        # Separate classification and regression tasks
        clf_data = self.data.dropna(subset=["prob_drop"]).copy()

        if not clf_data.empty:
            # Calculate the mean across seeds first for a cleaner summary
            seed_agg = (
                clf_data.groupby(["ablated_feature", "task", "seed"])["prob_drop"]
                .agg(["mean", "max"])
                .reset_index()
            )
            final_summary = (
                seed_agg.groupby(["ablated_feature", "task"])["mean"]
                .agg(["mean", "std"])
                .reset_index()
            )
            final_summary.columns = [
                "Ablated Feature",
                "Task",
                "Mean Drop (across seeds)",
                "Std Dev of Mean Drop",
            ]

            print("\nSummary for Classification Tasks (Average Probability Drop):")
            print(final_summary.to_string(index=False, float_format="%.5f"))

            # Save to CSV
            output_path = self.output_dir / "1_summary_classification_impact.csv"
            final_summary.to_csv(output_path, index=False)
            print(f"\nSaved classification summary to {output_path}")

    # --- REPLACE THE OLD FUNCTION WITH THIS ONE ---


    def plot_impact_distribution(self):
        """
        Answers Q2: Visualizes the specialist vs. generalist behavior.
        Generates a publishable PDF figure.
        """
        print("\n--- Question 2: Visualizing Specialist vs. Generalist Behavior ---")

        clf_data = self.data.dropna(subset=["prob_drop"]).copy()
        # We only want to plot meaningful drops, not the vast majority of zeros
        clf_data_significant = clf_data[clf_data["prob_drop"] > 1e-4].copy()

        if clf_data_significant.empty:
            print("No significant probability drops found to plot.")
            return

        # Clean up feature names for plotting
        clf_data_significant["Ablated Feature"] = (
            clf_data_significant["ablated_feature"]
            .str.replace("specialist_", "")
            .replace("generalist_", "")
        )

        plt.figure(figsize=(12, 8))
        ax = sns.violinplot(
            data=clf_data_significant,
            x="task",
            y="prob_drop",
            hue="Ablated Feature",
            inner="quart",
            cut=0,  # Don't extend violins beyond data range
            palette="muted",
            hue_order=sorted(
                clf_data_significant["Ablated Feature"].unique()
            ),  # Consistent color mapping
        )

        ax.set_title(
            "Distribution of Performance Drop After Feature Ablation (Classification Tasks)",
            fontsize=16,
            pad=20,
        )
        ax.set_ylabel("Probability Drop for Correct Class", fontsize=12)
        ax.set_xlabel("TDC Task", fontsize=12)
        ax.legend(title="Ablated Feature", loc="upper right")

        # Use a log scale to better visualize the range of impacts
        ax.set_yscale("log")

        # --- THIS IS THE CORRECTED PART ---
        # Use plt.xticks() to set rotation and alignment for the x-axis labels
        plt.xticks(rotation=30, ha="right")

        plt.tight_layout()

        output_path_pdf = self.output_dir / "2_impact_distribution.pdf"
        plt.savefig(output_path_pdf, bbox_inches="tight")
        plt.close()

        print(f"Saved impact distribution plot to {output_path_pdf}")

        
    def identify_most_sensitive_molecules(self, top_n=10):
        """
        Answers Q3: Finds the molecules most affected by each ablation.
        Generates a summary CSV.
        """
        print(f"\n--- Question 3: Identifying Top {top_n} Most Sensitive Molecules ---")

        # We want the single most-affected example from each seed to see the diversity
        most_sensitive = self.data.loc[
            self.data.groupby(["ablated_feature", "task", "seed"])["prob_drop"].idxmax()
        ]

        # Sort to show the most impactful ablations first
        most_sensitive = most_sensitive.sort_values(
            by=["ablated_feature", "task", "prob_drop"], ascending=[True, True, False]
        )

        # Select relevant columns for a clean output
        output_cols = [
            "ablated_feature",
            "task",
            "seed",
            "SMILES",
            "y_true",
            "prob_drop",
            "original_prob_correct",
            "ablated_prob_correct",
        ]

        final_table = most_sensitive[output_cols].head(
            top_n * 5
        )  # show a good number of examples

        print("\nTop Examples of Molecules Most Impacted by Ablation:")
        print(final_table.head(15).to_string(index=False, float_format="%.4f"))

        output_path = self.output_dir / "3_most_sensitive_molecules.csv"
        final_table.to_csv(output_path, index=False)
        print(f"\nSaved list of most sensitive molecules to {output_path}")

    def run_all_analyses(self):
        """Executes the full analysis pipeline."""
        self.summarize_overall_impact()
        self.plot_impact_distribution()
        self.identify_most_sensitive_molecules()
        print("\n\nAnalysis complete. PDFs and CSVs saved to the output directory.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Summarize and analyze feature ablation experiment results."
    )
    parser.add_argument(
        "--input_dir",
        type=str,
        required=True,
        help="The base directory containing the subdirectories for each ablation run.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="experiments/ablation_summary",
        help="Directory to save the final summary plots and CSVs.",
    )
    args = parser.parse_args()

    try:
        analyzer = AblationAnalyzer(base_dir=args.input_dir, output_dir=args.output_dir)
        analyzer.run_all_analyses()
    except Exception as e:
        print(f"\nAn error occurred during analysis: {e}")
