# experiments/analyze_sae_importance.py
from __future__ import annotations
import argparse
import os
import joblib
from typing import Dict, Any, List, Tuple

import numpy as np
import pandas as pd
import xgboost as xgb
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder

# --- TDC loaders ---
from tdc.single_pred import ADME, Tox, HTS

# --- Your project's imports ---
from lmkit.sparse.sae import SAEKit
from lmkit.sparse import utils as sae_utils


# ------------------------------
# Utilities (from tdc_benchmarks.py)
# ------------------------------
def pad_and_stack(
    seqs: List[List[int]] | List[np.ndarray], pad_value: int
) -> np.ndarray:
    if not seqs:
        return np.zeros((0, 0), dtype=np.int32)
    max_len = max(len(s) for s in seqs)
    out = np.full((len(seqs), max_len), pad_value, dtype=np.int32)
    for i, s in enumerate(seqs):
        arr = np.asarray(s, dtype=np.int32)
        out[i, : arr.shape[0]] = arr
    return out


def _find_smiles_column(df: pd.DataFrame) -> str:
    for cand in ("Drug", "SMILES", "smiles", "X"):
        if cand in df.columns:
            return cand
    raise KeyError(f"Could not find a SMILES column among: {list(df.columns)}")


def generate_ecfp_features(smiles_list: List[str], radius=2, n_bits=2048) -> np.ndarray:
    feats = np.zeros((len(smiles_list), n_bits), dtype=np.float32)
    for i, s in enumerate(smiles_list):
        m = Chem.MolFromSmiles(s)
        if m is None:
            continue
        fp = AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=n_bits)
        bits = np.frombuffer(fp.ToBitString().encode(), "S1") == b"1"
        feats[i, :] = bits.astype(np.float32, copy=False)
    return feats


def generate_transformer_and_sae_features_for_layer(
    smiles_list: List[str], sae_kit: SAEKit, sae_layer: int, batch_size: int = 64
) -> Tuple[np.ndarray, np.ndarray]:
    import jax.numpy as jnp

    dense_features, sae_features = [], []
    for i in range(0, len(smiles_list), batch_size):
        batch_smiles = smiles_list[i : i + batch_size]
        encs = [sae_kit.tokenizer.encode(s) for s in batch_smiles]
        inputs_np = pad_and_stack([e.ids for e in encs], sae_kit.tokenizer.pad_token_id)
        positions_np = pad_and_stack(
            [np.arange(len(e.ids), dtype=np.int32) for e in encs], -1
        )
        inputs = jnp.asarray(inputs_np, dtype=jnp.int32)
        positions = jnp.asarray(positions_np, dtype=jnp.int32)

        residuals = sae_utils.run_and_capture(
            sae_kit.run_fn,
            inputs,
            positions,
            sae_kit.lm_params,
            sae_kit.lm_config,
            sae_kit.hooks,
        )
        cfg = sae_kit.sae_configs[sae_layer]
        resid = residuals[(cfg.layer_id, cfg.placement)]

        sae_acts = sae_kit.get_encoded(inputs, positions, layer_id=sae_layer)

        mask = (positions >= 0).astype(jnp.float32)
        L = mask.sum(axis=-1, keepdims=True)
        dense_pool = (resid * mask[..., None]).sum(axis=1) / jnp.maximum(L, 1.0)
        sae_pool = (sae_acts * mask[..., None]).max(axis=1)

        dense_features.append(np.asarray(dense_pool, dtype=np.float32))
        sae_features.append(np.asarray(sae_pool, dtype=np.float32))
    return np.vstack(dense_features), np.vstack(sae_features)


def concat_sae_across_layers(
    smiles: List[str], sae_kit: SAEKit, layers: List[int], batch_size=64
):
    dense_first, sae_all = None, []
    for j, L in enumerate(layers):
        d, z = generate_transformer_and_sae_features_for_layer(
            smiles, sae_kit, L, batch_size=batch_size
        )
        if dense_first is None:
            dense_first = d
        sae_all.append(z)
    return dense_first, np.concatenate(sae_all, axis=1)


def infer_task_type_from_y(y: np.ndarray) -> str:
    y = pd.Series(y).dropna()
    if y.empty:
        return "regression"
    u = np.unique(y.values)
    if y.dtype.kind in ("i", "u"):
        return "classification" if u.size <= 10 else "regression"
    if y.dtype.kind == "f":
        if u.size <= 10 and set(np.unique(np.round(u, 6))).issubset({0.0, 1.0}):
            return "classification"
        return "regression"
    return "classification"


def prepare_labels_for_classification(y_train, y_valid, y_test):
    y_all = pd.concat(
        [pd.Series(y_train), pd.Series(y_valid), pd.Series(y_test)], axis=0
    )
    le = LabelEncoder()
    y_all_enc = le.fit_transform(y_all.astype(str))
    n_tr, n_va = len(y_train), len(y_valid)
    y_tr = y_all_enc[:n_tr]
    y_va = y_all_enc[n_tr : n_tr + n_va]
    y_te = y_all_enc[n_tr + n_va :]
    return y_tr, y_va, y_te, le.classes_


def build_feature_sets(
    tr_smiles,
    va_smiles,
    te_smiles,
    sae_kit,
    sae_layers: List[int],
    pca_dims: int,
    ecfp_bits: int,
    batch_size: int,
):
    tr_ecfp = generate_ecfp_features(tr_smiles, n_bits=ecfp_bits)
    va_ecfp = generate_ecfp_features(va_smiles, n_bits=ecfp_bits)
    te_ecfp = generate_ecfp_features(te_smiles, n_bits=ecfp_bits)

    dense_tr, sae_tr = concat_sae_across_layers(
        tr_smiles, sae_kit, sae_layers, batch_size=batch_size
    )
    dense_va, sae_va = concat_sae_across_layers(
        va_smiles, sae_kit, sae_layers, batch_size=batch_size
    )
    dense_te, sae_te = concat_sae_across_layers(
        te_smiles, sae_kit, sae_layers, batch_size=batch_size
    )

    tr_hybrid = np.concatenate([sae_tr, tr_ecfp], axis=1)
    va_hybrid = np.concatenate([sae_va, va_ecfp], axis=1)
    te_hybrid = np.concatenate([sae_te, te_ecfp], axis=1)

    # For analysis, we only focus on SAE-based feature sets
    feature_sets = {
        "SAE Features": (sae_tr, sae_va, sae_te),
        "SAE ⊕ ECFP": (tr_hybrid, va_hybrid, te_hybrid),
    }
    return feature_sets


# ------------------------------
# XGBoost (from tdc_benchmarks.py)
# ------------------------------
def _fit_xgb_classifier(
    X_tr,
    y_tr,
    X_va,
    y_va,
    *,
    params: Dict[str, Any] | None = None,
    early_stopping_rounds=50,
) -> xgb.XGBClassifier:
    n_classes = int(len(np.unique(y_tr)))
    spw = 1.0
    if n_classes == 2:
        pos = (y_tr == 1).sum()
        neg = (y_tr == 0).sum()
        spw = float(neg) / max(pos, 1) if pos > 0 else 1.0
    default = dict(
        n_estimators=2000,
        max_depth=6,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=1.0,
        reg_alpha=0.0,
        objective="binary:logistic" if n_classes == 2 else "multi:softprob",
        eval_metric="auc" if n_classes == 2 else "mlogloss",
        random_state=42,
        tree_method="hist",
        n_jobs=-1,
        scale_pos_weight=spw,
    )
    if params:
        default.update(params)
    model = xgb.XGBClassifier(**default)
    eval_set = [(X_va, y_va)]
    fitted = False
    try:
        from xgboost.callback import EarlyStopping

        cb = [
            EarlyStopping(
                rounds=early_stopping_rounds,
                metric_name=default["eval_metric"],
                data_name="validation_0",
                save_best=True,
            )
        ]
        model.fit(X_tr, y_tr, eval_set=eval_set, callbacks=cb, verbose=False)
        fitted = True
    except Exception:
        pass
    if not fitted:
        try:
            model.fit(
                X_tr,
                y_tr,
                eval_set=eval_set,
                early_stopping_rounds=early_stopping_rounds,
                verbose=False,
            )
            fitted = True
        except TypeError:
            pass
    if not fitted:
        model.fit(X_tr, y_tr, eval_set=eval_set, verbose=False)
    return model


def _fit_xgb_regressor(
    X_tr,
    y_tr,
    X_va,
    y_va,
    *,
    params: Dict[str, Any] | None = None,
    early_stopping_rounds=50,
) -> xgb.XGBRegressor:
    default = dict(
        n_estimators=2000,
        max_depth=6,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=1.0,
        reg_alpha=0.0,
        objective="reg:squarederror",
        eval_metric="rmse",
        random_state=42,
        tree_method="hist",
        n_jobs=-1,
    )
    if params:
        default.update(params)
    model = xgb.XGBRegressor(**default)
    eval_set = [(X_va, y_va)]
    fitted = False
    try:
        from xgboost.callback import EarlyStopping

        cb = [
            EarlyStopping(
                rounds=early_stopping_rounds,
                metric_name=default["eval_metric"],
                data_name="validation_0",
                save_best=True,
            )
        ]
        model.fit(X_tr, y_tr, eval_set=eval_set, callbacks=cb, verbose=False)
        fitted = True
    except Exception:
        pass
    if not fitted:
        try:
            model.fit(
                X_tr,
                y_tr,
                eval_set=eval_set,
                early_stopping_rounds=early_stopping_rounds,
                verbose=False,
            )
            fitted = True
        except TypeError:
            pass
    if not fitted:
        model.fit(X_tr, y_tr, eval_set=eval_set, verbose=False)
    return model

def _best_n_estimators(model) -> int:
    return getattr(model, "best_iteration", 200)


def _refit_with_best_n(model, X, y, is_clf: bool):
    params = model.get_params()
    params["n_estimators"] = max(1, _best_n_estimators(model))
    new_model = xgb.XGBClassifier(**params) if is_clf else xgb.XGBRegressor(**params)
    new_model.fit(X, y, verbose=False)
    return new_model


# ------------------------------
# Dataset Loading (from tdc_benchmarks.py)
# ------------------------------
_FAMILIES = [("ADME", ADME), ("Tox", Tox), ("HTS", HTS)]


def load_single_pred_dataset(name: str):
    errors = []
    for fam_name, cls in _FAMILIES:
        try:
            return fam_name, cls(name=name)
        except Exception as e:
            errors.append((fam_name, str(e)))
    raise ValueError(f"Dataset '{name}' not found. Tried: {errors}")


# ------------------------------
# NEW: Feature Importance Logic
# ------------------------------
def get_feature_importances(model, top_n: int = 200) -> pd.DataFrame:
    """Extracts top feature importances from a trained XGBoost model."""
    importances = model.feature_importances_
    top_indices = np.argsort(importances)[::-1][:top_n]
    df = pd.DataFrame(
        {
            "global_feature_index": top_indices,
            "importance_score": importances[top_indices],
        }
    )
    return df[df.importance_score > 1e-6]  # Filter out non-contributing features


def map_global_to_local_sae_index(
    global_indices: np.ndarray,
    sae_layers: List[int],
    sae_kit: SAEKit,
    ecfp_bits: int,
    feature_set_name: str,
) -> pd.DataFrame:
    """Maps global feature indices back to their original SAE layer and in-layer index."""
    layer_dims = [sae_kit.sae_configs[L].latent_size for L in sae_layers]
    layer_offsets = np.cumsum([0] + layer_dims)
    total_sae_dims = sum(layer_dims)

    results = []
    for global_idx in global_indices:
        is_sae = 0 <= global_idx < total_sae_dims
        is_ecfp = (
            "ECFP" in feature_set_name
            and total_sae_dims <= global_idx < total_sae_dims + ecfp_bits
        )

        if is_sae:
            for i in range(len(layer_dims)):
                if layer_offsets[i] <= global_idx < layer_offsets[i + 1]:
                    results.append(
                        {
                            "global_feature_index": global_idx,
                            "type": "SAE",
                            "layer_id": sae_layers[i],
                            "feature_in_layer": global_idx - layer_offsets[i],
                        }
                    )
                    break
        elif is_ecfp:
            results.append(
                {
                    "global_feature_index": global_idx,
                    "type": "ECFP",
                    "layer_id": -1,
                    "feature_in_layer": global_idx - total_sae_dims,
                }
            )
    return pd.DataFrame(results)


# ------------------------------
# NEW: Main Analysis Function for One Seed
# ------------------------------
def analyze_feature_importance_one_seed(
    task_name: str,
    smi_col: str,
    split: Dict[str, pd.DataFrame],
    sae_kit: SAEKit,
    sae_layers: List[int],
    ecfp_bits: int,
    batch_size: int,
    early_stop: int,
    seed: int,
    model_save_dir: str | None = None,
) -> pd.DataFrame:
    """Trains models on SAE features and returns a dataframe of feature importances."""
    train_df, valid_df, test_df = split["train"], split["valid"], split["test"]
    tr_smiles = train_df[smi_col].tolist()
    va_smiles = valid_df[smi_col].tolist()
    te_smiles = test_df[smi_col].tolist()
    y_tr, y_va, y_te = train_df["Y"].values, valid_df["Y"].values, test_df["Y"].values
    task_kind = infer_task_type_from_y(np.concatenate([y_tr, y_va, y_te]))

    # Generate only the necessary feature sets
    sae_feature_sets = build_feature_sets(
        tr_smiles, va_smiles, te_smiles, sae_kit, sae_layers, 0, ecfp_bits, batch_size
    )

    all_importances = []
    for fname, (X_tr, X_va, _) in sae_feature_sets.items():
        print(f"    - Training model on feature set: '{fname}'")
        if task_kind == "classification":
            ytr_enc, yva_enc, _, _ = prepare_labels_for_classification(y_tr, y_va, y_te)
            model = _fit_xgb_classifier(
                X_tr, ytr_enc, X_va, yva_enc, early_stopping_rounds=early_stop
            )
            model = _refit_with_best_n(
                model,
                np.vstack([X_tr, X_va]),
                np.concatenate([ytr_enc, yva_enc]),
                is_clf=True,
            )
            if model_save_dir:
                os.makedirs(model_save_dir, exist_ok=True)
                model_path = os.path.join(model_save_dir, f"{task_name}_{fname.replace(' ', '')}_seed{seed}.joblib")
                joblib.dump(model, model_path)
                print(f"    - Saved trained model to {model_path}")
        else:
            model = _fit_xgb_regressor(
                X_tr, y_tr, X_va, y_va, early_stopping_rounds=early_stop
            )
            model = _refit_with_best_n(
                model,
                np.vstack([X_tr, X_va]),
                np.concatenate([y_tr, y_va]),
                is_clf=False,
            )

        importances_df = get_feature_importances(model, top_n=200)
        if not importances_df.empty:
            mapped_df = map_global_to_local_sae_index(
                importances_df["global_feature_index"].values,
                sae_layers,
                sae_kit,
                ecfp_bits,
                fname,
            )
            full_info = pd.merge(importances_df, mapped_df, on="global_feature_index")
            full_info["task"] = task_name
            full_info["feature_set"] = fname
            all_importances.append(full_info)

    return (
        pd.concat(all_importances, ignore_index=True)
        if all_importances
        else pd.DataFrame()
    )


# ------------------------------
# NEW: Main Runner
# ------------------------------
def run_analysis(args):
    print("Loading LM + SAE kit...")
    sae_kit = SAEKit.load(
        model_dir=args.model_dir, checkpoint_id=args.ckpt_id, sae_dir=args.sae_dir
    )
    sae_layers = [int(x) for x in args.sae_layers.split(",")]
    seeds = [int(s) for s in args.seeds.split(",") if s.strip()]
    tasks = [t.strip() for t in args.tasks.split(",") if t.strip()]

    os.makedirs(args.out_dir, exist_ok=True)
    per_seed_importances = []

    for name in tasks:
        print(f"\n=== Analyzing Task: {name} ===")
        _, ds = load_single_pred_dataset(name)

        for run_id, seed in enumerate(seeds, 1):
            print(f"  - Seed {seed} ({run_id}/{len(seeds)})")
            split = ds.get_split(method=args.split_method, seed=seed)
            smi_col = _find_smiles_column(split["train"])

            # Skip multi-label datasets
            if isinstance(split["train"]["Y"].iloc[0], (list, tuple, np.ndarray, dict)):
                print(f"    [WARN] {name} appears multi‑label. Skipping.")
                break

            df = analyze_feature_importance_one_seed(
                task_name=name,
                smi_col=smi_col,
                split=split,
                sae_kit=sae_kit,
                sae_layers=sae_layers,
                ecfp_bits=args.ecfp_bits,
                batch_size=args.batch_size,
                early_stop=args.early_stop,
                model_save_dir=args.model_save_dir,
                seed=seed,
            )
            if not df.empty:
                df["seed"] = seed
                per_seed_importances.append(df)

    if not per_seed_importances:
        print("\nNo feature importance results were generated.")
        return

    # Aggregate and save results
    final_df = pd.concat(per_seed_importances, ignore_index=True)
    agg_df = (
        final_df.groupby(
            ["task", "feature_set", "type", "layer_id", "feature_in_layer"]
        )
        .agg(
            mean_importance=("importance_score", "mean"),
            std_importance=("importance_score", "std"),
            num_seeds=("seed", "count"),
        )
        .reset_index()
        .sort_values(["task", "mean_importance"], ascending=[True, False])
    )

    final_df.to_csv(
        os.path.join(args.out_dir, "feature_importance_per_seed.csv"), index=False
    )
    agg_df.to_csv(
        os.path.join(args.out_dir, "feature_importance_summary.csv"), index=False
    )
    print(f"\nSaved feature importance analysis to '{args.out_dir}'")

    print("\n--- Top 5 Important SAE Features (averaged over seeds) ---")
    summary = agg_df[agg_df["type"] == "SAE"]
    for task in summary["task"].unique():
        print(f"\n--- Task: {task} ---")
        top5 = summary[summary["task"] == task].head(5)
        print(top5.to_string(index=False, float_format="%.4f"))


if __name__ == "__main__":
    ap = argparse.ArgumentParser(
        "Analyze SAE feature importance on TDC single-pred benchmarks"
    )
    ap.add_argument(
        "--model_dir", required=True, help="Path to the base language model directory."
    )
    ap.add_argument(
        "--model_save_dir",
        type=str,
        default="experiments/tdc_models",
        help="Directory to save trained XGBoost models.",
    )
    ap.add_argument(
        "--ckpt_id", required=True, help="Checkpoint ID for the language model."
    )
    ap.add_argument(
        "--sae_dir",
        required=True,
        help="Path to the directory containing trained SAEs.",
    )
    ap.add_argument(
        "--sae_layers",
        type=str,
        default="5",
        help="Comma-separated SAE layer IDs to use.",
    )
    ap.add_argument(
        "--ecfp_bits",
        type=int,
        default=2048,
        help="Number of bits for ECFP fingerprints.",
    )
    ap.add_argument(
        "--batch_size", type=int, default=64, help="Batch size for feature generation."
    )
    ap.add_argument(
        "--early_stop", type=int, default=50, help="Early stopping rounds for XGBoost."
    )
    ap.add_argument(
        "--split_method",
        type=str,
        default="scaffold",
        help="TDC data splitting method.",
    )
    ap.add_argument(
        "--tasks",
        type=str,
        required=True,
        help="Comma-separated list of TDC task names.",
    )
    ap.add_argument(
        "--seeds",
        type=str,
        default="1,2,3",
        help="Comma-separated seeds for repeated splits.",
    )
    ap.add_argument(
        "--out_dir",
        type=str,
        default="experiments/sae_importance_outputs",
        help="Directory to save analysis results.",
    )

    args = ap.parse_args()
    run_analysis(args)
