#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Generates a suite of advanced, publication-quality plots from fragment scan
and downstream task importance summaries to tell a story about the learned
SAE feature space.
"""

from __future__ import annotations

import argparse
import pickle
from pathlib import Path
import warnings

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import umap.umap_ as umap
from flax.core import unfreeze

# Suppress UserWarnings from matplotlib about font caching
warnings.filterwarnings("ignore", category=UserWarning)

# =============================================================================
# === FRAGMENT CATEGORIZATION (CRITICAL USER INPUT) ===========================
# =============================================================================
# This dictionary maps specific fragment names (from your scan) to broader
# chemical concept categories. This is ESSENTIAL for the Hierarchy and
# Geometry plots. Please review and customize this list based on the
# fragments you included in your scan.
FRAGMENT_CATEGORIES = {
    # Simple Atoms / Functional Groups
    "fg:primary_amine": "Simple Functional Group",
    "fg:carboxylic_acid": "Simple Functional Group",
    "fg:amide": "Simple Functional Group",
    "fg:sulfonamide": "Simple Functional Group",
    "fg:ketone": "Simple Functional Group",
    "fg:ether": "Simple Functional Group",
    "fg:halogen_F": "Halogen",
    "fg:halogen_Cl": "Halogen",
    "fg:nitro": "Simple Functional Group",
    # Simple Rings
    "ring:benzene": "Aromatic Ring (Simple)",
    "ring:pyridine": "Aromatic N-Ring (Simple)",
    "ring:thiophene": "Aromatic S-Ring (Simple)",
    "ring:cyclohexane": "Aliphatic Ring",
    "ring:piperidine": "Aliphatic N-Ring",
    # Complex / Fused Rings
    "ring:indole": "Fused Ring",
    "ring:naphthalene": "Fused Ring",
    "ring:quinoline": "Fused Ring",
    "oxadiazole_1_3_4": "Aromatic N-Ring (Simple)",
    "aryl_sulfonamide": "Privileged Scaffold",
    "terminal_acylated_heteroaromatic": "Privileged Scaffold",
}
# =============================================================================


class AdvancedPlotter:
    """Generates a suite of ICLR-worthy plots from analysis summaries."""

    def __init__(self, out_dir: str, data_paths: dict):
        self.out_dir = Path(out_dir)
        self.out_dir.mkdir(parents=True, exist_ok=True)
        self.paths = data_paths
        sns.set_theme(style="whitegrid", context="paper")

        print("Loading data sources...")
        try:
            self.df_layer_summary = pd.read_csv(self.paths["layer_summary"])
            self.df_sae_summary = pd.read_csv(self.paths["sae_summary"])
            self.df_all_scores = pd.read_parquet(self.paths["all_scores"])
            self.df_downstream = pd.read_csv(self.paths["downstream_importance"])
            self.df_frag_ranking = pd.read_csv(self.paths["fragment_ranking"])
            print("All data loaded successfully.")
        except FileNotFoundError as e:
            print(f"ERROR: Could not find a required data file: {e.filename}")
            print("Please ensure all paths in the `main` function are correct.")
            exit(1)

    def _load_sae_weights(self, sae_root_dir: str, sae_tag: str, layer: int) -> np.ndarray:
        """Helper to load W_dec from a specified SAE checkpoint."""
        # Construct path to final checkpoint
        sae_path = Path(sae_root_dir) / sae_tag / f"sae_{layer}" / "checkpoint_final.pkl"

        if not sae_path.exists():
            raise FileNotFoundError(f"Could not find SAE weights at: {sae_path}")

        print(f"Loading W_dec from {sae_path}...")
        with open(sae_path, "rb") as f:
            # --- FIX START ---
            # Load the entire checkpoint dictionary
            checkpoint = pickle.load(f)
            # Extract the parameters dictionary using the 'params' key
            params = checkpoint['params']
            # --- FIX END ---
        
        # Handle both Flax FrozenDict and regular dicts
        if not isinstance(params, dict):
                params = unfreeze(params)
                
        return np.asarray(params['W_dec'])

    def plot_hierarchy(self):
        """(Plot 1) Shows distribution of peak layers for different fragment complexities."""
        print("\n--- Generating Plot 1: Hierarchy Plot ---")

        # Find the peak layer for each fragment based on WSD score
        peak_layers = self.df_layer_summary.loc[
            self.df_layer_summary.groupby("fragment")["wsd_top10_med__median"].idxmax()
        ][["fragment", "layer"]].rename(columns={"layer": "peak_layer"})

        peak_layers["category"] = peak_layers["fragment"].map(FRAGMENT_CATEGORIES)
        peak_layers = peak_layers.dropna(subset=["category"])

        if peak_layers.empty:
            print(
                "Could not generate Hierarchy Plot: No fragments matched the defined categories."
            )
            return

        plt.figure(figsize=(8, 5))
        sns.kdeplot(
            data=peak_layers,
            x="peak_layer",
            hue="category",
            fill=True,
            common_norm=False,
            palette="viridis",
            alpha=0.5,
            linewidth=2,
        )
        plt.title(
            "Hierarchy of Chemical Concepts Across Model Layers", fontsize=16, pad=20
        )
        plt.xlabel("Model Layer (Peak WSD Score)", fontsize=12)
        plt.ylabel("Density of Fragments", fontsize=12)
        plt.xticks(np.arange(0, peak_layers["peak_layer"].max() + 1, 1))
        plt.legend(
            title="Fragment Category", bbox_to_anchor=(1.05, 1), loc="upper left"
        )

        save_path = self.out_dir / "1_hierarchy_peak_layer_distribution.png"
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close()
        print(f"Saved to {save_path}")

    def plot_geometry(self, sae_root_dir: str, umap_sae_tag: str, umap_layer: int):
        """(Plot 2) UMAP of feature space, colored by chemical concept."""
        print(
            f"\n--- Generating Plot 2: Geometry Plot (UMAP for L{umap_layer}, {umap_sae_tag}) ---"
        )
        try:
            w_dec = self._load_sae_weights(sae_root_dir, umap_sae_tag, umap_layer)
        except FileNotFoundError as e:
            print(f"Skipping Geometry Plot. {e}")
            return

        # Find the best fragment for each feature in this specific run
        df_wsd = self.df_all_scores[self.df_all_scores["metric"] == "wsd_mean"].copy()
        run_filter = (df_wsd["sae_tag"] == umap_sae_tag) & (
            df_wsd["layer"] == umap_layer
        )
        df_run = df_wsd[run_filter]

        if df_run.empty:
            print(
                f"Skipping Geometry Plot: No WSD scores found for {umap_sae_tag} Layer {umap_layer}."
            )
            return

        best_frags = df_run.loc[df_run.groupby("feature_id")["score"].idxmax()]
        best_frags["category"] = best_frags["fragment"].map(FRAGMENT_CATEGORIES)

        # Run UMAP
        reducer = umap.UMAP(
            n_neighbors=15, min_dist=0.1, metric="cosine", random_state=42
        )
        embedding = reducer.fit_transform(w_dec)

        plot_df = pd.DataFrame(embedding, columns=["UMAP 1", "UMAP 2"])
        plot_df["feature_id"] = plot_df.index
        plot_df = plot_df.merge(
            best_frags[["feature_id", "category"]], on="feature_id", how="left"
        )
        plot_df["category"] = plot_df["category"].fillna("Unclassified")

        plt.figure(figsize=(10, 8))
        sns.scatterplot(
            data=plot_df,
            x="UMAP 1",
            y="UMAP 2",
            hue="category",
            palette="tab20",
            s=10,
            alpha=0.7,
        )
        plt.title(
            f"UMAP of SAE Feature Space (L{umap_layer}, {umap_sae_tag})",
            fontsize=16,
            pad=20,
        )
        plt.xlabel("UMAP Dimension 1", fontsize=12)
        plt.ylabel("UMAP Dimension 2", fontsize=12)
        plt.legend(
            title="Most Selective Fragment Type",
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )

        save_path = self.out_dir / f"2_geometry_umap_L{umap_layer}_{umap_sae_tag}.png"
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close()
        print(f"Saved to {save_path}")

    def plot_scalability(self):
        """(Plot 3) Shows how feature quality (WSD) changes with SAE size."""
        print("\n--- Generating Plot 3: Scalability Plot ---")

        # Select a few representative benchmark fragments
        benchmark_frags = [
            "ring:benzene",
            "fg:sulfonamide",
            "ring:indole",
            "fg:carboxylic_acid",
        ]
        df_filtered = self.df_sae_summary[
            self.df_sae_summary["fragment"].isin(benchmark_frags)
        ]

        if df_filtered.empty:
            print(
                "Could not generate Scalability Plot: No data for benchmark fragments."
            )
            return

        # Ensure order is logical (4x, 8x, 16x...)
        sae_order = sorted(
            df_filtered["sae_tag"].unique(),
            key=lambda x: int(x.split("x")[0].split("_")[-1]),
        )

        plt.figure(figsize=(10, 6))
        sns.violinplot(
            data=df_filtered,
            x="sae_tag",
            y="wsd_top10_med__median",
            hue="fragment",
            order=sae_order,
            palette="muted",
            cut=0,
        )
        plt.title("SAE Size vs. Feature Specificity (WSD)", fontsize=16, pad=20)
        plt.xlabel("SAE Dictionary Size", fontsize=12)
        plt.ylabel("Median WSD of Top 10 Features", fontsize=12)
        plt.xticks(rotation=30, ha="right")

        save_path = self.out_dir / "3_scalability_sae_size_vs_wsd.png"
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close()
        print(f"Saved to {save_path}")


    def run_all(self, sae_root_dir, umap_sae_tag, umap_layer):
        """Run all plotting functions."""
        self.plot_hierarchy()
        self.plot_geometry(sae_root_dir, umap_sae_tag, umap_layer)
        self.plot_scalability()
        print("\nAll plots generated successfully.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate advanced summary plots for SAE fragment scan analysis."
    )
    parser.add_argument(
        "--summary_dir",
        required=True,
        type=str,
        help="Path to the directory containing summary CSVs (e.g., 'fragment_scan_summaries').",
    )
    parser.add_argument(
        "--downstream_dir",
        required=True,
        type=str,
        help="Path to the directory containing downstream task importance summaries.",
    )
    parser.add_argument(
        "--sae_root_dir",
        required=True,
        type=str,
        help="Path to the root directory containing the trained SAE models (e.g., 'models/saes').",
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        default="advanced_plots",
        help="Directory to save the generated plots.",
    )
    parser.add_argument(
        "--umap_layer",
        type=int,
        default=4,
        help="Layer to use for the UMAP geometry plot.",
    )
    parser.add_argument(
        "--umap_sae_tag",
        type=str,
        default="relu_8x_95aa15",
        help="SAE tag (size/ID) to use for the UMAP geometry plot.",
    )

    args = parser.parse_args()

    # Define the expected paths based on the directory arguments
    paths = {
        "layer_summary": Path(args.summary_dir) / "per_fragment_layer_summary.csv",
        "sae_summary": Path(args.summary_dir) / "per_fragment_sae_summary.csv",
        "all_scores": Path(args.summary_dir) / "all_scores.parquet",
        "fragment_ranking": Path(args.summary_dir) / "fragment_difficulty_by_WSD.csv",
        "downstream_importance": Path(args.downstream_dir)
        / "feature_importance_summary.csv",
    }

    plotter = AdvancedPlotter(out_dir=args.out_dir, data_paths=paths)
    plotter.run_all(
        sae_root_dir=args.sae_root_dir,
        umap_sae_tag=args.umap_sae_tag,
        umap_layer=args.umap_layer,
    )
