# experiments/downstream_eval.py
from __future__ import annotations
import argparse
import os
import time
from typing import Tuple, Dict, Any

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import xgboost as xgb
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.decomposition import PCA
from sklearn.metrics import (
    mean_squared_error,
    mean_absolute_error,
    roc_auc_score,
    average_precision_score,
)
from sklearn.preprocessing import LabelEncoder

# TDC
from tdc.single_pred import ADME

# --- Project imports ---
from lmkit.sparse.sae import SAEKit
from lmkit.sparse import utils as sae_utils
from lmkit.tools import data as data_tools

# ------------------------
# Feature generation
# ------------------------

def pad_and_stack(seqs, pad_value: int) -> np.ndarray:
    """
    Pad a list of 1-D integer sequences to a 2-D (B, T_max) array with pad_value.
    Returns np.int32, which JAX expects for token arrays.
    """
    if not seqs:
        return np.zeros((0, 0), dtype=np.int32)
    max_len = max(len(s) for s in seqs)
    out = np.full((len(seqs), max_len), pad_value, dtype=np.int32)
    for i, s in enumerate(seqs):
        arr = np.asarray(s, dtype=np.int32)
        out[i, : arr.shape[0]] = arr
    return out


def _find_smiles_column(df: pd.DataFrame) -> str:
    for cand in ("Drug", "SMILES", "smiles", "X"):
        if cand in df.columns:
            return cand
    raise KeyError(f"Could not find a SMILES column among: {list(df.columns)}")


def generate_ecfp_features(
    smiles_list: list[str], radius: int = 2, n_bits: int = 2048
) -> np.ndarray:
    feats = np.zeros((len(smiles_list), n_bits), dtype=np.float32)
    for i, s in enumerate(smiles_list):
        m = Chem.MolFromSmiles(s)
        if m is None:
            continue
        fp = AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=n_bits)
        # Convert to float32 for XGBoost (robust)
        arr = np.frombuffer(fp.ToBitString().encode(), "S1") == b"1"
        feats[i, :] = arr.astype(np.float32, copy=False)
    return feats


def generate_transformer_and_sae_features(
    smiles_list: list[str], sae_kit: SAEKit, sae_layer: int, batch_size: int = 64
) -> tuple[np.ndarray, np.ndarray]:
    dense_features, sae_features = [], []

    for i in range(0, len(smiles_list), batch_size):
        batch_smiles = smiles_list[i : i + batch_size]

        # Tokenize using your tokenizer
        batch_tokens = [sae_kit.tokenizer.encode(s) for s in batch_smiles]
        inputs_np = pad_and_stack(
            [b.ids for b in batch_tokens], sae_kit.tokenizer.pad_token_id
        )
        positions_np = pad_and_stack(
            [np.arange(len(b.ids)) for b in batch_tokens], -1
        )

        inputs = jnp.asarray(inputs_np)
        positions = jnp.asarray(positions_np)

        # Capture residuals for the SAE layer
        sae_config = sae_kit.sae_configs[sae_layer]
        residuals_dict = sae_utils.run_and_capture(
            sae_kit.run_fn,
            inputs,
            positions,
            sae_kit.lm_params,
            sae_kit.lm_config,
            sae_kit.hooks,
        )
        # (B, T, H)
        dense_embeddings_per_token = residuals_dict[
            (sae_config.layer_id, sae_config.placement)
        ]
        # (B, T, K)
        sae_activations_per_token = sae_kit.get_encoded(
            inputs, positions, layer_id=sae_layer
        )

        # Mask & pool
        token_mask = (positions >= 0).astype(jnp.float32)
        seq_lengths = token_mask.sum(axis=-1, keepdims=True)

        masked_dense = dense_embeddings_per_token * token_mask[..., None]
        dense_pooled = masked_dense.sum(axis=1) / jnp.maximum(seq_lengths, 1.0)
        dense_features.append(np.asarray(dense_pooled, dtype=np.float32))

        masked_sae = sae_activations_per_token * token_mask[..., None]
        sae_pooled = masked_sae.max(axis=1)  # max-pool sparse activations
        sae_features.append(np.asarray(sae_pooled, dtype=np.float32))

    return np.vstack(dense_features), np.vstack(sae_features)


# ------------------------
# Task helpers
# ------------------------


def infer_task_type(y: np.ndarray) -> str:
    """Robustly infer classification vs regression from labels."""
    y = pd.Series(y).dropna()
    if y.empty:
        # default to regression if unknown
        return "regression"
    # If labels look categorical / few unique values, treat as classification
    unique = np.unique(y.values)
    if y.dtype.kind in ("i", "u"):  # ints: likely classification if small cardinality
        return "classification" if unique.size <= 10 else "regression"
    if y.dtype.kind == "f":
        # floats but maybe {0.,1.} etc.
        if unique.size <= 10 and set(np.unique(np.round(unique, 6))).issubset(
            {0.0, 1.0}
        ):
            return "classification"
        return "regression"
    # strings/object -> classification
    return "classification"


def _prepare_labels_for_classification(y_train, y_valid, y_test):
    """Ensure labels are numeric and binary/multiclass-friendly."""
    # If they are strings or floats like 0.0/1.0, normalize them
    y_all = pd.concat(
        [pd.Series(y_train), pd.Series(y_valid), pd.Series(y_test)], axis=0
    )
    le = LabelEncoder()
    y_all_enc = le.fit_transform(y_all.astype(str))
    n_train = len(y_train)
    n_valid = len(y_valid)
    y_train_enc = y_all_enc[:n_train]
    y_valid_enc = y_all_enc[n_train : n_train + n_valid]
    y_test_enc = y_all_enc[n_train + n_valid :]
    classes = le.classes_
    return y_train_enc, y_valid_enc, y_test_enc, classes


def _ecfp_dim(n_bits: int) -> int:
    return n_bits


# ------------------------
# XGBoost training utils
# ------------------------


def _fit_xgb_classifier(
    X_tr,
    y_tr,
    X_val,
    y_val,
    *,
    params: Dict[str, Any] | None = None,
    early_stopping_rounds: int = 50,
) -> xgb.XGBClassifier:
    # Class imbalance helper (binary)
    if len(np.unique(y_tr)) == 2:
        pos = (y_tr == 1).sum()
        neg = (y_tr == 0).sum()
        spw = float(neg) / max(pos, 1) if pos > 0 else 1.0
    else:
        spw = 1.0

    n_classes = int(len(np.unique(y_tr)))
    default_params = dict(
        n_estimators=2000,
        max_depth=6,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=1.0,
        reg_alpha=0.0,
        objective="binary:logistic" if n_classes == 2 else "multi:softprob",
        eval_metric="auc" if n_classes == 2 else "mlogloss",
        random_state=42,
        tree_method="hist",
        n_jobs=-1,
        scale_pos_weight=spw,
    )
    if params:
        default_params.update(params)

    clf = xgb.XGBClassifier(**default_params)
    eval_set = [(X_val, y_val)]
    fitted = False

    # 1) Try callbacks-based early stopping (newer XGBoost)
    try:
        from xgboost.callback import EarlyStopping

        cb = [
            EarlyStopping(
                rounds=early_stopping_rounds,
                metric_name=default_params["eval_metric"],
                data_name="validation_0",
                save_best=True,
            )
        ]
        clf.fit(X_tr, y_tr, eval_set=eval_set, callbacks=cb)
        fitted = True
    except Exception:
        pass

    # 2) Try legacy kwarg in fit (older XGBoost that supports the kwarg)
    if not fitted:
        try:
            clf.fit(
                X_tr,
                y_tr,
                eval_set=eval_set,
                early_stopping_rounds=early_stopping_rounds,
            )
            fitted = True
        except TypeError:
            pass

    # 3) Fall back: no early stopping available
    if not fitted:
        clf.fit(X_tr, y_tr, eval_set=eval_set)

    return clf

def _fit_xgb_regressor(
    X_tr,
    y_tr,
    X_val,
    y_val,
    *,
    params: Dict[str, Any] | None = None,
    early_stopping_rounds: int = 50,
) -> xgb.XGBRegressor:
    default_params = dict(
        n_estimators=2000,
        max_depth=6,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=1.0,
        reg_alpha=0.0,
        objective="reg:squarederror",
        eval_metric="rmse",
        random_state=42,
        tree_method="hist",
        n_jobs=-1,
    )
    if params:
        default_params.update(params)

    reg = xgb.XGBRegressor(**default_params)
    eval_set = [(X_val, y_val)]
    fitted = False

    # 1) Callbacks-based ES
    try:
        from xgboost.callback import EarlyStopping

        cb = [
            EarlyStopping(
                rounds=early_stopping_rounds,
                metric_name=default_params["eval_metric"],
                data_name="validation_0",
                save_best=True,
            )
        ]
        reg.fit(X_tr, y_tr, eval_set=eval_set, callbacks=cb)
        fitted = True
    except Exception:
        pass

    # 2) Legacy kwarg
    if not fitted:
        try:
            reg.fit(
                X_tr,
                y_tr,
                eval_set=eval_set,
                early_stopping_rounds=early_stopping_rounds,
            )
            fitted = True
        except TypeError:
            pass

    # 3) No ES available
    if not fitted:
        reg.fit(X_tr, y_tr, eval_set=eval_set)

    return reg

    
def _best_n_estimators(model) -> int:
    # sklearn wrappers expose best_iteration_; booster exposes best_ntree_limit
    n = getattr(model, "best_iteration", None)
    if n is not None:
        return int(n)
    booster = getattr(model, "get_booster", None)
    if booster is not None:
        try:
            return int(booster().best_ntree_limit)
        except Exception:
            pass
    return int(getattr(model, "n_estimators", 200))


def _refit_with_best_n(model, X, y, is_clf: bool) -> Any:
    n_best = _best_n_estimators(model)
    params = model.get_params()
    params["n_estimators"] = max(1, n_best)
    if is_clf:
        new_model = xgb.XGBClassifier(**params)
    else:
        new_model = xgb.XGBRegressor(**params)
    new_model.fit(X, y, verbose=False)
    return new_model


# ------------------------
# Main evaluation
# ------------------------


def run_evaluation(args):
    print("--- Loading LM + SAE ---")
    sae_kit = SAEKit.load(
        model_dir=args.model_dir, checkpoint_id=args.ckpt_id, sae_dir=args.sae_dir
    )
    print(f"Transformer: {args.model_dir} (ckpt {args.ckpt_id})")
    print(f"SAEs:        {args.sae_dir}")

    tasks = (
        args.tasks.split(",")
        if args.tasks
        else ["BBB_Martins", "CYP2D6_Veith", "Solubility_AqSolDB", "HIA_Hou"]
    )

    all_rows = []
    os.makedirs(args.out_dir, exist_ok=True)

    for task_name in tasks:
        print(f"\n=== Task: {task_name} ===")
        data = ADME(
            name=task_name
        )  # ADME DataLoader (no get_task_type)  :contentReference[oaicite:1]{index=1}
        split = data.get_split(
            method=args.split_method
        )  # e.g., 'scaffold'  :contentReference[oaicite:2]{index=2}

        # detect SMILES column & labels
        train_df = split["train"]
        valid_df = split["valid"]
        test_df = split["test"]

        smi_col = _find_smiles_column(train_df)
        y_train = train_df["Y"].values
        y_valid = valid_df["Y"].values
        y_test = test_df["Y"].values

        task_kind = infer_task_type(np.concatenate([y_train, y_valid, y_test]))
        print(f"Detected task type: {task_kind}")

        # Build features
        print("  Generating ECFP…")
        tr_ecfp = generate_ecfp_features(
            train_df[smi_col].tolist(), n_bits=args.ecfp_bits
        )
        va_ecfp = generate_ecfp_features(
            valid_df[smi_col].tolist(), n_bits=args.ecfp_bits
        )
        te_ecfp = generate_ecfp_features(
            test_df[smi_col].tolist(), n_bits=args.ecfp_bits
        )

        print("  Generating Transformer/SAE features…")
        tr_dense, tr_sae = generate_transformer_and_sae_features(
            train_df[smi_col].tolist(),
            sae_kit,
            args.sae_layer,
            batch_size=args.batch_size,
        )
        va_dense, va_sae = generate_transformer_and_sae_features(
            valid_df[smi_col].tolist(),
            sae_kit,
            args.sae_layer,
            batch_size=args.batch_size,
        )
        te_dense, te_sae = generate_transformer_and_sae_features(
            test_df[smi_col].tolist(),
            sae_kit,
            args.sae_layer,
            batch_size=args.batch_size,
        )

        print("  PCA baseline on dense embeddings…")
        pca = PCA(n_components=args.pca_dims, random_state=42)
        tr_pca = pca.fit_transform(tr_dense)
        va_pca = pca.transform(va_dense)
        te_pca = pca.transform(te_dense)

        feature_sets = {
            "SAE Features": (tr_sae, va_sae, te_sae),
            "Transformer Embeddings": (tr_dense, va_dense, te_dense),
            "PCA on Embeddings": (tr_pca, va_pca, te_pca),
            "ECFP Fingerprints": (tr_ecfp, va_ecfp, te_ecfp),
        }

        # Train/Eval per feature set
        for fname, (X_tr, X_va, X_te) in feature_sets.items():
            print(
                f"  → {fname}  (train {X_tr.shape}, valid {X_va.shape}, test {X_te.shape})"
            )
            t0 = time.time()

            if task_kind == "classification":
                y_tr_enc, y_va_enc, y_te_enc, classes = (
                    _prepare_labels_for_classification(y_train, y_valid, y_test)
                )
                clf = _fit_xgb_classifier(
                    X_tr,
                    y_tr_enc,
                    X_va,
                    y_va_enc,
                    early_stopping_rounds=args.early_stop,
                )
                # Refit on train+valid with best n
                clf = _refit_with_best_n(
                    clf,
                    np.vstack([X_tr, X_va]),
                    np.concatenate([y_tr_enc, y_va_enc]),
                    is_clf=True,
                )

                # Predict proba for AUROC/AUPRC if binary; otherwise macro-OVR AUROC
                n_classes = len(np.unique(y_tr_enc))
                if n_classes == 2:
                    prob = clf.predict_proba(X_te)[:, 1]
                    try:
                        auroc = roc_auc_score(y_te_enc, prob)
                    except ValueError:
                        # Rare degenerate case: single class in test
                        auroc = np.nan
                    auprc = average_precision_score(y_te_enc, prob)
                    metrics = {"AUROC": auroc, "AUPRC": auprc}
                else:
                    # Multiclass AUROC macro-ovr
                    prob = clf.predict_proba(X_te)
                    auroc = roc_auc_score(
                        y_te_enc, prob, multi_class="ovr", average="macro"
                    )
                    metrics = {"AUROC_macro": auroc}
                train_secs = time.time() - t0

            else:  # regression
                reg = _fit_xgb_regressor(
                    X_tr, y_train, X_va, y_valid, early_stopping_rounds=args.early_stop
                )
                reg = _refit_with_best_n(
                    reg,
                    np.vstack([X_tr, X_va]),
                    np.concatenate([y_train, y_valid]),
                    is_clf=False,
                )
                preds = reg.predict(X_te)
                rmse = float(np.sqrt(mean_squared_error(y_test, preds)))
                mae = float(mean_absolute_error(y_test, preds))
                metrics = {"RMSE": rmse, "MAE": mae}
                train_secs = time.time() - t0

            row = {
                "Task": task_name,
                "Features": fname,
                **metrics,
                "Train Time (s)": train_secs,
            }
            print("     ", row)
            all_rows.append(row)

        # Save per-task intermediate CSV
        task_df = pd.DataFrame([r for r in all_rows if r["Task"] == task_name])
        task_out = os.path.join(
            args.out_dir, f"{task_name.replace('/', '_')}_results.csv"
        )
        task_df.to_csv(task_out, index=False)
        print(f"  Saved → {task_out}")

    # Summary pivot
    results_df = pd.DataFrame(all_rows)
    print("\n=== Summary ===")

    # Build a tidy printout handling different metric names
    def _fmt(group):
        # choose AUROC if available else RMSE
        if "AUROC" in group:
            return group["AUROC"].mean()
        if "AUROC_macro" in group:
            return group["AUROC_macro"].mean()
        if "RMSE" in group:
            return group["RMSE"].mean()
        return group.select_dtypes(include=[np.number]).mean(numeric_only=True)

    # Create per-task x per-feature wide table with the primary metric
    pivot_rows = []
    for (task, feat), g in results_df.groupby(["Task", "Features"]):
        primary = {}
        if "AUROC" in g.columns:
            primary["Score"] = g["AUROC"].iloc[-1]
            primary["Metric"] = "AUROC"
        elif "AUROC_macro" in g.columns:
            primary["Score"] = g["AUROC_macro"].iloc[-1]
            primary["Metric"] = "AUROC_macro"
        elif "RMSE" in g.columns:
            primary["Score"] = g["RMSE"].iloc[-1]
            primary["Metric"] = "RMSE"
        else:
            continue
        pivot_rows.append({"Task": task, "Features": feat, **primary})
    summary_wide = pd.DataFrame(pivot_rows).pivot(
        index="Task", columns="Features", values="Score"
    )
    # Column order
    column_order = [
        "SAE Features",
        "Transformer Embeddings",
        "PCA on Embeddings",
        "ECFP Fingerprints",
    ]
    summary_wide = summary_wide[[c for c in column_order if c in summary_wide.columns]]
    print(summary_wide.to_string(float_format="%.4f"))

    out_csv = os.path.join(args.out_dir, "downstream_results_summary.csv")
    summary_wide.to_csv(out_csv)
    print(f"\nSaved summary → {out_csv}")


if __name__ == "__main__":
    ap = argparse.ArgumentParser(
        "Downstream evaluation for Transformer & SAE features on TDC ADME tasks"
    )
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--ckpt_id", required=True)
    ap.add_argument("--sae_dir", required=True)
    ap.add_argument("--sae_layer", type=int, default=5)
    ap.add_argument("--pca_dims", type=int, default=256)
    ap.add_argument("--ecfp_bits", type=int, default=2048)
    ap.add_argument("--batch_size", type=int, default=64)
    ap.add_argument("--early_stop", type=int, default=50)
    ap.add_argument(
        "--split_method",
        type=str,
        default="scaffold",
        help="TDC split method, e.g. scaffold/random/butina",
    )
    ap.add_argument(
        "--tasks",
        type=str,
        default="",
        help="Comma-separated ADME dataset names (default: preset 4)",
    )
    ap.add_argument("--out_dir", type=str, default="experiments/downstream_outputs")

    args = ap.parse_args()
    run_evaluation(args)
