# previous_methods.py
import numpy as np
from scipy import sparse
import sklearn
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.neural_network import MLPRegressor
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import r2_score, mean_squared_error
import tensorflow as tf
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from packaging.version import Version
import warnings
from sklearn.model_selection import StratifiedKFold, ShuffleSplit, KFold
from sklearn.neighbors import NearestNeighbors
from sklearn.base import clone
from sklearn.neighbors import KernelDensity
import typing as _typing

# -------------------------
# Baselines 
# -------------------------
def compute_ll(task, fitted_models, x, y, eps=1e-8):
    res = "Error: Unknown task name specified"
    if task == "binary": 
        y_flat = y.to_numpy()
        ll_list = []
        for mdl in fitted_models.values():
            proba = mdl.predict_proba(x)
    
            class_map = {c:i for i,c in enumerate(mdl.classes_)}
            idx = np.array([class_map[val] for val in y_flat])
    
            true_p = proba[np.arange(len(y_flat)), idx]
            ll_list.append(np.log(true_p+eps).mean())
            res = tf.constant(ll_list, dtype=tf.float32)
    elif task == "regression":
        y_vec = np.asarray(y_train).ravel()
        res = []
        for yhat in yhat_train_list:
            yhat = np.asarray(yhat).ravel()
            resid  = y_vec - yhat
            sigma2 = float(np.var(resid, ddof=1)) + eps
            ll     = -0.5 * np.log(2.0 * np.pi * sigma2) - 0.5 * (resid**2) / sigma2
            res.append(ll.mean())
        res = tf.constant(res, dtype=tf.float32)
    else:
        raise ValueError("Unknown task (use 'binary' or 'regression').")
    return res

def baseline_methods(task, x_train, y_train, x_test, fitted_models, eps=1e-8):
    if task == "binary":
        # predictions
        pred_test = np.stack([m.predict_proba(x_test) for m in fitted_models.values()], axis=1)
    
        # Best single model
        base_ll = compute_ll(task, fitted_models, x_train, y_train) # log-likelihoods on training data
        best_idx     = np.argmax(base_ll)
        best   = pred_test[:, best_idx, :]
        
        # Uniform Ensemble
        uniform = pred_test.mean(axis=1)
        
        # Frequentistic averaging
        train_accs = np.array([accuracy_score(y_train, mdl.predict(x_train)) for mdl in fitted_models.values()])
        weights = train_accs / train_accs.sum()
        freq = (weights[None,:,None] * pred_test).sum(axis=1)
        
        # Standard BMA with uniform prior ∝ exp(train log‐evidence)
        w_bma = tf.exp(base_ll - tf.reduce_max(base_ll))
        w_bma /= tf.reduce_sum(w_bma)
        bma = (w_bma[None,:,None] * pred_test).numpy().sum(axis=1)
        
    elif task == "regression":
        y_tr = np.asarray(y_train).ravel()
        yhat_train_list = []
        for m in fitted_models.values():
            yhat_tr = None
            try:
                yhat_tr = np.asarray(m.predict(x_train)).ravel()
                if len(yhat_tr) != len(y_tr):
                    yhat_tr = None
            except Exception:
                yhat_tr = None
            if yhat_tr is None and hasattr(m, "yhat_train") and m.yhat_train is not None:
                yhat_tr = np.asarray(m.yhat_train).ravel()
                if len(yhat_tr) != len(y_tr):
                    raise ValueError("Provided yhat_train length does not match y_train.")
            if yhat_tr is None:
                raise ValueError("A model cannot produce training predictions and no yhat_train was provided.")
            yhat_train_list.append(yhat_tr)

        pred_test = np.stack([np.asarray(m.predict(x_test)).ravel()
                              for m in fitted_models.values()], axis=1)  

        ll_list = []
        for yhat in yhat_train_list:
            resid  = y_tr - yhat
            sigma2 = float(np.var(resid, ddof=1)) + eps
            ll     = -0.5 * np.log(2.0 * np.pi * sigma2) - 0.5 * (resid**2) / sigma2
            ll_list.append(ll.mean())
        base_ll = tf.constant(ll_list, dtype=tf.float32)

        # Best single model
        best_idx = int(tf.argmax(base_ll).numpy())
        best = pred_test[:, best_idx]

        # Uniform Ensemble
        uniform = pred_test.mean(axis=1)

        # Frequentistic averaging
        r2s = np.array([max(0.0, r2_score(y_tr, yhat)) for yhat in yhat_train_list], dtype=float)
        if r2s.sum() <= eps:
            mses = np.array([mean_squared_error(y_tr, yhat) for yhat in yhat_train_list], dtype=float)
            inv  = 1.0 / (mses + eps)
            w_f  = inv / inv.sum()
        else:
            w_f = r2s / r2s.sum()
        freq = (w_f[None, :] * pred_test).sum(axis=1)

        # BMA - weights from train log-evidence
        w_bma = tf.exp(base_ll - tf.reduce_max(base_ll))
        w_bma = w_bma / tf.reduce_sum(w_bma)
        bma   = (w_bma[None, :].numpy() * pred_test).sum(axis=1)

    else:
        raise ValueError("Unknown task (use 'binary' or 'regression').")
        
    return best, uniform, freq, bma


def moe_gating(
    x_train,
    y_train_or_oh,                  
    fitted_models,
    feature_dim,
    num_models,
    MoEGatingNet,
    *,
    task="binary",                  
    base_pred_train=None,           
    epochs=10,
    lr=1e-3,
    l2=0.0,
    batch_size=64,
    verbose=False,
    gate_features=None,            
):
    # 1) Base predictions
    if base_pred_train is None:
        if task == "binary":
            preds = [m.predict_proba(x_train) for m in fitted_models.values()]
            base_pred_train = np.stack(preds, axis=1).astype(np.float32)   # [N, M, C]
        elif task == "regression":
            preds = [np.asarray(m.predict(x_train)).ravel() for m in fitted_models.values()]
            base_pred_train = np.stack(preds, axis=1).astype(np.float32)   # [N, M]
        else:
            raise ValueError("task must be 'binary' or 'regression'")
    else:
        base_pred_train = np.asarray(base_pred_train, dtype=np.float32)

    # 2) Gate inputs
    if gate_features is not None:
        x_tf = tf.convert_to_tensor(np.asarray(gate_features, dtype=np.float32))
    else:
        try:
            x_tf = tf.convert_to_tensor(x_train, dtype=tf.float32)
        except Exception:
            bp_flat = base_pred_train.reshape(len(base_pred_train), -1)
            x_tf = tf.convert_to_tensor(bp_flat, dtype=tf.float32)

    use_feat_dim = int(x_tf.shape[-1])
    bp_tf = tf.convert_to_tensor(base_pred_train, dtype=tf.float32)

    net = MoEGatingNet(use_feat_dim, num_models)
    opt = tf.keras.optimizers.Adam(learning_rate=lr)

    N = int(x_tf.shape[0])

    if task == "binary":
        C = int(bp_tf.shape[-1])
        if C is None or C <= 1:
            raise ValueError("For binary/classification, base_pred_train must be [N,M,C] with C>=2.")
        y_arr = y_train_or_oh
        if isinstance(y_arr, np.ndarray) and (y_arr.ndim == 1 or (y_arr.ndim == 2 and y_arr.shape[1] == 1)):
            y_oh = tf.one_hot(tf.convert_to_tensor(y_arr.ravel(), dtype=tf.int32), depth=C)
        else:
            y_oh = tf.convert_to_tensor(y_arr, dtype=tf.float32)
            if y_oh.shape.rank == 1 or (y_oh.shape.rank == 2 and y_oh.shape[-1] == 1):
                y_oh = tf.one_hot(tf.cast(tf.reshape(y_oh, [-1]), tf.int32), depth=C)

        ds = tf.data.Dataset.from_tensor_slices((x_tf, y_oh, bp_tf)).shuffle(N).batch(batch_size)
        ce = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

        for ep in range(1, epochs + 1):
            epoch_loss_sum = 0.0
            epoch_count = 0
            for xb, yb, pb in ds:
                with tf.GradientTape() as tape:
                    logits = net(xb)                          
                    w = tf.nn.softmax(logits, axis=1)         
                    mix = tf.einsum('bm,bmc->bc', w, pb)      
                    mix = tf.clip_by_value(mix, 1e-7, 1.0)
                    loss = ce(yb, mix)
                    if l2 > 0:
                        loss += l2 * tf.add_n([tf.nn.l2_loss(v) for v in net.trainable_variables])
                grads = tape.gradient(loss, net.trainable_variables)
                opt.apply_gradients(zip(grads, net.trainable_variables))
                bsz = int(xb.shape[0]); epoch_loss_sum += float(loss) * bsz; epoch_count += bsz
            if verbose:
                print(f"[MoE-gate/binary] epoch {ep}/{epochs} loss={epoch_loss_sum/max(epoch_count,1):.4f}")

    else:  # regression 
        y_tf = tf.convert_to_tensor(np.asarray(y_train_or_oh).ravel(), dtype=tf.float32)
        ds = tf.data.Dataset.from_tensor_slices((x_tf, y_tf, bp_tf)).shuffle(N).batch(batch_size)
        mse = tf.keras.losses.MeanSquaredError()

        for ep in range(1, epochs + 1):
            epoch_loss_sum = 0.0
            epoch_count = 0
            for xb, yb, pb in ds:
                with tf.GradientTape() as tape:
                    logits = net(xb)                          
                    w = tf.nn.softmax(logits, axis=1)         
                    mix = tf.reduce_sum(w * pb, axis=1)       
                    loss = mse(yb, mix)
                    if l2 > 0:
                        loss += l2 * tf.add_n([tf.nn.l2_loss(v) for v in net.trainable_variables])
                grads = tape.gradient(loss, net.trainable_variables)
                opt.apply_gradients(zip(grads, net.trainable_variables))
                bsz = int(xb.shape[0]); epoch_loss_sum += float(loss) * bsz; epoch_count += bsz
            if verbose:
                print(f"[MoE-gate/reg] epoch {ep}/{epochs} loss={epoch_loss_sum/max(epoch_count,1):.6f}")

    return net

def moe_gating_predict(net, x, base_pred, *, task="binary", return_weights=False):
    bp = np.asarray(base_pred, dtype=np.float32)

    try:
        x_tf = tf.convert_to_tensor(x, dtype=tf.float32)
    except Exception:
        bp_flat = bp.reshape(len(bp), -1)   
        x_tf = tf.convert_to_tensor(bp_flat, dtype=tf.float32)

    bp_tf = tf.convert_to_tensor(bp)        
    logits = net(x_tf)
    w = tf.nn.softmax(logits, axis=1)

    if task == "binary":
        mix = tf.einsum('nm,nmc->nc', w, bp_tf)  
        if return_weights:
            return mix.numpy(), w.numpy()
        return mix.numpy()
    elif task == "regression":
        mix = tf.reduce_sum(w * bp_tf, axis=1)   
        if return_weights:
            return mix.numpy(), w.numpy()
        return mix.numpy()
    else:
        raise ValueError("task must be 'binary' or 'regression'")

def moe_prediction(
    x_test, gate_net, fitted_models, *,
    task="binary",
    base_pred_test=None,
    gate_features=None,   
    return_weights=False
):
    if base_pred_test is None:
        if task == "binary":
            base_pred_test = np.stack([m.predict_proba(x_test) for m in fitted_models.values()], axis=1).astype(np.float32)  # [T,M,C]
        else:
            base_pred_test = np.column_stack([np.asarray(m.predict(x_test)).ravel() for m in fitted_models.values()]).astype(np.float32)  # [T,M]
    else:
        base_pred_test = np.asarray(base_pred_test, dtype=np.float32)

    # gate inputs
    if gate_features is not None:
        Xg = np.asarray(gate_features, dtype=np.float32)
    else:
        try:
            Xg = np.asarray(x_test, dtype=np.float32)
        except Exception:
            Xg = base_pred_test.reshape(base_pred_test.shape[0], -1).astype(np.float32)

    logits = gate_net(Xg, training=False)
    W = tf.nn.softmax(logits, axis=1).numpy()  # [T,M]

    if task == "binary":
        mix = np.einsum('tm,tmc->tc', W, base_pred_test)   # [T,C]
    else:
        mix = np.einsum('tm,tm->t', W, base_pred_test)     # [T]

    return (mix, W) if return_weights else mix

# -----------------------
# Dynamic Local Accuracy
# -----------------------

def compute_cv_pred_probas(models, X, y, cv=5, n_jobs=None):

    import numpy as np
    from sklearn.model_selection import StratifiedKFold
    import copy

    def _rows(A, idx):
        return A.iloc[idx] if hasattr(A, "iloc") else A[idx]

    def _decision_to_proba(df):
        df = np.asarray(df)
        if df.ndim == 1:
            sig = 1 / (1 + np.exp(-df))
            return np.c_[(1 - sig), sig].astype(np.float32)
        else:
            e = np.exp(df - df.max(axis=1, keepdims=True))
            sm = e / e.sum(axis=1, keepdims=True)
            return sm.astype(np.float32)

    def _safe_reinit_estimator(mdl):
        try:
            if hasattr(mdl, "get_params"):
                params = mdl.get_params(deep=False)
                try:
                    return mdl.__class__(**params)
                except Exception:
                    pass
        except Exception:
            pass
        try:
            return mdl.__class__()
        except Exception:
            pass
        try:
            md = copy.deepcopy(mdl)
            if hasattr(md, "reset"):
                md.reset()
            elif hasattr(md, "initialize"):
                md.initialize()
            return md
        except Exception:
            return mdl  

    y_flat = np.asarray(y).ravel()
    N = X.shape[0]
    model_names = list(models.keys())
    M = len(model_names)

    C = None
    for name in model_names:
        mdl = models[name]
        try:
            idx = np.arange(min(N, 8))
            mdl_tmp = _safe_reinit_estimator(mdl)
            mdl_tmp.fit(_rows(X, idx), y_flat[idx])
            proba = mdl_tmp.predict_proba(_rows(X, idx[:1]))
            C = proba.shape[-1]
            break
        except Exception:
            continue
    if C is None:
        C = len(np.unique(y_flat))

    res = np.zeros((N, M, C), dtype=np.float32)
    skf = StratifiedKFold(n_splits=cv, shuffle=True)

    for m_idx, name in enumerate(model_names):
        mdl_orig = models[name]
        fold_preds = np.zeros((N, C), dtype=np.float32)

        for train_idx, val_idx in skf.split(X, y_flat):
            mdl = _safe_reinit_estimator(mdl_orig)
            mdl.fit(_rows(X, train_idx), y_flat[train_idx])
            try:
                proba = mdl.predict_proba(_rows(X, val_idx))
            except Exception:
                df = mdl.decision_function(_rows(X, val_idx))
                proba = _decision_to_proba(df)
            if proba.shape[1] != C:
                proba_fixed = np.zeros((proba.shape[0], C), dtype=np.float32)
                classes = getattr(mdl, "classes_", np.arange(C))
                for ci, cls in enumerate(classes):
                    if int(cls) < C:
                        proba_fixed[:, int(cls)] = proba[:, ci]
                proba = proba_fixed
            fold_preds[val_idx] = proba.astype(np.float32)

        res[:, m_idx, :] = fold_preds

    return res
 

def compute_oos_pred_values(
    models, X, y, *,
    mode="cv",              # "cv" or "holdout"
    cv=5,
    holdout_size=0.2,
    components_cap=320,     
    mlp_context=None,       # expects {"mlp_params": {...}, "RSEED": int (optional)}
    pre_builder=None,
    sanitizer=None,
    infer_keep=None,
    svd_params=None,        
):

    # ---------- built-in fallbacks ----------
    def _default_infer_keep(X_df):
        return [c for c in X_df.columns if c != "compound"]

    def _default_sanitizer(A):
        if sparse.issparse(A):
            A = A.tocsr(copy=True)
            d = A.data
            bad_inf = ~np.isfinite(d)
            if bad_inf.any(): d[bad_inf] = np.nan
            bad_nan = np.isnan(d)
            if bad_nan.any(): d[bad_nan] = 0.0
            return A
        else:
            return np.nan_to_num(A, nan=0.0, posinf=0.0, neginf=0.0)

    def _default_pre_builder(X_df, compound_col="compound"):
        ohe_kwargs = {"handle_unknown": "ignore"}
        if Version(sklearn.__version__) >= Version("1.2"):
            ohe_kwargs["sparse_output"] = True
        else:
            ohe_kwargs["sparse"] = True
        ohe = OneHotEncoder(**ohe_kwargs)

        keep = (infer_keep or _default_infer_keep)(X_df)
        pre = ColumnTransformer(
            transformers=[
                ("num", Pipeline([
                    ("impute", SimpleImputer(strategy="median")),
                    ("scale",  StandardScaler(with_mean=False)),  # safe on sparse
                ]), list(keep)),
                ("drug", ohe, [compound_col]),
            ],
            sparse_threshold=1.0
        )
        return pre

    infer_keep = infer_keep or _default_infer_keep
    sanitizer  = sanitizer  or _default_sanitizer
    pre_builder= pre_builder or _default_pre_builder

    # ---------- utilities ----------
    def _row_subset(A, idx):
        try:  return A.iloc[idx]
        except AttributeError:  return A[idx]

    def _to_numpy(y_like):
        try:  return y_like.to_numpy().ravel()
        except AttributeError:  return np.asarray(y_like).ravel()

    # ---------- splits ----------
    y_flat = _to_numpy(y)
    N = len(y_flat)
    names = list(models.keys()); M = len(names)
    out = np.full((N, M), np.nan, dtype=np.float32)

    if mode == "cv":
        splits = list(KFold(n_splits=cv, shuffle=True).split(np.arange(N), y_flat))
    elif mode == "holdout":
        splits = list(ShuffleSplit(n_splits=1, test_size=holdout_size).split(np.arange(N)))
    else:
        raise ValueError("mode must be 'cv' or 'holdout'")

    seed = (mlp_context or {}).get("RSEED")

    for m_idx, name in enumerate(names):
        mdl_orig = models[name]
        oos = np.full((N,), np.nan, dtype=np.float32)
        can_fit = hasattr(mdl_orig, "fit")
        use_mlp_path = (not can_fit)

        if use_mlp_path and mlp_context is None:
            if hasattr(mdl_orig, "yhat_train") and mdl_orig.yhat_train is not None:
                for (tr_idx, va_idx) in splits:
                    oos[va_idx] = np.asarray(mdl_orig.yhat_train).ravel()[va_idx]
                out[:, m_idx] = oos
                continue
            else:
                raise ValueError(
                    f"Model '{name}' has no .fit; provide mlp_context with mlp_params and RSEED, "
                    "or supply yhat_train on the wrapper for a fast fallback."
                )

        for tr_idx, va_idx in splits:
            X_tr = _row_subset(X, tr_idx)
            y_tr = _to_numpy(_row_subset(y, tr_idx))
            X_va = _row_subset(X, va_idx)

            if not use_mlp_path:
                try:
                    mdl = clone(mdl_orig)
                except Exception:
                    warnings.warn(f"[compute_oos_pred_values] Could not clone '{name}', using original instance.")
                    mdl = mdl_orig
                mdl.fit(X_tr, y_tr)
                y_hat = np.asarray(mdl.predict(X_va)).ravel().astype(np.float32)
                oos[va_idx] = y_hat
            else:
                pre = pre_builder(X_tr)                  
                Xtr_sparse = pre.fit_transform(X_tr)
                Xva_sparse = pre.transform(X_va)

                Xtr_sparse = sanitizer(Xtr_sparse)
                Xva_sparse = sanitizer(Xva_sparse)

                if sparse.issparse(Xtr_sparse):
                    Xtr_dense = Xtr_sparse.toarray()
                    Xva_dense = Xva_sparse.toarray()
                else:
                    Xtr_dense = np.asarray(Xtr_sparse)
                    Xva_dense = np.asarray(Xva_sparse)

                sc = StandardScaler(with_mean=True)
                Xtr_pred = sc.fit_transform(Xtr_dense)
                Xva_pred = sc.transform(Xva_dense)

                mlp_params = dict((mlp_context or {}).get("mlp_params", {}))
                mlp = MLPRegressor(**mlp_params)

                mlp.fit(Xtr_pred, y_tr.astype(np.float32))
                y_hat = mlp.predict(Xva_pred).astype(np.float32)
                oos[va_idx] = y_hat

        out[:, m_idx] = oos

    return out


def _ensure_numeric_feats(X, fitted_models, *, task="binary"):
       try:
        return np.asarray(X, dtype=np.float32)
    except Exception:
        if task == "binary":
            bp = np.stack([m.predict_proba(X) for m in fitted_models.values()], axis=1).astype(np.float32)  
        else:
            bp = np.stack([np.asarray(m.predict(X)).ravel() for m in fitted_models.values()], axis=1).astype(np.float32)  
        return bp.reshape(bp.shape[0], -1)
        
def dynamic_local_accuracy(
    x_train, y_train, x_test,
    base_pred_train_cv,          
    fitted_models,
    k=50, temp=1.0, alpha=1.0,
    *,
    task="binary",
    X_train_feat=None,           
    X_test_feat=None,
    base_pred_train_oos=None,
    base_pred_test=None,
    return_weights=False
):
    import numpy as np
    from sklearn.neighbors import NearestNeighbors

    if task == "binary":
        if X_train_feat is None:
            Xtr_np = _ensure_numeric_feats(x_train, fitted_models, task="binary")
        else:
            Xtr_np = np.asarray(X_train_feat, dtype=np.float32)

        if X_test_feat is None:
            Xte_np = _ensure_numeric_feats(x_test, fitted_models, task="binary")
        else:
            Xte_np = np.asarray(X_test_feat, dtype=np.float32)

        base_pred_test_bin = np.stack(
            [m.predict_proba(x_test) for m in fitted_models.values()],
            axis=1
        ).astype(np.float32)  

        bpt = np.asarray(base_pred_train_cv)
        if bpt.ndim != 3:
            raise ValueError(f"base_pred_train_cv must be [N,M,C], got {bpt.shape}")
        N, M, C_cv = bpt.shape
        C_test = base_pred_test_bin.shape[-1]
        if C_cv != C_test:
            if C_test == 2 and C_cv == 1:
                p = np.clip(bpt[..., 0], 0.0, 1.0)
                bpt2 = np.empty((N, M, 2), dtype=bpt.dtype)
                bpt2[..., 1] = p
                bpt2[..., 0] = 1.0 - p
                bpt = bpt2
            else:
                bpt2 = np.zeros((N, M, C_test), dtype=bpt.dtype)
                kcopy = min(C_cv, C_test)
                bpt2[..., :kcopy] = bpt[..., :kcopy]
                bpt = bpt2
        base_pred_train_cv = bpt
        # ----------------------------------------------------------------------

        y_tr = np.asarray(y_train).ravel().astype(int)
        T = Xte_np.shape[0]
        C = base_pred_test_bin.shape[-1]

        try:
            nn = NearestNeighbors(n_neighbors=min(k, N), algorithm="auto").fit(Xtr_np)
            def knn_indices(i, _k):
                return nn.kneighbors(Xte_np[i:i+1], n_neighbors=min(_k, N), return_distance=False)[0]
        except Exception:
            def knn_indices(i, _k):
                d = np.sum((Xtr_np - Xte_np[i])**2, axis=1)
                return np.argsort(d)[:min(_k, N)]

        cv_pred_labels = np.argmax(base_pred_train_cv, axis=2)  # [N,M]
        res = np.zeros((T, C), dtype=np.float32)
        for i in range(T):
            idx = knn_indices(i, k)
            correct = (cv_pred_labels[idx] == y_tr[idx, None]).mean(axis=0)   # [M]
            smoothed = (correct * k + alpha) / (k + 2 * alpha)                # light smoothing
            w = np.exp(smoothed / max(temp, 1e-6))
            w = w / w.sum()
            res[i] = (w[None, :, None] * base_pred_test_bin[i:i+1]).sum(axis=1).squeeze(0)
        return res

    elif task == "regression":
        if X_train_feat is None or X_test_feat is None or base_pred_train_oos is None or base_pred_test is None:
            raise ValueError("For task='regression', please provide X_train_feat, X_test_feat, base_pred_train_oos, and base_pred_test.")
        y_tr = np.asarray(y_train).ravel().astype(np.float32)
        N, M = base_pred_train_oos.shape
        T    = X_test_feat.shape[0]

        nn = NearestNeighbors(n_neighbors=min(k, N), algorithm="auto").fit(X_train_feat)
        res = np.zeros((T,), dtype=np.float32)
        ws = []

        for i in range(T):
            idx = nn.kneighbors(X_test_feat[i:i+1], return_distance=False)[0]
            y_loc = y_tr[idx]
            preds_loc = base_pred_train_oos[idx, :]
            mask = ~np.isnan(preds_loc)

            local_mse = np.empty((M,), dtype=np.float32)
            for m in range(M):
                m_mask = mask[:, m]
                if m_mask.any():
                    err = y_loc[m_mask] - preds_loc[m_mask, m]
                    mse = float(np.mean(err**2))
                else:
                    g_mask = ~np.isnan(base_pred_train_oos[:, m])
                    mse = float(np.mean((y_tr[g_mask] - base_pred_train_oos[g_mask, m])**2)) if g_mask.any() else np.inf
                local_mse[m] = (mse * k + alpha) / (k + alpha)

            scores = -local_mse / max(temp, 1e-6)
            scores -= np.nanmax(scores)
            w = np.exp(scores); s = w.sum()
            w = w / s if np.isfinite(s) and s > 0 else np.ones_like(w) / M
            ws.append(w)

            res[i] = float(np.dot(w, base_pred_test[i, :]))
        if return_weights:
            return res, np.array(ws)
        else:
            return res
    else:
        raise ValueError("task must be 'binary' or 'regression'")


# ---------------------------------
# Synthetic Model Combination (SMC)
# ---------------------------------
# Adapted from the official repository: https://github.com/XanderJC/synthetic-model-combination

def _auto_bandwidth(X):
    if hasattr(X, "toarray"):
        X = X.toarray()
    std = np.std(X, axis=0, ddof=1) + 1e-12
    return 0.5 * float(np.mean(std))

def _fit_kdes_per_model(x_train, mask_list, bandwidth=None, min_cov=20, bw_scale=1.0):
    kdes = []
    # densify if sparse
    if hasattr(x_train, "toarray"):
        X = x_train.toarray()
    else:
        X = np.asarray(x_train)

    if bandwidth is None:
        std = np.std(X, axis=0, ddof=1) + 1e-12
        bandwidth = 0.5 * float(np.mean(std))
    bw = max(bandwidth * float(bw_scale), 1e-3)

    for mask in mask_list:
        idx = np.where(mask)[0]
        n = len(idx)
        if n == 0:
            kdes.append(("const_low", -1e6))
            continue
        X_m = X[idx]
        if KernelDensity is None or n < min_cov:
            mu = np.nanmean(X_m, axis=0)
            if not np.all(np.isfinite(mu)):
                kdes.append(("const_low", -1e6))
                continue
            cov = np.cov(X_m.T) if n > 1 else np.eye(X.shape[1])
            d = cov.shape[0]
            cov = np.asarray(cov) + 1e-3 * np.eye(d)
            kdes.append(("gaussian_fallback", mu, cov))
        else:
            kde = KernelDensity(kernel="gaussian", bandwidth=bw).fit(X_m)
            kdes.append(("sk_kde", kde))
    return kdes


def _score_kdes(kdes, X):
    if hasattr(X, "toarray"):
        X = X.toarray()
    else:
        X = np.asarray(X)

    T = X.shape[0]
    scores = np.zeros((T, len(kdes)), dtype=np.float32)
    for m, item in enumerate(kdes):
        tag = item[0]
        if tag == "sk_kde":
            scores[:, m] = item[1].score_samples(X)  # log-density
        elif tag == "gaussian_fallback":
            _, mu, cov = item
            if (not np.all(np.isfinite(mu))) or (not np.all(np.isfinite(cov))):
                scores[:, m] = -1e6
            else:
                try:
                    inv = np.linalg.inv(cov)
                except np.linalg.LinAlgError:
                    inv = np.linalg.pinv(cov)
                diff = X - mu
                scores[:, m] = -0.5 * np.sum(diff @ inv * diff, axis=1)
        else:  # "const_low"
            scores[:, m] = -1e6
    return scores
    

def smc_ensemble(
    x_train, y_train, x_test,
    base_pred_train_cv,          
    fitted_models,
    conf_threshold=0.6,
    bandwidth=None,
    temp=1.0,
    *,
    task="binary",               
    # --- extra args used only for regression ---
    X_train_feat=None,           
    X_test_feat=None,
    base_pred_train_oos=None,    
    base_pred_test=None,         
    cover_q=0.30,
    return_weights=False                
):
    import numpy as np
    from sklearn.neighbors import KernelDensity

    if task == "binary":
        base_pred_test = np.stack([m.predict_proba(x_test) for m in fitted_models.values()], axis=1)
    
        y_tr = (y_train.to_numpy() if hasattr(y_train, "to_numpy") else np.asarray(y_train)).astype(int)
    
        N, M, C = base_pred_train_cv.shape
    
        cv_pred_labels = np.argmax(base_pred_train_cv, axis=2)     
        cv_pred_conf   = np.max(base_pred_train_cv, axis=2)         
        covered_masks = [(cv_pred_labels[:, m] == y_tr) & (cv_pred_conf[:, m] >= conf_threshold)
                         for m in range(M)]
    
        if X_train_feat is None or X_test_feat is None:
            import sklearn
            from packaging import version
            from sklearn.compose import ColumnTransformer
            from sklearn.pipeline import Pipeline
            from sklearn.preprocessing import OneHotEncoder, FunctionTransformer, StandardScaler
            from sklearn.impute import SimpleImputer
            import pandas as pd
    
            def _to_str_preserve_nan(X):
                arr = np.asarray(X, dtype=object)
                mask = pd.isna(arr)
                arr[~mask] = arr[~mask].astype(str)
                return arr
    
            ohe_params = dict(handle_unknown="ignore")
            if version.parse(sklearn.__version__) >= version.parse("1.2"):
                ohe_params["sparse_output"] = True
            else:
                ohe_params["sparse"] = True
            if version.parse(sklearn.__version__) >= version.parse("1.1"):
                ohe_params["min_frequency"] = 10
    
            if hasattr(x_train, "select_dtypes"):
                cat_cols = x_train.select_dtypes(include=["object", "category"]).columns.tolist()
                num_cols = x_train.columns.difference(cat_cols).tolist()
            else:
                cat_cols, num_cols = [], list(range(x_train.shape[1]))
    
            cat_pipe = Pipeline([
                ("imp", SimpleImputer(strategy="most_frequent")),
                ("to_str", FunctionTransformer(_to_str_preserve_nan, feature_names_out="one-to-one")),
                ("ohe", OneHotEncoder(**ohe_params))
            ])
            num_pipe = Pipeline([
                ("imp", SimpleImputer(strategy="median")),
                ("scaler", StandardScaler(with_mean=False))  # sparse-friendly
            ])
    
            enc = ColumnTransformer(
                [("num", num_pipe, num_cols),
                 ("cat", cat_pipe, cat_cols)],
                remainder="drop"
            )
    
            X_train_feat = enc.fit_transform(x_train)  # can be sparse
            X_test_feat  = enc.transform(x_test)
    
        kdes = _fit_kdes_per_model(X_train_feat, covered_masks, bandwidth=bandwidth)
        log_scores = _score_kdes(kdes, X_test_feat)
    
        w = np.exp(log_scores / max(temp, 1e-6))
        w = w / np.clip(w.sum(axis=1, keepdims=True), 1e-12, None)
    
        res = (w[:, :, None] * base_pred_test).sum(axis=1)  # [T, C]
        return (res, w) if return_weights else res


    elif task == "regression":
        if X_train_feat is None or X_test_feat is None or base_pred_train_oos is None or base_pred_test is None:
            raise ValueError("For task='regression', provide X_train_feat, X_test_feat, base_pred_train_oos, base_pred_test.")

        y_tr = np.asarray(y_train).ravel().astype(np.float32)
        N, M = base_pred_train_oos.shape
        T    = X_test_feat.shape[0]

        cover_q       = float(cover_q)     
        min_cov       = 50                 
        temp          = float(max(temp, 1e-6))
        safe_gaussian = True              
        
        covered_masks = []
        for m in range(M):
            yhat_m = base_pred_train_oos[:, m]
            mask_m = ~np.isnan(yhat_m)
            if mask_m.sum() < 5:
                covered_masks.append(np.zeros(N, dtype=bool)); continue
            resid = np.abs(y_tr[mask_m] - yhat_m[mask_m])
            thr = np.quantile(resid, cover_q)
            kept_idx = np.where(mask_m)[0][resid <= thr]
            if kept_idx.size < min_cov:
                order = np.argsort(resid)[:min_cov]
                kept_idx = np.where(mask_m)[0][order]
            cov_mask = np.zeros(N, dtype=bool); cov_mask[kept_idx] = True
            covered_masks.append(cov_mask)

        Xtr = np.asarray(X_train_feat, dtype=np.float32)
        Xte = np.asarray(X_test_feat,  dtype=np.float32)

        lam = 0.9
        global_var = np.var(Xtr, axis=0, ddof=1) + 1e-6
        D = Xtr.shape[1]

        log_scores = np.full((T, M), -1e6, dtype=np.float32)
        for m in range(M):
            idx = np.where(covered_masks[m])[0]
            if idx.size < 3:
                continue  
            Xm = Xtr[idx]
            mu = np.mean(Xm, axis=0)
            if idx.size > D:
                emp = np.cov(Xm.T)
            else:
                emp = np.diag(np.var(Xm, axis=0, ddof=1) + 1e-6)
            cov = (1.0 - lam) * emp + lam * np.diag(global_var)
            try:
                inv = np.linalg.inv(cov)
            except np.linalg.LinAlgError:
                inv = np.linalg.pinv(cov, rcond=1e-6)
            diff = Xte - mu
            try:
                sign, logdet = np.linalg.slogdet(cov)
                if sign <= 0:
                    logdet = np.log(np.clip(np.linalg.det(cov), 1e-12, None))
            except Exception:
                logdet = np.log(np.clip(np.linalg.det(cov + 1e-6*np.eye(D)), 1e-12, None))
            quad = np.einsum('nd,dd,nd->n', diff, inv, diff)
            log_scores[:, m] = (-0.5 * quad - 0.5 * logdet).astype(np.float32)


        log_scores = log_scores - np.max(log_scores, axis=1, keepdims=True)
        w = np.exp(log_scores / temp)
        w_sum = np.clip(np.sum(w, axis=1, keepdims=True), 1e-12, None)
        w /= w_sum

        # Combine scalar predictions
        res = np.sum(w * base_pred_test, axis=1).astype(np.float32)
        if return_weights:
            out = res, w
        else:
            out = res   
        return out

# -----------------------------------
# Bayesian Hierarchical Stacking (BHS)
# -----------------------------------

class _BHSGate(tf.keras.Model):

    def __init__(self, feature_dim: int, num_models: int,
                 temp: float = 1.0, s0: float = 5.0):
        super().__init__()
        self.D = int(feature_dim)
        self.M = int(num_models)
        self.temp = float(max(temp, 1e-6))
        self.s0 = tf.constant(float(s0), dtype=tf.float32)

        self.alpha = self.add_weight(
            name="alpha", shape=(self.M,),
            initializer="zeros", trainable=True,
        )
        self.beta = self.add_weight(
            name="beta", shape=(self.D, self.M),
            initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
            trainable=True,
        )
        self.mu_alpha = self.add_weight(
            name="mu_alpha", shape=(),
            initializer="zeros", trainable=True,
        )
        self.log_tau_a = self.add_weight(
            name="log_tau_alpha", shape=(),
            initializer=tf.keras.initializers.Constant(0.0),
            trainable=True,
        )
        self.log_tau_b = self.add_weight(
            name="log_tau_beta", shape=(),
            initializer=tf.keras.initializers.Constant(0.0),
            trainable=True,
        )

    def weights_logits(self, X: tf.Tensor) -> tf.Tensor:
        logits = tf.linalg.matmul(X, self.beta) + self.alpha[None, :]
        return logits / self.temp

    def weights_softmax(self, X: tf.Tensor) -> tf.Tensor:
        return tf.nn.softmax(self.weights_logits(X), axis=1)

    def nlp(self) -> tf.Tensor:
        """Negative log-prior"""
        tau_a = tf.nn.softplus(self.log_tau_a) + 1e-6
        tau_b = tf.nn.softplus(self.log_tau_b) + 1e-6
        nlp_alpha = tf.reduce_sum((self.alpha - self.mu_alpha) ** 2) / (2.0 * tau_a**2) \
                    + tf.cast(self.M, tf.float32) * tf.math.log(tau_a)
        nlp_beta  = tf.reduce_sum(self.beta ** 2) / (2.0 * tau_b**2) \
                    + tf.cast(self.D * self.M, tf.float32) * tf.math.log(tau_b)
        nlp_mu    = (self.mu_alpha ** 2) / (2.0 * self.s0 ** 2)
        return nlp_alpha + nlp_beta + nlp_mu


def _bhs_batch_elpd_loss(gate: _BHSGate,
                         Xb: tf.Tensor, yb: tf.Tensor,
                         pb_cv_b: tf.Tensor,
                         prior_weight: float = 1.0) -> tf.Tensor:
    C = tf.shape(pb_cv_b)[-1]
    W = gate.weights_softmax(Xb)
    y_oh = tf.one_hot(tf.cast(yb, tf.int32), depth=C, dtype=tf.float32)
    p_true_per_model = tf.einsum('bmc,bc->bm', pb_cv_b, y_oh)
    p_mix_true = tf.reduce_sum(W * p_true_per_model, axis=1)
    nll = -tf.reduce_mean(tf.math.log(tf.clip_by_value(p_mix_true, 1e-12, 1.0)))
    B = tf.cast(tf.shape(Xb)[0], tf.float32)
    return nll + prior_weight * gate.nlp() / B


def _bhs_batch_mse_loss(gate: _BHSGate,
                        Xb: tf.Tensor, yb: tf.Tensor,
                        yhat_oos_b: tf.Tensor,
                        prior_weight: float = 1.0) -> tf.Tensor:
    mask = tf.math.is_finite(yhat_oos_b)  
    denom = tf.reduce_sum(tf.cast(mask, tf.float32), axis=1, keepdims=True) + 1e-6
    row_mean = tf.reduce_sum(tf.where(mask, yhat_oos_b, 0.0), axis=1, keepdims=True) / denom
    yhat_filled = tf.where(mask, yhat_oos_b, row_mean)

    W = gate.weights_softmax(Xb)                 
    mix = tf.reduce_sum(W * yhat_filled, axis=1) 
    mse = tf.reduce_mean((mix - yb) ** 2)
    B = tf.cast(tf.shape(Xb)[0], tf.float32)
    return mse + prior_weight * gate.nlp() / B


def bhs_train(
    x_train: _typing.Union[np.ndarray, tf.Tensor],
    y_train: _typing.Union[np.ndarray, tf.Tensor],
    base_pred_train_cv: _typing.Union[np.ndarray, tf.Tensor],
    epochs: int = 300, lr: float = 5e-3, batch_size: int = 128,
    temp: float = 1.0,
    prior_weight: float = 1.0,
    s0: float = 5.0,
    verbose: bool = False,
    seed: int = 42,
    *,
    task: str = "binary"   
) -> dict:
    tf.random.set_seed(seed)

    X_feat = None
    use_feat_kind = "x"  

    try:
        X_feat = tf.convert_to_tensor(x_train, dtype=tf.float32)
    except Exception:
        bp = np.asarray(base_pred_train_cv, dtype=np.float32)
        X_feat = tf.convert_to_tensor(bp.reshape(bp.shape[0], -1), dtype=tf.float32)
        use_feat_kind = "bp_flat"

    if task == "binary":
        y = tf.convert_to_tensor(np.asarray(y_train).ravel().astype(np.int32))
        pb_cv = tf.convert_to_tensor(base_pred_train_cv, dtype=tf.float32)
        pb_cv = tf.where(tf.math.is_finite(pb_cv), pb_cv, tf.zeros_like(pb_cv))
        pb_cv = tf.clip_by_value(pb_cv, 1e-12, 1.0)
        aux = pb_cv
    elif task == "regression":
        y = tf.convert_to_tensor(np.asarray(y_train).ravel().astype(np.float32))
        yhat_oos = tf.convert_to_tensor(base_pred_train_cv, dtype=tf.float32)
        aux = yhat_oos
    else:
        raise ValueError("task must be 'binary' or 'regression'")

    N = tf.shape(X_feat)[0]
    D = int(X_feat.shape[1])
    M = int(aux.shape[1])

    gate = _BHSGate(D, M, temp=temp, s0=s0)
    _ = gate.weights_softmax(X_feat[:1])  

    opt = tf.keras.optimizers.Adam(learning_rate=lr)

    ds = (tf.data.Dataset.from_tensor_slices((X_feat, y, aux))
            .cache()
            .shuffle(int(N))
            .batch(batch_size, drop_remainder=False)
            .prefetch(tf.data.AUTOTUNE))

    @tf.function(jit_compile=False)
    def train_step(Xb, yb, aux_b):
        with tf.GradientTape() as tape:
            if task == "binary":
                loss = _bhs_batch_elpd_loss(gate, Xb, yb, aux_b, prior_weight=prior_weight)
            else:
                loss = _bhs_batch_mse_loss(gate, Xb, yb, aux_b, prior_weight=prior_weight)
        grads = tape.gradient(loss, gate.trainable_variables)
        opt.apply_gradients(zip(grads, gate.trainable_variables))
        return loss

    last_loss = tf.constant(0.0, dtype=tf.float32)
    for ep in range(1, epochs + 1):
        for Xb, yb, aux_b in ds:
            last_loss = train_step(Xb, yb, aux_b)
        if verbose and (ep % max(1, epochs // 10) == 0 or ep == 1):
            print(f"[BHS/{task}] epoch {ep:4d}/{epochs}  loss={float(last_loss):.6f}")

    params = {
        "alpha": gate.alpha.numpy(),
        "beta": gate.beta.numpy(),
        "mu_alpha": float(gate.mu_alpha.numpy()),
        "log_tau_alpha": float(gate.log_tau_a.numpy()),
        "log_tau_beta": float(gate.log_tau_b.numpy()),
        "temp": float(gate.temp),
        "task": task,
        "feat_kind": use_feat_kind,  
    }
    return params


def _bhs_weights_from_params(params: dict,
                             X: _typing.Union[np.ndarray, tf.Tensor],
                             base_pred: _typing.Union[np.ndarray, tf.Tensor] | None = None) -> tf.Tensor:
    feat_kind = params.get("feat_kind", "x")
    if feat_kind == "bp_flat":
        if base_pred is None:
            raise ValueError("bhs_predict: base_pred is required because gate was trained on flattened base preds.")
        bp = np.asarray(base_pred, dtype=np.float32)
        X_used = bp.reshape(bp.shape[0], -1)  # [N, M*C] or [N, M]
    else:
        X_used = X
    X_used = tf.convert_to_tensor(X_used, dtype=tf.float32)

    alpha = tf.convert_to_tensor(params["alpha"], dtype=tf.float32)
    beta  = tf.convert_to_tensor(params["beta"], dtype=tf.float32)
    temp  = float(params.get("temp", 1.0))
    return tf.nn.softmax((tf.linalg.matmul(X_used, beta) + alpha[None, :]) / max(temp, 1e-6), axis=1)

def bhs_predict(params: dict,
                X: _typing.Union[np.ndarray, tf.Tensor],
                base_pred: _typing.Union[np.ndarray, tf.Tensor],
                *,
                return_weights: bool = False):

    beta = np.asarray(params["beta"], dtype=np.float32)    # [D, M]
    alpha = np.asarray(params["alpha"], dtype=np.float32)  # [M]
    temp = float(params.get("temp", 1.0))
    D_tr, M = beta.shape

    # 1) Build numeric features Xf with D_tr columns
    Xf = None
    try:
        Xf = np.asarray(X, dtype=np.float32)
    except Exception:
        Xf = None

    if Xf is None or Xf.ndim != 2 or Xf.shape[1] != D_tr:
        BP = np.asarray(base_pred)
        if BP.ndim == 3:        
            X_alt = BP.reshape(BP.shape[0], -1)
        elif BP.ndim == 2:      
            X_alt = BP
        else:
            raise ValueError(f"base_pred must be [N,M] or [N,M,C], got {BP.shape}")
        Xf = X_alt.astype(np.float32)

    if Xf.shape[1] != D_tr:
        if Xf.shape[1] < D_tr:
            pad = np.zeros((Xf.shape[0], D_tr - Xf.shape[1]), dtype=np.float32)
            Xf = np.concatenate([Xf, pad], axis=1)
        else:
            Xf = Xf[:, :D_tr]

    logits = (Xf @ beta + alpha[None, :]) / max(temp, 1e-6)
    W = tf.nn.softmax(tf.convert_to_tensor(logits, dtype=tf.float32), axis=1)

    base = tf.convert_to_tensor(base_pred, dtype=tf.float32)
    if base.ndim == 3:   
        res = tf.einsum('nm,nmc->nc', W, base)
    else:               
        res = tf.reduce_sum(W * base, axis=1)

    return (res.numpy(), W.numpy()) if return_weights else res.numpy()
