# causal_feature_mapping.py

import os
import sys
from collections import Counter
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from rdkit import Chem
from scipy.stats import fisher_exact
from tqdm.auto import tqdm

# Ensure lmkit is in the path
if os.getcwd().endswith("notebooks"):
    os.chdir("../")
if "." not in sys.path:
    sys.path.insert(0, ".")

# Suppress RDKit errors for cleaner output
from rdkit import RDLogger

RDLogger.DisableLog("rdApp.*")

# --- LMKit Imports ---
from lmkit.impl import caching, hooks as hooks_lib
from lmkit.impl import transformer
from lmkit.impl.hooks import HookRequest, capture
from lmkit.sparse import sae as sae_lib
from lmkit.sparse import utils as sae_utils
from lmkit.sparse.atlas_utils import NeuronSelector, load_collector
from lmkit.sparse.sae import SAEKit, normalize
from lmkit.tools import stem, train_utils

# --- Configuration ---
MODEL_DIR = "models/transformer_sm"
CKPT_ID = 59712
SAE_NAME = "relu_4x_e9a211"
SAE_DIR = f"models/saes/{SAE_NAME}"
COLLECTOR_DIR = "dump/atlas"
LAYER_ID = 4
TOP_N_FEATURES = 128
ATTRIBUTION_SAMPLES = 10
STEERING_ALPHA = 4.0
GENERATION_SAMPLES = 2048



def run_attribution_for_feature(
    sae_kit: SAEKit,
    collector,
    layer_id: int,
    feature_id: int,
    top_k_samples: int = 10,
) -> str:
    """
    For a given feature, find its top activating examples, compute attribution scores
    via backpropagation, extract the high-importance token fragments, and return the
    most common one.
    """
    top_sequences = collector.top_sequences_for(feature_id)[:top_k_samples]
    if not top_sequences:
        return None

    sae_params = sae_kit.sae_params[layer_id]
    sae_config = sae_kit.sae_configs[layer_id]

    extracted_fragments = []

    for _, _, tokens in top_sequences:
        if not tokens:
            continue

        valid_tokens = [
            t for t in tokens if sae_kit.tokenizer.token_to_id(t) is not None
        ]
        if not valid_tokens:
            continue

        ids = jnp.array([[sae_kit.tokenizer.token_to_id(t) for t in valid_tokens]])
        positions = jnp.arange(ids.shape[1])[None, :]

        # --- Causal Attribution via VJP Backpropagation ---

        # 1. Forward pass to get the transformer's residual stream for the SAE
        specific_hook, _ = capture(
            HookRequest(layer=layer_id, kind=sae_config.placement)
        )
        residuals = sae_utils.run_and_capture(
            sae_kit.run_fn,
            ids,
            positions,
            sae_kit.lm_params,
            sae_kit.lm_config,
            specific_hook,
        )
        sae_input = residuals[(sae_config.layer_id, sae_config.placement)]

        # 2. Find the token position with the max activation for this feature.
        _, sae_acts = sae_lib.run(sae_input, sae_params, sae_config, return_act=True)
        target_token_idx = jnp.argmax(sae_acts[0, :, feature_id])
        sae_input_at_target_pos = sae_input[0, target_token_idx, :]

        # 3. Define a function for the SAE forward pass on a single activation vector
        #    and compute its VJP (the backward pass).
        def sae_forward_scalar(residual_vec):
            """Takes one residual vector and returns one scalar SAE feature activation."""
            _, acts = sae_lib.run(
                residual_vec[None, None, :], sae_params, sae_config, return_act=True
            )
            return acts[0, 0, feature_id]

        # Get the VJP function. The incoming gradient for a scalar output is 1.0.
        _, vjp_sae_fn = jax.vjp(sae_forward_scalar, sae_input_at_target_pos)
        grad_wrt_residual_vec = vjp_sae_fn(1.0)[0]

        # 4. Define a function for the transformer forward pass that returns the residual stream.
        #    This function must be pure with respect to the parameters we are differentiating.
        def transformer_forward_residuals(lm_params_for_grad):
            """Runs the transformer and captures the target residual stream."""
            hook_req = hooks_lib.HookRequest(layer=layer_id, kind=sae_config.placement)
            hooks_to_return, _ = hooks_lib.capture(hook_req)

            cache = caching.TransformerCache.create(
                positions, sae_kit.lm_config, dtype=jnp.bfloat16, dynamic=False
            )
            _, _, captured = sae_kit.run_fn(
                ids,
                cache,
                lm_params_for_grad,
                sae_kit.lm_config,
                hooks_to_return=hooks_to_return,
            )
            return captured[0]

        # 5. Compute the VJP of the transformer pass.
        _, vjp_transformer_fn = jax.vjp(
            transformer_forward_residuals, sae_kit.lm_params
        )

        # 6. Construct the cotangent (incoming gradient) for the transformer's backward pass.
        #    It's zero everywhere except at the target token position, where it's the
        #    gradient from the SAE's backward pass.
        cotangent_for_residuals = jnp.zeros_like(sae_input)
        cotangent_for_residuals = cotangent_for_residuals.at[
            0, target_token_idx, :
        ].set(grad_wrt_residual_vec)

        # 7. Execute the transformer's backward pass. This returns gradients for all parameters.
        param_grads = vjp_transformer_fn(cotangent_for_residuals)[0]

        # 8. The gradient w.r.t the embedding table contains the attribution scores.
        #    We index it with our input token IDs to get the specific per-position gradients.
        embedding_table_grads = param_grads["embed_table"]
        grads = jnp.take(embedding_table_grads, ids, axis=0)

        # --- End: Causal Attribution ---

        # Calculate per-token importance as the L2 norm of the gradient
        token_importance = np.linalg.norm(np.asarray(grads[0]), axis=-1)

        if np.any(np.isnan(token_importance)) or np.std(token_importance) == 0:
            continue

        # Identify important tokens (e.g., > 1.5 standard deviations above the mean)
        threshold = np.mean(token_importance) + 1.5 * np.std(token_importance)
        important_indices = np.where(token_importance > threshold)[0]

        if len(important_indices) > 0:
            fragment_tokens = [valid_tokens[i] for i in important_indices]
            fragment_smiles = "".join(fragment_tokens)

            # Attempt to parse and canonicalize the fragment
            mol_frag = Chem.MolFromSmiles(fragment_smiles, sanitize=False)
            if mol_frag:
                try:
                    Chem.SanitizeMol(mol_frag)
                    canonical_smarts = Chem.MolToSmarts(mol_frag)
                    extracted_fragments.append(canonical_smarts)
                except Exception:
                    # RDKit sanitization can fail on partial fragments; this is expected.
                    continue

    if not extracted_fragments:
        return None

    # Return the most commonly extracted canonical fragment for this feature
    most_common_fragment = Counter(extracted_fragments).most_common(1)[0][0]
    return most_common_fragment

# --- STEP 2: Generative Validation ---

def make_steering_vector(sae_kit, layer_id, feature_id, alpha_std):
    """Creates the residual stream vector to steer generation."""
    sae_params = sae_kit.sae_params[layer_id]
    direction = sae_params["W_dec"][:, feature_id]
    collector = load_collector(f"{COLLECTOR_DIR}/{SAE_NAME}_layer{layer_id}.pkl")
    magnitude = alpha_std * collector.std[feature_id]
    return direction * magnitude


def run_steering_experiment(sae_kit, hypotheses):
    """
    Generates a baseline set and a steered set for each hypothesis, then
    compares the frequency of the target substructure.
    """
    print("\n--- Running Generative Validation ---")

    print(f"Generating {GENERATION_SAMPLES} baseline molecules...")
    _, baseline_mols = train_utils.get_eval_fn(
        tokenizer=sae_kit.tokenizer,
        model_config=sae_kit.lm_config,
        num_samples=GENERATION_SAMPLES,
        batch_size=512,
        metrics_fn=stem.molstats,
        max_tokens=256,
        temp=0.7,
        log_metrics=False,
        return_results=True,
        run_fn=transformer.run,
        verbose=True,
    )(jax.random.PRNGKey(42), step=-1, params=sae_kit.lm_params)

    valid_baseline = [s for s in (stem.clean_smiles(m) for m in baseline_mols) if s]
    print(f"Generated {len(valid_baseline)} valid baseline molecules.")

    results = []
    sorted_hypotheses = sorted(hypotheses.items())

    for i, (feature_id, smarts) in enumerate(sorted_hypotheses):
        if smarts is None:
            continue

        print(
            f"\n({i + 1}/{len(hypotheses)}) Validating Feature {feature_id} -> SMARTS '{smarts}'"
        )

        query_mol = Chem.MolFromSmarts(smarts)
        if not query_mol:
            print(f"  Skipping invalid SMARTS: {smarts}")
            continue

        baseline_hits = sum(
            1
            for smi in valid_baseline
            if smi
            and (mol := Chem.MolFromSmiles(smi))
            and mol.HasSubstructMatch(query_mol)
        )
        baseline_freq = baseline_hits / len(valid_baseline) if valid_baseline else 0

        steering_vector = make_steering_vector(
            sae_kit, LAYER_ID, feature_id, STEERING_ALPHA
        )

        def decode_latent(latent, *, params, config, x_mean, x_std):
            act = config.act_fn(latent, config)

            rec_scaled = act @ params["W_dec"] + params["b_dec"]
            return sae_lib.rescale(rec_scaled, x_mean, x_std).astype(latent.dtype)

        def make_add_cb(steering_vector, layer_id: int, alpha: float = 1.0):
            cfg = sae_kit.sae_configs[layer_id]
            prm = sae_kit.sae_params[layer_id]

            def _cb(x):
                if cfg.rescale_inputs:
                    x_norm, x_mean, x_std = sae_lib.normalize(x)
                else:
                    x_norm, x_mean, x_std = x, 0.0, 1.0

                if cfg.pre_enc_bias:
                    x_enc = x_norm - prm["b_dec"]
                else:
                    x_enc = x_norm

                # Encode the input and add the amplified steering vector
                z_cur = x_enc @ prm["W_enc"]
                z_added = (
                    z_cur + alpha * steering_vector
                )  # Just add the scaled steering vector

                # Decode
                x_added = decode_latent(
                    z_added, params=prm, config=cfg, x_mean=x_mean, x_std=x_std
                )
                return x_added.astype(x.dtype)

            return _cb

        callback = make_add_cb(steering_vector, LAYER_ID, alpha=1.0)

        editor = hooks_lib.ActivationEditor(
            edits=(
                hooks_lib.Edit(
                    layer=LAYER_ID,
                    kind=sae_kit.sae_configs[LAYER_ID].placement,
                    op="call",
                    callback=callback,
                ),
            )
        )

        steered_run = partial(
            transformer.run,
            editor=editor,
        )

        lm_config_mod = sae_kit.lm_config.copy(
            dict(
                bos_id=sae_kit.tokenizer.bos_token_id,
                eos_id=sae_kit.tokenizer.eos_token_id,
                pad_id=sae_kit.tokenizer.pad_token_id,
            )
        )

        _, steered_mols = train_utils.get_eval_fn(
            tokenizer=sae_kit.tokenizer,
            model_config=lm_config_mod,
            num_samples=GENERATION_SAMPLES,
            batch_size=512,
            metrics_fn=stem.molstats,
            max_tokens=256,
            temp=0.7,
            log_metrics=False,
            return_results=True,
            run_fn=steered_run,
            verbose=False,
        )(jax.random.PRNGKey(feature_id), step=-1, params=sae_kit.lm_params)

        valid_steered = [s for s in (stem.clean_smiles(m) for m in steered_mols) if s]

        steered_hits = sum(
            1
            for smi in valid_steered
            if smi
            and (mol := Chem.MolFromSmiles(smi))
            and mol.HasSubstructMatch(query_mol)
        )
        steered_freq = steered_hits / len(valid_steered) if valid_steered else 0

        enrichment = steered_freq / baseline_freq if baseline_freq > 0 else float("inf")
        table = [
            [steered_hits, len(valid_steered) - steered_hits],
            [baseline_hits, len(valid_baseline) - baseline_hits],
        ]
        odds_ratio, p_value = fisher_exact(table, alternative="greater")

        results.append(
            {
                "Layer": LAYER_ID,
                "Feature ID": feature_id,
                "Hypothesized SMARTS": smarts,
                "Baseline Freq": baseline_freq,
                "Steered Freq": steered_freq,
                "Enrichment Factor": enrichment,
                "p-value": p_value,
            }
        )
        print(
            f"  Result: Baseline Freq={baseline_freq:.4f}, Steered Freq={steered_freq:.4f}, Enrichment={enrichment:.2f}x (p={p_value:.2e})"
        )

    return pd.DataFrame(results)


# --- Main Execution Logic ---


def main():
    """Main function to run the entire causal analysis pipeline."""
    print("--- Loading Models and Data ---")
    sae_kit = SAEKit.load(model_dir=MODEL_DIR, checkpoint_id=CKPT_ID, sae_dir=SAE_DIR)
    collector = load_collector(f"{COLLECTOR_DIR}/{SAE_NAME}_layer{LAYER_ID}.pkl")
    print(f"Loaded SAE Kit '{SAE_NAME}' and Collector for Layer {LAYER_ID}.")

    print(
        f"\n--- Step 1: Generating Hypotheses via Causal Attribution for Layer {LAYER_ID} ---"
    )
    selector = NeuronSelector(collector)
    top_feature_ids = selector.pick(metric="mean*max", topk=TOP_N_FEATURES)

    hypotheses = {}
    for feature_id in tqdm(top_feature_ids, desc="Attributing Features"):
        hypothesized_smarts = run_attribution_for_feature(
            sae_kit,
            collector,
            LAYER_ID,
            int(feature_id),
            top_k_samples=ATTRIBUTION_SAMPLES,
        )
        if hypothesized_smarts:
            hypotheses[int(feature_id)] = hypothesized_smarts
            print(f"  Feature {feature_id} -> {hypothesized_smarts}")

    if not hypotheses:
        print("Could not generate any hypotheses. Exiting.")
        return

    # results_df = run_steering_experiment(sae_kit, hypotheses)

    # if results_df is None or results_df.empty:
    #     print("No results from steering experiment. Exiting.")
    #     return

    # print("\n\n" + "=" * 80)
    # print(" " * 20 + "FINAL CAUSAL ANALYSIS REPORT")
    # print("=" * 80)
    # print(f"Layer: {LAYER_ID}, SAE: {SAE_NAME}, Top {len(hypotheses)} Features Tested")
    # print("-" * 80)

    # sorted_results = results_df.sort_values("Enrichment Factor", ascending=False)

    # pd.set_option("display.max_rows", 200)
    # pd.set_option("display.width", 120)
    # print(sorted_results.head(20).to_string(index=False, float_format="%.4f"))
    # print("=" * 80)

    # output_file = f"causal_mapping_L{LAYER_ID}_{SAE_NAME}.csv"
    # sorted_results.to_csv(output_file, index=False)
    # print(f"Full results saved to {output_file}")


if __name__ == "__main__":
    main()
