from __future__ import annotations

import functools
from abc import ABC, abstractmethod
from functools import partial
import itertools
from typing import Any, Callable, Final, Generic, Iterator, Optional, TypeVar

import numpy as np
import sklearn.cluster as skl_cluster
import sklearn.linear_model as skl_linear_model
import sklearn.neighbors as skl_neighbors
import torch as th
import xgboost as xgbst

import models.common
import models.protocols

MT = TypeVar("MT")
SKLMT = TypeVar("SKLMT", bound=models.protocols.ModuleHasPredictProba)


class SubsetFeatureClassifier(th.nn.Module, ABC, Generic[MT]):
    @staticmethod
    def make_cact_fcomb_map(
        n_covs: int,
        min_features: int,
        max_features: Optional[int],
    ) -> tuple[list[tuple[int, ...]], dict[tuple[int, ...], int]]:
        """make comb-action to feature-combination map

        Args:
            n_covs (int): number of covariate
            min_features (int): minimum number of selected features
            max_features (Optional[int]): maximum number of selected features

        Returns:
            list[tuple[int, ...]]: combination action to feature combination
            dict[tuple[int, ...], int]: feature combination to action combination
        """
        max_features = n_covs if max_features is None else max_features
        assert 0 < min_features and min_features <= n_covs
        assert min_features < max_features and max_features <= n_covs
        cact_to_fcomb = list(
            itertools.chain(
                *[
                    itertools.combinations(range(n_covs), i)
                    for i in range(min_features, max_features + 1)
                ]
            )
        )
        fcomb_to_cact = {fcomb: cact for cact, fcomb in enumerate(cact_to_fcomb)}
        return cact_to_fcomb, fcomb_to_cact

    @staticmethod
    def make_full_action_features(
        n_covs: int,
        min_features: int,
        max_features: Optional[int],
        n_experts_per_comb: int,
    ) -> tuple[th.Tensor, th.Tensor, list[tuple[int, ...]], dict[tuple[int, ...], int]]:
        cact_to_fcomb, fcomb_to_cact = SubsetFeatureClassifier.make_cact_fcomb_map(
            n_covs=n_covs, min_features=min_features, max_features=max_features
        )
        # make all combination action features
        cinds_f = th.zeros((len(cact_to_fcomb), n_covs), dtype=th.float32)
        for cact, fcomb in enumerate(cact_to_fcomb):
            cinds_f[cact, fcomb] = 1.0
        # make all expert combination features
        xacts_f: th.Tensor = cinds_f
        if n_experts_per_comb > 1:
            # (n_experts_per_comb, len(self.cact_to_fcomb), n_covs)
            cinds: th.Tensor = cinds_f[None, :, :].expand(n_experts_per_comb, -1, -1)
            exinds: th.Tensor = th.eye(n_experts_per_comb, dtype=th.float32)[
                :, None, :
            ].expand(-1, len(cact_to_fcomb), -1)
            # (n_experts_per_comb * len(self.cact_to_fcomb), n_covs + n_experts_per_comb)
            xacts_f = th.cat((cinds, exinds), dim=2).flatten(0, 1)
        return cinds_f, xacts_f, cact_to_fcomb, fcomb_to_cact

    n_experts_per_act: int
    xs_train: Final[np.ndarray]
    ys_train: Final[np.ndarray]

    n_covs: int

    _cinds_f: th.Tensor
    _xacts_f: th.Tensor
    # act_to_comb: Final[list[tuple[int, ...]]]
    # comb_to_act: Final[dict[tuple[int, ...], int]]

    _xact_to_classifier: dict[tuple[int, ...], MT]

    _dummy: th.Tensor

    @property
    def device(self):
        return self._dummy.device

    def __init__(
        self,
        n_experts_per_act: int,
        cinds_f: th.Tensor,
        xacts_f: th.Tensor,
        xs_train: np.ndarray,
        ys_train: np.ndarray,
        # cact_to_fcomb: list[tuple[int, ...]],
        # fcomb_to_cact: dict[tuple[int, ...], int],
    ):
        super().__init__()
        self.n_experts_per_act = n_experts_per_act
        self._cinds_f = cinds_f
        self._xacts_f = xacts_f
        self.xs_train = xs_train
        self.ys_train = ys_train
        # shape information
        self.n_covs = xs_train.shape[1]
        self.n_labels = len(np.unique(self.ys_train))
        # range of number of selected features
        # hparam for xgbc
        self._xact_to_classifier = dict()
        # all possible feature combinations
        # self.act_to_comb = cact_to_fcomb
        # self.comb_to_act = fcomb_to_cact
        # dummy to keep track of device
        self.register_buffer("_dummy", th.empty(()))

    def act_to_key_fcomb_exidx(
        self, act: int
    ) -> tuple[tuple[int, ...], tuple[int, ...], int]:
        """transform an action to feature combination and expert index.

        Args:
            act (int): an action of interest

        Returns:
            tuple[int, ...]: key
            tuple[int, ...]: feature combination
            int: expert index
        """
        xact: th.Tensor = self._xacts_f[act].to(dtype=th.long)
        key: tuple[int, ...] = tuple(xact.tolist())
        if self.n_experts_per_act == 1:
            comb = tuple(th.argwhere(xact == 1).flatten().tolist())
            return key, comb, 0
        comb_t: th.Tensor = xact[: self.n_covs]
        exind_t: th.Tensor = xact[self.n_covs :]
        comb: tuple[int, ...] = tuple(th.argwhere(comb_t == 1).flatten().tolist())
        exidx: int = int(th.argwhere(exind_t == 1).item())
        return key, comb, exidx

    def get_classifier_from_act(self, act: int) -> MT:
        """Get classifier from action

        Args:
            act (int): the action selected

        Returns:
            MT: the model correspond to this action.
        """
        key: tuple[int, ...] = tuple(self._xacts_f[act].to(dtype=th.long).tolist())
        return self[key]

    @abstractmethod
    def predict_proba(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor:
        """Predict probability given context and action

        Args:
            ctxs (th.Tensor): (bsz, n_covs + n_act_feats) the context
            acts (th.Tensor): (bsz, ) action taken for each context

        Returns:
            th.Tensor: (n_ctxs, n_labels) the probability
        """

    @abstractmethod
    def _predict_proba_same_act(self, ctxs: th.Tensor, act: int) -> th.Tensor:
        """Predict probability given context and action

        Args:
            ctxs (th.Tensor): (bsz, n_covs + n_act_feats) the context
            acts (int): an action taken for this batch of ctxs

        Returns:
            th.Tensor: (n_ctxs, n_labels) the probability
        """

    @abstractmethod
    def __getitem__(self, key: tuple[int, ...]) -> MT: ...

    def __iter__(self) -> Iterator[tuple[int, ...]]:
        return iter([tuple(xact.to(dtype=th.long).tolist()) for xact in self._xacts_f])
        # return iter(self.act_to_comb)

    def __len__(self) -> int:
        return len(self._xacts_f)


class _SubsetFeatureSKLClassifier(SubsetFeatureClassifier[SKLMT]):
    make_model_func: Callable[[], SKLMT]
    use_cp: bool
    to_cache_model: bool

    def __init__(
        self,
        xacts_f: th.Tensor,
        make_model_func: Callable[[], SKLMT],
        use_cp: bool,
        xs_train: np.ndarray,
        ys_train: np.ndarray,
        to_cache_model: bool,
    ):
        super().__init__(
            n_experts_per_act=1,
            cinds_f=xacts_f,
            xacts_f=xacts_f,
            xs_train=xs_train,
            ys_train=ys_train,
        )
        self.make_model_func = make_model_func
        self.use_cp = use_cp
        self.to_cache_model = to_cache_model

    def predict_proba(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor:
        n: int = len(ctxs)
        n_labels: int = self.n_labels
        pyhats: th.Tensor = th.empty((n, n_labels), dtype=th.float32)
        for act in th.unique(acts):
            _act: int = int(act.item())
            _curr_idxs: th.Tensor = th.argwhere(acts == _act).flatten()
            _ctxs: th.Tensor = ctxs[_curr_idxs]
            pyhats[_curr_idxs,] = self._predict_proba_same_act(_ctxs, _act)
        return pyhats

    def _predict_proba_same_act(self, ctxs: th.Tensor, act: int) -> th.Tensor:
        # comb: tuple[int, ...] = self.act_to_comb[act]
        # classifier = self[comb]
        key, comb, _ = self.act_to_key_fcomb_exidx(act)
        classifier = self[key]
        ctxs_ = (
            models.common.to_cp_or_np(ctxs.to(device=self.device))
            if self.use_cp
            else ctxs.numpy(force=True)
        )
        pyhats_n: np.ndarray = classifier.predict_proba(ctxs_[:, comb])
        pyhats: th.Tensor = th.as_tensor(pyhats_n, dtype=th.float32)
        return pyhats

    def __getitem__(self, key: tuple[int, ...]) -> SKLMT:
        assert self.n_experts_per_act == 1
        comb: tuple[int, ...] = key
        if key in self._xact_to_classifier:
            return self._xact_to_classifier[key]
        model = self.make_model_func()
        xs = (
            models.common.to_cp_or_np(
                th.as_tensor(
                    self.xs_train[:, comb], dtype=th.float32, device=self.device
                )
            )
            if self.use_cp
            else self.xs_train[:, comb].astype(np.float32)
        )
        ys = (
            models.common.to_cp_or_np(
                th.as_tensor(self.ys_train, dtype=th.long, device=self.device)
            )
            if self.use_cp
            else self.ys_train.astype(np.int64)
        )
        model.fit(xs, ys)
        if self.to_cache_model:
            self._xact_to_classifier[key] = model
        return model


class SubsetFeatureXGBClassifier(_SubsetFeatureSKLClassifier[xgbst.XGBClassifier]):
    xgbc_kwargs: dict[str, Any]

    def __init__(
        self,
        xacts_f: th.Tensor,
        xs_train: np.ndarray,
        ys_train: np.ndarray,
        to_cache_model: bool = True,
        xgbc_kwargs: dict[str, Any] = dict(),
    ):
        super().__init__(
            xacts_f=xacts_f,
            make_model_func=partial(xgbst.XGBClassifier, **xgbc_kwargs),
            use_cp=False,
            xs_train=xs_train,
            ys_train=ys_train,
            to_cache_model=to_cache_model,
        )
        # hparam for xgbc
        self.xgbc_kwargs = xgbc_kwargs
        self._xact_to_classifier = dict()


class SubsetFeatureLogisticRegressionClassifier(
    _SubsetFeatureSKLClassifier[skl_linear_model.LogisticRegression]  # type:ignore
):
    lrc_kwargs: dict[str, Any]

    def __init__(
        self,
        xacts_f: th.Tensor,
        xs_train: np.ndarray,
        ys_train: np.ndarray,
        to_cache_model: bool = True,
        lrc_kwargs: dict[str, Any] = dict(),
    ):
        super().__init__(
            xacts_f=xacts_f,
            make_model_func=self._make_model_func,
            use_cp=False,
            xs_train=xs_train,
            ys_train=ys_train,
            to_cache_model=to_cache_model,
        )
        self.lrc_kwargs = lrc_kwargs

    def _make_model_func(self) -> skl_linear_model.LogisticRegression:
        try:
            import cuml.linear_model

            return cuml.linear_model.LogisticRegression(
                **self.lrc_kwargs, output_type="numpy"
            )
        except Exception:
            return skl_linear_model.LogisticRegression(**self.lrc_kwargs)


class SubsetFeatureKNNLogisticRegressionClassifier(SubsetFeatureClassifier[None]):
    knn_kwargs: dict[str, Any]
    lrc_kwargs: dict[str, Any]

    NearestNeighbors: Callable[..., skl_neighbors.NearestNeighbors]
    LogisticRegression: Callable[..., skl_linear_model.LogisticRegression]

    def __init__(
        self,
        xacts_f: th.Tensor,
        xs_train: np.ndarray,
        ys_train: np.ndarray,
        knn_kwargs: dict[str, Any] = dict(),
        lrc_kwargs: dict[str, Any] = dict(),
    ):
        super().__init__(
            n_experts_per_act=1,
            cinds_f=xacts_f,
            xacts_f=xacts_f,
            xs_train=xs_train,
            ys_train=ys_train,
        )
        self.knn_kwargs = knn_kwargs
        self.lrc_kwargs = lrc_kwargs
        # choose between sklearn and cuml during runtime
        NearestNeighbors = skl_neighbors.NearestNeighbors
        LogisticRegression = skl_linear_model.LogisticRegression
        try:
            import cuml.linear_model
            import cuml.neighbors

            NearestNeighbors = functools.partial(
                cuml.neighbors.NearestNeighbors, output_type="numpy"
            )
            LogisticRegression = functools.partial(
                cuml.linear_model.LogisticRegression, output_type="numpy"
            )
            NearestNeighbors(**knn_kwargs)
            LogisticRegression(**lrc_kwargs)
        except Exception:
            pass
        self.NearestNeighbors = NearestNeighbors
        self.LogisticRegression = LogisticRegression

    def fit_predict_proba(self, ctxs: th.Tensor, acts: th.Tensor) -> tuple[
        list[np.ndarray],
        list[np.ndarray],
        list[skl_linear_model.LogisticRegression],
        th.Tensor,
    ]:
        n: int = len(ctxs)
        n_labels: int = self.n_labels
        lrcs_l: list[skl_linear_model.LogisticRegression] = list()
        xs_train_lrc_l: list[np.ndarray] = list()
        ys_train_lrc_l: list[np.ndarray] = list()
        pyhats: th.Tensor = th.empty((n, n_labels), dtype=th.float32)
        for i, (ctx, act) in enumerate(zip(ctxs, acts)):
            _act: int = int(act.item())
            _, _comb, _ = self.act_to_key_fcomb_exidx(_act)
            _knn = self.NearestNeighbors(**self.knn_kwargs)
            _knn.fit(self.xs_train[:, _comb])
            _knnidxs: np.ndarray = _knn.kneighbors(
                ctx[None, _comb].numpy(force=True),
                return_distance=False,
            )  # type:ignore
            _knnidxs = _knnidxs.flatten()
            del _knn
            _lrc = self.LogisticRegression(**self.lrc_kwargs)
            _xs_train_lrc, _ys_train_lrc = self._make_valid_lrc_train_set(
                self.xs_train[_knnidxs][:, _comb], self.ys_train[_knnidxs]
            )
            _lrc.fit(_xs_train_lrc, _ys_train_lrc)
            pyhats[i] = th.as_tensor(
                _lrc.predict_proba(ctx[None, _comb].numpy(force=True)),
                dtype=th.float32,
            )
            lrcs_l.append(_lrc)
            xs_train_lrc_l.append(_xs_train_lrc)
            ys_train_lrc_l.append(_ys_train_lrc)
        return xs_train_lrc_l, ys_train_lrc_l, lrcs_l, pyhats

    def predict_proba(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor:
        n: int = len(ctxs)
        n_labels: int = self.n_labels
        pyhats: th.Tensor = th.empty((n, n_labels), dtype=th.float32)
        uacts: th.Tensor
        invidxs_t: th.Tensor
        # batch nearest neighbors computations for inputs with same feature combinations
        uacts, invidxs_t = th.unique(acts, return_inverse=True)
        for ai, uact in enumerate(uacts):
            act: int = int(uact.item())
            _, comb, _ = self.act_to_key_fcomb_exidx(act)
            knn = self.NearestNeighbors(**self.knn_kwargs)
            knn.fit(self.xs_train[:, comb])
            # query knn in training set for input contexts with the same action
            idxs: th.Tensor = th.argwhere(invidxs_t == ai).flatten()
            knnidxs: np.ndarray = knn.kneighbors(
                ctxs[idxs][:, comb].numpy(force=True), return_distance=False
            )  # type:ignore
            del knn
            # for each context, fit local personalized logistic regression model
            for _i, (_knnidxs, _ctx) in enumerate(zip(knnidxs, ctxs[idxs][:, comb])):
                _lrc = self.LogisticRegression(**self.lrc_kwargs)
                # ensure all labels shows up at least once
                _xs_train_lrc, _ys_train_lrc = self._make_valid_lrc_train_set(
                    self.xs_train[_knnidxs][:, comb], self.ys_train[_knnidxs]
                )
                _lrc.fit(_xs_train_lrc, _ys_train_lrc)
                _idx: int = int(idxs[_i].item())
                pyhats[_idx] = th.as_tensor(
                    _lrc.predict_proba(_ctx[None, :].numpy(force=True)),
                    dtype=th.float32,
                )
                del _lrc
        return pyhats

    def _predict_proba_same_act(self, ctxs: th.Tensor, act: int) -> th.Tensor:
        return NotImplemented

    def __getitem__(self, key: tuple[int, ...]) -> None:
        return None

    def _make_valid_lrc_train_set(
        self, xs_train: np.ndarray, ys_train: np.ndarray
    ) -> tuple[np.ndarray, np.ndarray]:
        n_labels: int = self.n_labels
        if len(np.unique(ys_train)) == n_labels:
            return xs_train, ys_train
        dummy: np.ndarray = np.mean(xs_train, axis=0, keepdims=True)
        xs_train = np.concatenate((xs_train, np.tile(dummy, (n_labels, 1))), axis=0)
        ys_train = np.concatenate((ys_train, np.arange(n_labels, dtype=np.int64)))
        return xs_train, ys_train


class SubsetFeatureKNNClassifier(
    _SubsetFeatureSKLClassifier[skl_neighbors.KNeighborsClassifier]  # type:ignore
):
    knc_kwargs: dict[str, Any]

    def __init__(
        self,
        xacts_f: th.Tensor,
        xs_train: np.ndarray,
        ys_train: np.ndarray,
        to_cache_model: bool = True,
        knc_kwargs: dict[str, Any] = dict(),
    ):
        super().__init__(
            xacts_f=xacts_f,
            make_model_func=self._make_model_func,
            use_cp=True,
            xs_train=xs_train,
            ys_train=ys_train,
            to_cache_model=to_cache_model,
        )
        self.knc_kwargs = knc_kwargs

    def _make_model_func(self) -> skl_neighbors.KNeighborsClassifier:
        try:
            import cuml.neighbors

            return cuml.neighbors.KNeighborsClassifier(
                **self.knc_kwargs, output_type="numpy"
            )
        except Exception:
            return skl_neighbors.KNeighborsClassifier(**self.knc_kwargs)


class SubsetFeatureNadarayaWatsonClassifier(SubsetFeatureClassifier[None]):
    _prior_exp: float
    _x_vars: np.ndarray

    def __init__(
        self,
        xacts_f: th.Tensor,
        xs_train: np.ndarray,
        ys_train: np.ndarray,
        sig_mult: float = 0.15,
    ):
        super().__init__(
            n_experts_per_act=1,
            cinds_f=xacts_f,
            xacts_f=xacts_f,
            xs_train=xs_train,
            ys_train=ys_train,
        )
        assert len(np.unique(ys_train)) == 2
        self._prior_exp = np.mean(ys_train).item()
        self._x_vars = np.var(xs_train, axis=0, keepdims=True) * sig_mult

    def predict_proba(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor:
        xqrys: np.ndarray = ctxs.numpy(force=True)
        bs: np.ndarray = self._xacts_f[acts].numpy(force=True)
        # for i, a in enumerate(acts):
        #     _, fcomb, _ = self.act_to_key_fcomb_exidx(int(a.item()))
        #     bs[i, fcomb] = 1.0
        return th.as_tensor(self._expert(xqrys, bs), dtype=th.float32)

    def _predict_proba_same_act(self, ctxs: th.Tensor, act: int) -> th.Tensor:
        return NotImplemented

    def __getitem__(self, key: tuple[int, ...]) -> None:
        return None

    def _nw_pred(
        self,
        Xtrn: th.Tensor,
        Ytrn: th.Tensor,
        Xq: th.Tensor,
        B: th.Tensor,
        sigmas: th.Tensor,
        # toss_ind: Optional[int] = None,
    ) -> th.Tensor:
        """
        Args:
            Xtrn: N x d Train Instances
            Ytrn: N x nclass Train Labels (one-hot)
            Xq: 1 x d Query Instances
            B: d x R binary masks to try
            sigmas: 1 x R bandwidth to use on each mask
        """
        Xtrn2 = Xtrn**2
        # Xq2 = Xq**2
        XtrnXq = Xtrn * Xq
        # N x R
        d2 = th.matmul(Xtrn2, B) - 2.0 * th.matmul(
            XtrnXq, B
        )  # TODO: don't think Xq2 needed
        # if toss_ind is not None:
        #     e_i = tf.reshape(
        #         tf.cast(tf.equal(tf.range(Xtrn.shape[0]), toss_ind), tf.float32),
        #         [-1, 1],
        #     )
        #     d2 = d2 + 1e8 * e_i * sigmas
        kerns = th.softmax(-d2 / sigmas, dim=0)
        # R x nclass
        Y_neighbors = th.matmul(kerns.T, Ytrn)
        return Y_neighbors

    def _expert(
        self, X: np.ndarray, B: np.ndarray, alpha: float = 0.0, minprob: float = 0.001
    ) -> th.Tensor:
        """
        Implement some black box truth expert, should perform better when given
        relevent features, worse when not.

        Args:
            X: numpy array of shape (N, d) masked input features
            B: numpy array of shape (N, d) corresponding mask {0, 1}
        Returns:
            PY: numpy array of shape (N,) expert predicted probabilities
        """
        sig_masks = np.matmul(B, self._x_vars.T)
        sig_masks = np.maximum(
            sig_masks, np.min(self._x_vars)
        )  # Avoid div by zero for empty mask
        featspen = th.as_tensor(1.0 - alpha * np.mean(B, axis=1))
        # :( using slow for loop due to laziness
        # tho expert shouldn't be getting called too many times in practice
        PY = th.zeros((X.shape[0],), dtype=th.float32)
        for i in range(X.shape[0]):
            PY[i] = self._nw_pred(
                th.as_tensor(self.xs_train, dtype=th.float32, device=self.device),
                th.as_tensor(
                    self.ys_train[:, None], dtype=th.float32, device=self.device
                ),
                th.as_tensor(X[i, None, :], dtype=th.float32, device=self.device),
                th.as_tensor(B[i, None, :].T, dtype=th.float32, device=self.device),
                th.as_tensor(sig_masks[i], dtype=th.float32, device=self.device),
            ).to(device="cpu")
        PY = th.minimum(
            th.maximum(PY, th.as_tensor(minprob)), th.as_tensor(1.0 - minprob)
        )
        # TODO: adjust based on number of given feats
        # return featspen * PY + (1 - featspen) * self._prior_exp
        pyhats: th.Tensor = featspen * PY + (1 - featspen) * self._prior_exp
        pyhats = th.cat((1.0 - pyhats[:, None], pyhats[:, None]), dim=1)
        return pyhats


class SubsetFeatureMultiExpertNadarayaWatsonClassifier(SubsetFeatureClassifier[None]):
    _prior_exp: float
    _xs_train_multi: list[np.ndarray]
    _ys_train_multi: list[np.ndarray]
    _x_vars: np.ndarray

    def __init__(
        self,
        n_experts_per_act: int,
        cinds_f: th.Tensor,
        xacts_f: th.Tensor,
        xs_train: np.ndarray,
        ys_train: np.ndarray,
        kmeans_kwargs: dict[str, Any] = dict(),
        sig_mult: float = 0.15,
    ):
        super().__init__(
            n_experts_per_act=n_experts_per_act,
            cinds_f=cinds_f,
            xacts_f=xacts_f,
            xs_train=xs_train,
            ys_train=ys_train,
        )
        assert len(np.unique(ys_train)) == 2
        self._prior_exp = np.mean(ys_train).item()

        KMeans = skl_cluster.KMeans
        try:
            import cuml.cluster

            KMeans = functools.partial(cuml.cluster.KMeans, output_type="numpy")
            KMeans(n_clusters=n_experts_per_act, **kmeans_kwargs)
        except Exception:
            KMeans = skl_cluster.KMeans
        
        valid_clusters = False
        include_for_clustering = np.ones(xs_train.shape[0], dtype=bool)
        while not valid_clusters:
            
            cluster_model = KMeans(n_clusters=n_experts_per_act, **kmeans_kwargs)
            cluster_model = cluster_model.fit(xs_train[include_for_clustering])
            cluster_labels: np.ndarray = cluster_model.predict(xs_train)

            valid_clusters = True
            for i in range(n_experts_per_act):
                if np.sum(cluster_labels == i) < 2:
                    valid_clusters = False
                    include_for_clustering *= cluster_labels != i

        self._xs_train_multi = list()
        self._ys_train_multi = list()
        self._x_vars = np.ones((n_experts_per_act, 1, self.n_covs), dtype=np.float32)
        for cid in np.unique(cluster_labels):
            _cms: np.ndarray = np.argwhere(cid == cluster_labels).flatten()
            self._xs_train_multi.append(xs_train[_cms])
            self._ys_train_multi.append(ys_train[_cms])
            self._x_vars[cid] = (
                np.var(self._xs_train_multi[cid], axis=0, keepdims=True) * sig_mult + 1e-5
            )
            assert np.all(self._x_vars[cid] > 0)

    def predict_proba(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor:
        # TODO start here
        xqrys: np.ndarray = ctxs.numpy(force=True)
        bs: np.ndarray = self._xacts_f[acts].numpy(force=True)
        return th.as_tensor(self._expert(xqrys, bs), dtype=th.float32)

    def _predict_proba_same_act(self, ctxs: th.Tensor, act: int) -> th.Tensor:
        return NotImplemented

    def __getitem__(self, key: tuple[int, ...]) -> None:
        return None

    def _nw_pred(
        self,
        Xtrn: th.Tensor,
        Ytrn: th.Tensor,
        Xq: th.Tensor,
        B: th.Tensor,
        sigmas: th.Tensor,
        # toss_ind: Optional[int] = None,
    ) -> th.Tensor:
        """
        Args:
            Xtrn: N x d Train Instances
            Ytrn: N x nclass Train Labels (one-hot)
            Xq: 1 x d Query Instances
            B: d x R binary masks to try
            sigmas: 1 x R bandwidth to use on each mask
        """
        Xtrn2 = Xtrn**2
        # Xq2 = Xq**2
        XtrnXq = Xtrn * Xq
        # N x R
        d2 = th.matmul(Xtrn2, B) - 2.0 * th.matmul(
            XtrnXq, B
        )  # TODO: don't think Xq2 needed
        # if toss_ind is not None:
        #     e_i = tf.reshape(
        #         tf.cast(tf.equal(tf.range(Xtrn.shape[0]), toss_ind), tf.float32),
        #         [-1, 1],
        #     )
        #     d2 = d2 + 1e8 * e_i * sigmas
        kerns = th.softmax(-d2 / sigmas, dim=0)
        # R x nclass
        Y_neighbors = th.matmul(kerns.T, Ytrn)
        return Y_neighbors

    def _expert(
        self, X: np.ndarray, B: np.ndarray, alpha: float = 0.0, minprob: float = 0.001
    ) -> th.Tensor:
        """
        Implement some black box truth expert, should perform better when given
        relevent features, worse when not.

        Args:
            X: numpy array of shape (N, d) masked input features
            B: numpy array of shape (N, d) corresponding mask {0, 1}
        Returns:
            PY: numpy array of shape (N,) expert predicted probabilities
        """
        expert_index: np.ndarray = np.argmax(B[:, self.n_covs :], axis=1)
        B_feat: np.ndarray = B[:, : self.n_covs]
        # (bsz, 1)
        sig_masks = np.take_along_axis(
            # (n_experts_per_fcomb, bsz, 1)
            np.stack(
                [
                    # (bsz, 1)
                    np.matmul(B_feat, self._x_vars[i].T)
                    for i in range(self.n_experts_per_act)
                ]
            ),
            # (bsz, 1)
            indices=expert_index[None, ..., None],
            axis=0,
        )[0]
        sig_masks = np.maximum(
            sig_masks, np.min(self._x_vars)
        )  # Avoid div by zero for empty mask
        featspen = th.as_tensor(1.0 - alpha * np.mean(B, axis=1))
        # :( using slow for loop due to laziness
        # tho expert shouldn't be getting called too many times in practice
        PY = th.zeros((X.shape[0],), dtype=th.float32)
        for i in range(X.shape[0]):
            PY[i] = self._nw_pred(
                th.as_tensor(
                    self._xs_train_multi[expert_index[i]],
                    dtype=th.float32,
                    device=self.device,
                ),
                th.as_tensor(
                    self._ys_train_multi[expert_index[i]][:, None],
                    dtype=th.float32,
                    device=self.device,
                ),
                th.as_tensor(X[i, None, :], dtype=th.float32, device=self.device),
                th.as_tensor(
                    B_feat[i, None, :].T, dtype=th.float32, device=self.device
                ),
                th.as_tensor(sig_masks[i], dtype=th.float32, device=self.device),
            ).to(device="cpu")
        PY = th.minimum(
            th.maximum(PY, th.as_tensor(minprob)), th.as_tensor(1.0 - minprob)
        )
        # TODO: adjust based on number of given feats
        # return featspen * PY + (1 - featspen) * self._prior_exp
        pyhats: th.Tensor = featspen * PY + (1 - featspen) * self._prior_exp
        pyhats = th.cat((1.0 - pyhats[:, None], pyhats[:, None]), dim=1)
        return pyhats


class SubsetFeatureClassifierWrapperBase(SubsetFeatureClassifier[None], ABC):
    base_classifier: SubsetFeatureClassifier

    def __init__(self, base_classifier: SubsetFeatureClassifier):
        """SubsetFeatureClassifierBase

        Args:
            base_classifier (SubsetFeatureClassifier): base classifier
        """
        assert base_classifier.n_experts_per_act == 1
        super().__init__(
            n_experts_per_act=base_classifier.n_experts_per_act,
            cinds_f=base_classifier._cinds_f,
            xacts_f=base_classifier._xacts_f,
            xs_train=base_classifier.xs_train,
            ys_train=base_classifier.ys_train,
        )
        self.base_classifier = base_classifier

    def __getitem__(self, key: tuple[int, ...]) -> None:
        return None


class SubsetFeatureBiasedClassifierWrapperBase(SubsetFeatureClassifierWrapperBase):
    bias_level: float

    def __init__(self, base_classifier: SubsetFeatureClassifier, bias_level: float):
        """SubsetFeatureBiasedClassifierWrapperBase

        Args:
            base_classifier (SubsetFeatureClassifier[MT]): base classifier
            bias_level (float): strength of the bias effect within 0.0 and 1.0
        """
        super().__init__(base_classifier=base_classifier)
        assert base_classifier.n_labels == 2
        self.bias_level = bias_level

    @abstractmethod
    def _apply_bias(
        self, ctxs: th.Tensor, xacts: th.Tensor, pyhats: th.Tensor
    ) -> th.Tensor: ...

    def predict_proba(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor:
        pyhats: th.Tensor = self.base_classifier.predict_proba(ctxs, acts)
        assert pyhats.shape[1] == 2
        xacts: th.Tensor = self._xacts_f[acts]
        # for act in th.unique(acts):
        #     _act: int = int(act.item())
        #     _, _fcomb, _ = self.act_to_key_fcomb_exidx(_act)
        #     # _fcomb: tuple[int, ...] = self.act_to_comb[_act]
        #     _curr_idxs: th.Tensor = th.argwhere(acts == _act).flatten()
        #     xacts[_curr_idxs,][:, _fcomb] = 1
        return self._apply_bias(ctxs, xacts, pyhats)

    def _predict_proba_same_act(self, ctxs: th.Tensor, act: int) -> th.Tensor:
        return NotImplemented


class SubsetFeatureOverloadClassifierWrapper(SubsetFeatureBiasedClassifierWrapperBase):
    min_temp: float
    bias_mult: float

    def __init__(
        self,
        base_classifier: SubsetFeatureClassifier,
        bias_level: float,
        min_temp: float = 1.0,
        bias_mult: float = 5.0,
    ):
        """SubsetFeatureOverloadClassifierWrapper
        Expert whose outputs increase in uncertainty as more features are provided in the selected display. This is implemented via a temperature function that increases with the square root of the proportion of observed features.

        Args:
            base_classifier (SubsetFeatureClassifier[MT]): base classifier to be wrapped
            bias_level (float): strength of the bias effect within 0.0 and 1.0
            min_temp (float, optional): temperature to apply to predictions when 0 features are provided
            bias_mult (float, optional): scale factor for bias level. Defaults to 5.0.
        """
        super().__init__(base_classifier, bias_level=bias_level)
        self.min_temp = min_temp
        self.bias_mult = bias_mult

    def _apply_bias(
        self, ctxs: th.Tensor, xacts: th.Tensor, pyhats: th.Tensor
    ) -> th.Tensor:
        # (bsz, 1)
        temp: th.Tensor = self.min_temp + self.bias_mult * self.bias_level * th.sqrt(
            th.mean(xacts, dim=1, keepdim=True)
        )
        pyhats_: th.Tensor = pyhats ** (1 / temp)
        pyhats_ = pyhats_ / th.sum(pyhats_, dim=1, keepdim=True)
        return pyhats_


class SubsetFeatureRiskAverseClassifierWrapper(
    SubsetFeatureBiasedClassifierWrapperBase
):
    def __init__(self, base_classifier: SubsetFeatureClassifier, bias_level: float):
        """SubsetFeatureRiskAverseClassifierWrapper
        Expert with a bias towards predicting the positive class (Class 1).

        Here, a bias_level of 0 results in unchanged predictions, whereas a bias level of 1 will result in P(Y=1) = 1.0 in all cases.

        Args:
            base_classifier (SubsetFeatureClassifier[MT]): base classifier
            bias_level (float): strength of the bias effect within 0.0 and 1.0
        """
        super().__init__(base_classifier, bias_level=bias_level)

    def _apply_bias(
        self, ctxs: th.Tensor, xacts: th.Tensor, pyhats: th.Tensor
    ) -> th.Tensor:
        pyhats_: th.Tensor = th.zeros_like(pyhats)
        pyhats_[:, 1] = (1 - self.bias_level) * pyhats[:, 1] + self.bias_level
        pyhats_[:, 0] = 1 - pyhats_[:, 1]
        return pyhats_


class SubsetFeaturePoisonFeatureClassifierWrapper(
    SubsetFeatureBiasedClassifierWrapperBase
):
    univar_model: xgbst.XGBClassifier
    poison_feature_idx: int

    def __init__(
        self,
        base_classifier: SubsetFeatureClassifier[MT],
        bias_level: float,
        poison_feature_idx: int,
        xgbc_kwargs: dict[str, Any] = dict(),
    ):
        super().__init__(base_classifier, bias_level)
        self.poison_feature_idx = poison_feature_idx
        self.univar_model = xgbst.XGBClassifier(**xgbc_kwargs)
        self.univar_model.fit(
            base_classifier.xs_train[:, poison_feature_idx][:, None],
            base_classifier.ys_train,
        )

    def _apply_bias(
        self, ctxs: th.Tensor, xacts: th.Tensor, pyhats: th.Tensor
    ) -> th.Tensor:
        # original max impl.
        # bias_level = np.where(B[:, poison_feat_index], self.bias_level, 0.0)
        # univar_preds = self.univar_model.predict_proba(X[:, self.poison_feat_index, None])[:, 1]
        # return (1-bias_level) * PY + (bias_level) * univar_preds
        xacts_: th.Tensor = xacts.to(dtype=th.bool)
        # (n_ctxs, )
        bias_levels: th.Tensor = th.where(
            xacts_[:, self.poison_feature_idx], self.bias_level, 0.0
        )
        pyhats_u: th.Tensor = th.as_tensor(
            self.univar_model.predict_proba(
                ctxs[:, self.poison_feature_idx][:, None].numpy(force=True),
                dtype=th.float32,
                device=pyhats.device,
            )
        )
        pyhats_: th.Tensor = th.zeros_like(pyhats)
        pyhats_[:, 1] = (1 - bias_levels) * pyhats[:, 1] + bias_levels * pyhats_u[:, 1]
        pyhats_[:, 0] = 1.0 - pyhats_[:, 1]
        return pyhats_
