def full_pipeline_tabm(X_train, Y_train, X_valid, Y_valid, params=None):
    """
    Train and evaluate TabM model for classification.
    Returns accuracy on validation set.
    """
    import math
    import torch
    import torch.nn.functional as F
    from tabm_reference import Model, make_parameter_groups
    from sklearn.metrics import accuracy_score

    # --- Hyperparameters and Device ---
    if params is None:
        params = {}
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Model config (with defaults)
    n_blocks = params.get("n_blocks", 3)
    d_block = params.get("d_block", 128)
    dropout = params.get("dropout", 0.1)
    k = params.get("k", 32)
    lr = params.get("lr", 2e-3)
    weight_decay = params.get("weight_decay", 3e-4)
    batch_size = params.get("batch_size", 128)
    n_epochs = params.get("n_epochs", 20)

    # --- Data Conversion ---
    X_train = torch.tensor(X_train, dtype=torch.float32, device=device)
    Y_train = torch.tensor(Y_train, dtype=torch.long, device=device)
    X_valid = torch.tensor(X_valid, dtype=torch.float32, device=device)
    Y_valid = torch.tensor(Y_valid, dtype=torch.long, device=device)

    # --- Model ---
    model = Model(
        n_num_features=X_train.shape[1],
        cat_cardinalities=[],
        n_classes=len(torch.unique(Y_train)),
        backbone={
            "type": "MLP",
            "n_blocks": n_blocks,
            "d_block": d_block,
            "dropout": dropout,
        },
        bins=None,
        num_embeddings=None,
        arch_type="tabm",
        k=k,
    ).to(device)
    optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=lr, weight_decay=weight_decay)

    # --- Training Loop ---
    for epoch in range(n_epochs):
        model.train()
        # Shuffle batches
        idx_batches = torch.randperm(len(Y_train), device=device).split(batch_size)
        for batch_idx in idx_batches:
            xb = X_train[batch_idx]
            yb = Y_train[batch_idx]
            logits = model(xb, None).squeeze(-1).float()
            # Flatten over possible k (output paths), repeat targets
            loss = F.cross_entropy(logits.flatten(0, 1), yb.repeat_interleave(logits.shape[-2]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # --- Evaluation ---
    model.eval()
    with torch.no_grad():
        # Process in batches to avoid OOM
        eval_batch_size = 128
        preds = []
        for i in range(0, len(X_valid), eval_batch_size):
            xb = X_valid[i:i+eval_batch_size]
            logits = model(xb, None).squeeze(-1).float()
            preds.append(logits.mean(1).cpu())
        y_pred = torch.cat(preds, dim=0).numpy().argmax(1)
        y_true = Y_valid.cpu().numpy()
        acc = accuracy_score(y_true, y_pred)
    return acc
