import numpy as np
import torch
import torch.nn as nn
import inspect

try:
    from tabpfn import TabPFNClassifier
except Exception as e:
    TabPFNClassifier = None
    _IMPORT_ERR = e

class TabPFNv2Wrapper(nn.Module):
    def __init__(self, params, args):
        super().__init__()
        if TabPFNClassifier is None:
            raise RuntimeError(f"tabpfn not available: {_IMPORT_ERR}")

        self.device = "cuda" if args.use_gpu else "cpu"
        self.num_features = args.num_features
        self.num_classes  = args.num_classes

        # user params
        n_ensembles  = int(params.get("ensembles", 4))
        bs_infer     = int(params.get("batch_size_inference", 4096))
        variant      = params.get("variant", "v2")
        seed         = int(params.get("seed", 42))

        # Build kwargs, then keep only those supported by the installed version
        wanted = {
            "device": self.device,
            # common names seen across releases
            "N_ensemble_configurations": n_ensembles,
            "n_ensembles": n_ensembles,
            "batch_size_inference": bs_infer,
            "model_type": variant,
            "model": variant,
            "random_state": seed,
        }
        sig = inspect.signature(TabPFNClassifier.__init__)
        supported = {k: v for k, v in wanted.items() if k in sig.parameters}

        self.clf = TabPFNClassifier(**supported)

        # optional train-context cap
        self.max_ctx = int(params.get("max_context", 4096))
        self._Xtr = None
        self._ytr = None

    def fit(self, X_train, y_train, X_val=None, y_val=None):
        if self.max_ctx and X_train.shape[0] > self.max_ctx:
            rng = np.random.default_rng(int(self.clf.__dict__.get("random_state", 42)))
            idx = rng.choice(X_train.shape[0], self.max_ctx, replace=False)
            Xc, yc = X_train[idx], y_train[idx]
        else:
            Xc, yc = X_train, y_train
        self._Xtr, self._ytr = Xc, yc
        self.clf.fit(Xc, yc)

    @torch.no_grad()
    def predict_proba(self, X):
        return self.clf.predict_proba(X)

    @torch.no_grad()
    def predict(self, X):
        return self.predict_proba(X).argmax(axis=1)
