from __future__ import annotations

import math
import os
import tempfile as tmpf
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Generic, Literal, Optional, TypeVar

import gpytorch as gpth
import lightning as pl
import numpy as np
import sklearn.exceptions as skl_exceptions
import torch as th
import torch.distributions.utils
import torch.utils.dlpack
import tqdm
import xgboost as xgbst

import datasets.base

OutputT = TypeVar("OutputT")


class ScoreEstimator(th.nn.Module, ABC, Generic[OutputT]):
    _enable_lazy_fit: bool = True

    n_ctx_covs: int
    n_experts_per_fcomb: int
    _train_inputs: th.Tensor | None
    _train_targets: th.Tensor | None
    _ys: th.Tensor | None
    _pyhats: th.Tensor | None

    _dummy: th.Tensor

    @property
    def device(self):
        return self._dummy.device

    @property
    def train_inputs(self) -> th.Tensor:
        assert self._train_inputs is not None
        return self._train_inputs

    @property
    def train_targets(self) -> th.Tensor:
        assert self._train_targets is not None
        return self._train_targets

    @property
    def ys(self) -> th.Tensor:
        assert self._ys is not None
        return self._ys

    @property
    def pyhats(self) -> th.Tensor:
        assert self._pyhats is not None
        return self._pyhats

    def __init__(self, n_ctx_covs: int, n_experts_per_fcomb: int) -> None:
        super().__init__()
        self.n_ctx_covs = n_ctx_covs
        self.n_experts_per_fcomb = n_experts_per_fcomb
        self._train_inputs = None
        self._train_targets = None
        self._ys = None
        self._pyhats = None
        self.register_buffer("_dummy", th.empty(()))

    def set_train_data_(
        self,
        inputs: th.Tensor,
        targets: th.Tensor,
        infos: Optional[datasets.base.EnvRewardInfo] = None,
    ):
        assert len(inputs) == len(targets.flatten())
        self._train_inputs = inputs.clone().to(device="cpu")
        self._train_targets = targets.clone().flatten().to(device="cpu")
        if infos is not None:
            self._ys = infos.ys.clone().to(device="cpu")
            self._pyhats = infos.pyhats.clone().to(device="cpu")

    def add_to_train_data_(
        self,
        inputs: th.Tensor,
        targets: th.Tensor,
        infos: Optional[datasets.base.EnvRewardInfo] = None,
    ):
        new_inputs = th.cat((self.train_inputs, inputs.to(device="cpu")), dim=0)
        new_targets = th.cat((self.train_targets, targets.to(device="cpu")), dim=0)
        if infos is not None:
            new_ys = th.cat((self.ys, infos.ys.to(device="cpu")), dim=0)
            new_pyhats = th.cat((self.pyhats, infos.pyhats.to(device="cpu")), dim=0)
            infos = datasets.base.EnvRewardInfo(new_pyhats, new_ys)
        self.set_train_data_(new_inputs, new_targets, infos)

    def initialize(self, xs: th.Tensor, ys: th.Tensor, *args, **kwargs):
        pass

    def decompose_inputs(
        self, inputs: th.Tensor
    ) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
        """decompose inputs into ctxs, cinds, exinds

        Args:
            inputs (th.Tensor): (bsz, n_covs)

        Returns:
            th.Tensor: (bsz, n_ctx_covs) contexts
            th.Tensor: (bsz, n_ctx_covs) contexts indicator masks
            th.Tensor: (bsz, ) expert indicator masks
        """
        if self.n_experts_per_fcomb == 1:
            ctxs, cinds = th.chunk(inputs, chunks=2, dim=1)
            exinds: th.Tensor = th.empty(
                (len(ctxs)), dtype=th.float32, device=inputs.device
            )
            return ctxs, cinds, exinds
        ctxs: th.Tensor = inputs[:, : self.n_ctx_covs]
        cinds: th.Tensor = inputs[:, self.n_ctx_covs : 2 * self.n_ctx_covs]
        exinds: th.Tensor = inputs[:, 2 * self.n_ctx_covs :]
        return ctxs, cinds, exinds

    @abstractmethod
    def forward(self, inputs: th.Tensor) -> OutputT: ...

    @abstractmethod
    def get_posterior_mean(self, forward_outs: OutputT) -> th.Tensor: ...

    @abstractmethod
    def sample_posterior(
        self, forward_outs: OutputT, sample_shape: th.Size = th.Size()
    ) -> th.Tensor: ...

    @abstractmethod
    def fit_(self, plf: pl.Fabric) -> dict[str, float]: ...

    def get_extra_state(self) -> dict[str, Any]:
        extra_state: dict[str, Any] = {
            "train_inputs": self._train_inputs,
            "train_targets": self._train_targets,
            "ys": self._ys,
            "pyhats": self._pyhats,
        }
        return extra_state

    def set_extra_state(self, state: Any) -> None:
        self._train_inputs = state["train_inputs"]
        self._train_targets = state["train_targets"]
        self._ys = state["ys"]
        self._pyhats = state["pyhats"]


class GPScoreEst(ScoreEstimator[gpth.distributions.MultivariateNormal]):
    _enable_lazy_fit = False

    class _GP(gpth.models.ExactGP):
        mean_module: gpth.means.Mean
        covar_module: gpth.kernels.Kernel
        likelihood: gpth.likelihoods.GaussianLikelihood

        def __init__(
            self,
            mean_module: gpth.means.Mean,
            covar_module: gpth.kernels.Kernel,
            train_inputs: Optional[th.Tensor],
            train_targets: Optional[th.Tensor],
            likelihood: gpth.likelihoods.GaussianLikelihood,
        ):
            super().__init__(train_inputs, train_targets, likelihood)
            self.mean_module = mean_module
            self.covar_module = covar_module

        def forward(self, xs: th.Tensor) -> gpth.distributions.MultivariateNormal:
            return gpth.distributions.MultivariateNormal(
                self.mean_module(xs), self.covar_module(xs)
            )

    n_fit_iter: int
    opt_kwargs: dict[str, Any]
    n_init_fit_iter: int

    _gp_est: _GP
    _is_first_fit: bool
    _opt_step: int
    _mu: th.Tensor
    _sigma: th.Tensor
    _opt: th.optim.Optimizer

    def __init__(
        self,
        n_ctx_covs: int,
        n_experts_per_fcomb: int,
        mean_module: gpth.means.Mean,
        covar_module: gpth.kernels.Kernel,
        n_fit_iter: int,
        opt_kwargs: dict[str, Any],
        n_init_fit_iter: Optional[int] = None,
    ) -> None:
        super().__init__(n_ctx_covs=n_ctx_covs, n_experts_per_fcomb=n_experts_per_fcomb)
        self._gp_est = GPScoreEst._GP(
            mean_module, covar_module, None, None, gpth.likelihoods.GaussianLikelihood()
        )
        self.n_fit_iter = n_fit_iter
        self.n_init_fit_iter = (
            n_fit_iter if n_init_fit_iter is None else n_init_fit_iter
        )
        self.opt_kwargs = opt_kwargs
        self._is_first_fit = True
        self._opt_step = 0
        self.register_buffer("_mu", th.tensor(0.0))
        self.register_buffer("_sigma", th.tensor(1.0))
        self._opt = th.optim.Adam(self._gp_est.parameters(), **self.opt_kwargs)

    def forward(self, inputs: th.Tensor) -> gpth.distributions.MultivariateNormal:
        return self._gp_est(inputs)

    def get_posterior_mean(
        self, forward_outs: gpth.distributions.MultivariateNormal
    ) -> th.Tensor:
        return (forward_outs.mean * self._sigma) + self._mu

    def sample_posterior(
        self,
        forward_outs: gpth.distributions.MultivariateNormal,
        sample_shape: th.Size = th.Size(),
    ) -> th.Tensor:
        return (forward_outs.sample(sample_shape) * self._sigma) + self._mu

    def fit_(self, plf: pl.Fabric) -> dict[str, float]:
        self.train().to(device=plf.device)
        gp_est = self._gp_est
        xs: th.Tensor = self.train_inputs.to(self.device)
        ys: th.Tensor = self.train_targets.to(self.device)
        sigma, mu = th.std_mean(ys)
        self._sigma.copy_(sigma)
        self._mu.copy_(mu)
        zs = (ys - self._mu) / self._sigma
        gp_est.set_train_data(xs, zs, strict=False)
        assert gp_est.train_inputs is not None
        assert gp_est.train_targets is not None
        mll = gpth.mlls.ExactMarginalLogLikelihood(gp_est.likelihood, gp_est)
        n_iter: int = self.n_fit_iter
        if self._is_first_fit:
            n_iter = self.n_init_fit_iter
            self._opt_step = 0
            self._is_first_fit = False
        pbar = tqdm.trange(n_iter, leave=False, dynamic_ncols=True)
        for _ in pbar:
            self._opt.zero_grad()
            posterior = gp_est(xs)
            nll: th.Tensor = -mll(posterior, zs)
            nll.backward()
            self._opt.step()
            plf.log_dict(
                {
                    "train_gp/nll": nll.item(),
                    "train_gp/mse_zs": gpth.metrics.mean_squared_error(
                        posterior, zs
                    ).item(),
                },
                step=self._opt_step,
            )
            pbar.set_postfix({"nll": -nll.item()})
            self._opt_step = self._opt_step + 1
        pbar.close()
        mse: float
        with th.no_grad():
            mse = th.nn.functional.mse_loss(
                self.get_posterior_mean(self(xs)), ys
            ).item()
        metrics_d = {"est_mse": mse}
        return metrics_d


class XGBoostScoreEst(ScoreEstimator[th.Tensor]):
    fraction_training_data_per_split: float
    n_splits: int

    _models: list[xgbst.XGBRegressor]
    _rg: th.Generator

    def __init__(
        self,
        n_ctx_covs: int,
        n_experts_per_fcomb: int,
        fraction_training_data_per_split: float,
        n_splits: int,
        xgbr_kwargs: dict[str, Any] = dict(),
        rseed: Optional[int] = None,
    ) -> None:
        super().__init__(n_ctx_covs=n_ctx_covs, n_experts_per_fcomb=n_experts_per_fcomb)
        self.fraction_training_data_per_split = fraction_training_data_per_split
        self.n_splits = n_splits
        self._models = [xgbst.XGBRegressor(**xgbr_kwargs) for _ in range(n_splits)]
        self._rg = th.Generator()
        if rseed is not None:
            self._rg.manual_seed(rseed)

    def forward(self, inputs: th.Tensor) -> th.Tensor:
        _inputs: np.ndarray = inputs.numpy(force=True)
        outs_l: list[th.Tensor] = [
            th.as_tensor(m.predict(_inputs), device=self.device) for m in self._models
        ]
        outs: th.Tensor = th.stack(outs_l, dim=1)
        return outs

    def get_posterior_mean(self, forward_outs: th.Tensor) -> th.Tensor:
        return th.mean(forward_outs, dim=1)

    def sample_posterior(
        self, forward_outs: th.Tensor, sample_shape: th.Size = th.Size()
    ) -> th.Tensor:
        n: int = len(forward_outs)
        n_samps: int = math.prod(sample_shape)
        idxs: th.Tensor = th.randint(
            0, len(self._models), (n, n_samps), dtype=th.long, device=self.device
        )
        outs: th.Tensor = th.gather(forward_outs, dim=1, index=idxs).reshape(
            (n, *sample_shape)
        )
        return outs

    def fit_(self, plf: pl.Fabric) -> dict[str, float]:
        self.train()
        rsquares_l: list[float] = list()
        mses_l: list[float] = list()
        for m in self._models:
            n_data: int = math.ceil(
                len(self.train_targets) * self.fraction_training_data_per_split
            )
            idxs: th.Tensor = th.randint(
                0, len(self.train_targets), (n_data,), dtype=th.long, generator=self._rg
            )
            xs: np.ndarray = self.train_inputs[idxs,].numpy(force=True)
            ys: np.ndarray = self.train_targets[idxs].numpy(force=True)
            m.fit(xs, ys)
            train_inputs: np.ndarray = self.train_inputs.numpy(force=True)
            train_targets: np.ndarray = self.train_targets.numpy(force=True)
            rsquares_l.append(m.score(train_inputs, train_targets))
            mses_l.append(
                th.nn.functional.mse_loss(
                    th.as_tensor(m.predict(train_inputs), dtype=th.float32),
                    self.train_targets,
                ).item()
            )
        metrics = {
            "est_rsquared": th.mean(th.as_tensor(rsquares_l)).item(),
            "est_mse": th.mean(th.as_tensor(mses_l)).item(),
        }
        return metrics

    def get_extra_state(self) -> Any:
        extra_state: dict[str, Any] = super().get_extra_state()
        model_states_l: list[list[str]] = list()
        extra_state.update(
            {
                "model_states_l": model_states_l,
                "fraction_training_data_per_split": self.fraction_training_data_per_split,
                "n_splits": self.n_splits,
            }
        )
        try:
            with tmpf.TemporaryDirectory() as td:
                for i, model in enumerate(self._models):
                    p = os.path.join(td, f"m{i}.json")
                    model.save_model(p)
                    with open(p, mode="r") as f:
                        model_states: list[str] = f.readlines()
                        model_states_l.append(model_states)
        except skl_exceptions.NotFittedError:
            pass
        return extra_state

    def set_extra_state(self, state: Any) -> None:
        super().set_extra_state(state)
        self.fraction_training_data_per_split = state[
            "fraction_training_data_per_split"
        ]
        self.n_splits = state["n_splits"]
        self._models.clear()
        with tmpf.TemporaryDirectory() as td:
            for i, model_states in enumerate(state["model_states_l"]):
                p = os.path.join(td, f"m{i}.json")
                with open(p, mode="w") as f:
                    f.writelines(model_states)
                model = xgbst.XGBRegressor()
                model.load_model(p)
                self._models.append(model)
        return


class StructureScoreEstBase(ScoreEstimator[th.Tensor]):
    alpha: th.Tensor
    fraction_training_data_per_split: float
    mimic_models: list[xgbst.XGBClassifier]
    base_models: list[xgbst.XGBClassifier]
    xs_etrain: th.Tensor | None
    ys_etrain: th.Tensor | None
    _rg: th.Generator

    def __init__(
        self,
        n_ctx_covs: int,
        n_experts_per_fcomb: int,
        fraction_training_data_per_split: float,
        n_mimic_models: int,
        n_base_models: int,
        mimic_xgbc_kwargs: dict[str, Any] = dict(),
        base_xgbc_kwargs: dict[str, Any] = dict(),
        alpha: float = 0.0,
        xs_etrain: Optional[th.Tensor] = None,
        ys_etrain: Optional[th.Tensor] = None,
        rseed: Optional[int] = None,
    ) -> None:
        super().__init__(n_ctx_covs=n_ctx_covs, n_experts_per_fcomb=n_experts_per_fcomb)
        self.register_buffer("alpha", th.tensor(alpha))
        self.fraction_training_data_per_split = fraction_training_data_per_split
        self.xs_etrain = xs_etrain
        self.ys_etrain = ys_etrain
        self.mimic_models = [
            xgbst.XGBClassifier(**mimic_xgbc_kwargs) for _ in range(n_mimic_models)
        ]
        self.base_models = [
            xgbst.XGBClassifier(**base_xgbc_kwargs) for _ in range(n_base_models)
        ]
        self._rg = th.Generator()
        if rseed is not None:
            self._rg.manual_seed(rseed)

    def initialize(self, xs: th.Tensor, ys: th.Tensor):
        self.xs_etrain = xs.cpu()
        self.ys_etrain = ys.cpu()
        return

    def _forward(self, inputs: th.Tensor) -> tuple[list[th.Tensor], list[th.Tensor]]:
        mimic_outs_l: list[th.Tensor] = self._forward_mimic(inputs)
        base_outs_l: list[th.Tensor] = self._forward_base(inputs)
        return mimic_outs_l, base_outs_l

    def _forward_mimic(self, inputs: th.Tensor) -> list[th.Tensor]:
        xs, cinds, exinds = self.decompose_inputs(inputs)
        inputs_: th.Tensor = (
            th.cat((xs * cinds, cinds), dim=1)
            if self.n_experts_per_fcomb == 1
            else th.cat((xs * cinds, cinds, exinds), dim=1)
        )
        inputs_n: np.ndarray = inputs_.numpy(force=True)
        mimic_outs_l: list[th.Tensor] = [
            th.as_tensor(
                m.predict_proba(inputs_n), dtype=th.float32, device=self.device
            )
            for m in self.mimic_models
        ]
        return mimic_outs_l

    def _forward_base(self, inputs: th.Tensor) -> list[th.Tensor]:
        ctxs_n: np.ndarray = self.decompose_inputs(inputs)[0].numpy(force=True)
        base_outs_l: list[th.Tensor] = [
            th.as_tensor(m.predict_proba(ctxs_n), dtype=th.float32, device=self.device)
            for m in self.base_models
        ]
        return base_outs_l

    def fit_(self, plf: pl.Fabric) -> dict[str, float]:
        self.train()
        self._fit_base()
        self._fit_mimic()
        self.eval()
        mse_loss: th.Tensor
        with th.no_grad():
            mse_loss = th.nn.functional.mse_loss(
                self.get_posterior_mean(self(self.train_inputs)),
                self.train_targets.to(device=self.device),
            )
        metrics_d: dict[str, float] = {"est_mse": mse_loss.item()}
        return metrics_d

    def _fit_base(self):
        assert self.xs_etrain is not None
        assert self.ys_etrain is not None
        ctxs: th.Tensor = self.decompose_inputs(self.train_inputs)[0]
        # get rid of repeat elements
        tmp_t: th.Tensor = th.unique(th.cat((ctxs, self.ys[:, None]), dim=1), dim=0)
        xs: th.Tensor = th.cat((tmp_t[:, :-1], self.xs_etrain), dim=0)
        ys: th.Tensor = th.cat((tmp_t[:, -1].to(dtype=th.long), self.ys_etrain), dim=0)
        if len(self.base_models) == 1:
            self.base_models[0].fit(xs.numpy(force=True), ys.numpy(force=True))
            return
        for bm in self.base_models:
            n_data: int = math.ceil(len(xs) * self.fraction_training_data_per_split)
            idxs: th.Tensor = th.randint(
                0, len(xs), (n_data,), dtype=th.long, generator=self._rg
            )
            bm.fit(xs[idxs].numpy(force=True), ys[idxs].numpy(force=True))
        return

    def _fit_mimic(self):
        n_labels: int = self.pyhats.shape[1]
        for m in self.mimic_models:
            n_data: int = math.ceil(
                len(self.train_targets) * self.fraction_training_data_per_split
            )
            idxs: th.Tensor = th.randint(
                0, len(self.train_targets), (n_data,), dtype=th.long, generator=self._rg
            )
            _tctxs, _tcinds, _texinds = self.decompose_inputs(self.train_inputs[idxs])
            _xs: th.Tensor = (
                th.cat((_tctxs * _tcinds, _tcinds), dim=1)
                if self.n_experts_per_fcomb == 1
                else th.cat((_tctxs * _tcinds, _tcinds, _texinds), dim=1)
            )
            _xs = _xs[:, None, :].expand(-1, n_labels, -1)
            _ys: th.Tensor = th.arange(0, n_labels, dtype=th.long)
            _ys = _ys[None, :].expand(n_data, -1)
            _xs = _xs.flatten(0, 1)
            _ys = _ys.flatten(0, 1)
            _weight = self.pyhats[idxs].flatten(0, 1)
            m.fit(
                _xs.numpy(force=True),
                _ys.numpy(force=True),
                sample_weight=_weight.numpy(force=True),
            )
        return

    def get_extra_state(self) -> Any:
        extra_state: dict[str, Any] = super().get_extra_state()
        base_model_states_l: list[list[str]] = list()
        mimic_model_states_l: list[list[str]] = list()
        extra_state.update(
            {
                "base_model_states_l": base_model_states_l,
                "mimic_model_states_l": mimic_model_states_l,
                "fraction_training_data_per_split": self.fraction_training_data_per_split,
            }
        )
        try:
            with tmpf.TemporaryDirectory() as td:
                for i, model in enumerate(self.base_models):
                    p = os.path.join(td, f"base_m{i}.json")
                    model.save_model(p)
                    with open(p, mode="r") as f:
                        model_states: list[str] = f.readlines()
                        base_model_states_l.append(model_states)
                for i, model in enumerate(self.mimic_models):
                    p = os.path.join(td, f"mimic_m{i}.json")
                    model.save_model(p)
                    with open(p, mode="r") as f:
                        model_states: list[str] = f.readlines()
                        mimic_model_states_l.append(model_states)
        except skl_exceptions.NotFittedError:
            pass
        return extra_state

    def set_extra_state(self, state: Any) -> None:
        super().set_extra_state(state)
        self.fraction_training_data_per_split = state[
            "fraction_training_data_per_split"
        ]
        self.base_models.clear()
        self.mimic_models.clear()
        with tmpf.TemporaryDirectory() as td:
            for i, model_states in enumerate(state["base_model_states_l"]):
                p = os.path.join(td, f"base_m{i}.json")
                with open(p, mode="w") as f:
                    f.writelines(model_states)
                model = xgbst.XGBClassifier()
                model.load_model(p)
                self.base_models.append(model)
            for i, model_states in enumerate(state["mimic_model_states_l"]):
                p = os.path.join(td, f"mimic_m{i}.json")
                with open(p, mode="w") as f:
                    f.writelines(model_states)
                model = xgbst.XGBClassifier()
                model.load_model(p)
                self.mimic_models.append(model)
        return


class ZipStructureScoreEst(StructureScoreEstBase):
    def __init__(
        self,
        n_ctx_covs: int,
        n_experts_per_fcomb: int,
        fraction_training_data_per_split: float,
        n_models: int,
        mimic_xgbc_kwargs: dict[str, Any] = dict(),
        base_xgbc_kwargs: dict[str, Any] = dict(),
        alpha: float = 0.0,
        xs_etrain: Optional[th.Tensor] = None,
        ys_etrain: Optional[th.Tensor] = None,
        rseed: Optional[int] = None,
    ) -> None:
        super().__init__(
            n_ctx_covs=n_ctx_covs,
            n_experts_per_fcomb=n_experts_per_fcomb,
            fraction_training_data_per_split=fraction_training_data_per_split,
            n_mimic_models=n_models,
            n_base_models=n_models,
            mimic_xgbc_kwargs=mimic_xgbc_kwargs,
            base_xgbc_kwargs=base_xgbc_kwargs,
            alpha=alpha,
            xs_etrain=xs_etrain,
            ys_etrain=ys_etrain,
            rseed=rseed,
        )

    def forward(self, inputs: th.Tensor) -> th.Tensor:
        inputs = inputs.to(self.device)
        mimic_outs_l, base_outs_l = self._forward(inputs)
        outs: th.Tensor = th.stack(
            [
                -th.nn.functional.cross_entropy(
                    torch.distributions.utils.probs_to_logits(mo), bo, reduction="none"
                )
                for mo, bo in zip(mimic_outs_l, base_outs_l)
            ],
            dim=1,
        ).to(device=inputs.device)
        _, cinds, _ = self.decompose_inputs(inputs)
        outs = outs - (self.alpha * th.sum(cinds, dim=1))[:, None]
        return outs

    def get_posterior_mean(self, forward_outs: th.Tensor) -> th.Tensor:
        outs: th.Tensor = th.mean(forward_outs, dim=1)
        return outs

    def sample_posterior(
        self, forward_outs: th.Tensor, sample_shape: th.Size = th.Size()
    ) -> th.Tensor:
        n: int = len(forward_outs)
        n_samps: int = math.prod(sample_shape)
        idxs: th.Tensor = th.randint(
            0, forward_outs.shape[1], (n, n_samps), dtype=th.long, device=self.device
        )
        outs: th.Tensor = th.gather(forward_outs, dim=1, index=idxs).reshape(
            (n, *sample_shape)
        )
        return outs


class OneToManyStructureScoreEst(StructureScoreEstBase):
    def __init__(
        self,
        n_ctx_covs: int,
        n_experts_per_fcomb: int,
        fraction_training_data_per_split: float,
        n_mimic_models: int,
        mimic_xgbc_kwargs: dict[str, Any] = dict(),
        base_xgbc_kwargs: dict[str, Any] = dict(),
        alpha: float = 0.0,
        xs_etrain: Optional[th.Tensor] = None,
        ys_etrain: Optional[th.Tensor] = None,
        rseed: Optional[int] = None,
    ) -> None:
        super().__init__(
            n_ctx_covs=n_ctx_covs,
            n_experts_per_fcomb=n_experts_per_fcomb,
            fraction_training_data_per_split=fraction_training_data_per_split,
            n_mimic_models=n_mimic_models,
            n_base_models=1,
            mimic_xgbc_kwargs=mimic_xgbc_kwargs,
            base_xgbc_kwargs=base_xgbc_kwargs,
            alpha=alpha,
            xs_etrain=xs_etrain,
            ys_etrain=ys_etrain,
            rseed=rseed,
        )

    def forward(self, inputs: th.Tensor) -> th.Tensor:
        inputs = inputs.to(self.device)
        mimic_outs_l, base_outs_l = self._forward(inputs)
        outs: th.Tensor = th.stack(
            [
                -th.nn.functional.cross_entropy(
                    torch.distributions.utils.probs_to_logits(mo),
                    base_outs_l[0],
                    reduction="none",
                )
                for mo in mimic_outs_l
            ],
            dim=1,
        ).to(device=inputs.device)
        _, cinds, _ = self.decompose_inputs(inputs)
        outs = outs - (self.alpha * th.sum(cinds, dim=1))[:, None]
        return outs

    def get_posterior_mean(self, forward_outs: th.Tensor) -> th.Tensor:
        outs: th.Tensor = th.mean(forward_outs, dim=1)
        return outs

    def sample_posterior(
        self, forward_outs: th.Tensor, sample_shape: th.Size = th.Size()
    ) -> th.Tensor:
        n: int = len(forward_outs)
        n_samps: int = math.prod(sample_shape)
        idxs: th.Tensor = th.randint(
            0, forward_outs.shape[1], (n, n_samps), dtype=th.long, device=self.device
        )
        outs: th.Tensor = th.gather(forward_outs, dim=1, index=idxs).reshape(
            (n, *sample_shape)
        )
        return outs


class ModisteScoreEstimatorBase(th.nn.Module, ABC):
    _enable_lazy_fit: bool = True

    n_ctx_covs: int
    n_experts_per_fcomb: int

    _train_inputs: th.Tensor | None
    _train_targets: th.Tensor | None
    _ys: th.Tensor | None
    _pyhats: th.Tensor | None
    _xact_to_aidxs: defaultdict[tuple[int, ...], list[int]]

    _dummy: th.Tensor

    @property
    def device(self):
        return self._dummy.device

    @property
    def train_inputs(self) -> th.Tensor:
        assert self._train_inputs is not None
        return self._train_inputs

    @property
    def train_targets(self) -> th.Tensor:
        assert self._train_targets is not None
        return self._train_targets

    @property
    def ys(self) -> th.Tensor:
        assert self._ys is not None
        return self._ys

    @property
    def pyhats(self) -> th.Tensor:
        assert self._pyhats is not None
        return self._pyhats

    def __init__(
        self,
        n_ctx_covs: int,
        n_experts_per_fcomb: int,
    ) -> None:
        super().__init__()
        self.n_ctx_covs = n_ctx_covs
        self.n_experts_per_fcomb = n_experts_per_fcomb
        self._train_inputs = None
        self._train_targets = None
        self._ys = None
        self._pyhats = None
        self.register_buffer("_dummy", th.empty(()))
        self._xact_to_aidxs = defaultdict(list)

    def set_train_data_(
        self,
        inputs: th.Tensor,
        targets: th.Tensor,
        infos: Optional[datasets.base.EnvRewardInfo] = None,
    ):
        assert len(inputs) == len(targets.flatten())
        self._train_inputs = inputs.clone().to(device="cpu")
        self._train_targets = targets.clone().flatten().to(device="cpu")
        if infos is not None:
            self._ys = infos.ys.clone().to(device="cpu")
            self._pyhats = infos.pyhats.clone().to(device="cpu")
        # record indices of ctx-act pairs that have the same action
        _, cinds, exinds = self.decompose_inputs(inputs)
        xacts: th.Tensor = (
            cinds
            if self.n_experts_per_fcomb == 1
            else th.cat((cinds, exinds), dim=1).to(dtype=th.long)
        )
        uxacts_t, invidxs_t = th.unique(xacts, dim=0, return_inverse=True)
        for i, xact_t in enumerate(uxacts_t):
            xact: tuple[int, ...] = tuple(xact_t.tolist())
            self._xact_to_aidxs[xact].extend(
                th.argwhere(invidxs_t == i).flatten().tolist()
            )
        return

    def add_to_train_data_(
        self,
        inputs: th.Tensor,
        targets: th.Tensor,
        infos: Optional[datasets.base.EnvRewardInfo] = None,
    ):
        # add xact to nidxs
        start_idx: int = len(self.train_inputs)
        _, xacts = th.chunk(inputs, chunks=2, dim=1)
        xacts = xacts.to(dtype=th.long)
        uxacts_t, invidxs_t = th.unique(xacts, dim=0, return_inverse=True)
        for i, xact_t in enumerate(uxacts_t):
            xact: tuple[int, ...] = tuple(xact_t.tolist())
            self._xact_to_aidxs[xact].extend(
                (th.argwhere(invidxs_t == i) + start_idx).flatten().tolist()
            )
        # update train_inputs and train_targets
        self._train_inputs = th.cat((self.train_inputs, inputs.to(device="cpu")), dim=0)
        self._train_targets = th.cat(
            (self.train_targets, targets.to(device="cpu")), dim=0
        )
        if infos is not None:
            self._ys = th.cat((self.ys, infos.ys.to(device="cpu")), dim=0)
            self._pyhats = th.cat((self.pyhats, infos.pyhats.to(device="cpu")), dim=0)
        return

    def decompose_inputs(
        self, inputs: th.Tensor
    ) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
        if self.n_experts_per_fcomb == 1:
            ctxs, cinds = th.chunk(inputs, chunks=2, dim=1)
            exinds: th.Tensor = th.empty(
                (len(ctxs)), dtype=th.float32, device=inputs.device
            )
            return ctxs, cinds, exinds
        ctxs: th.Tensor = inputs[:, : self.n_ctx_covs]
        cinds: th.Tensor = inputs[:, self.n_ctx_covs : 2 * self.n_ctx_covs]
        exinds: th.Tensor = inputs[:, 2 * self.n_ctx_covs :]
        return ctxs, cinds, exinds

    @abstractmethod
    def forward(self, inputs: th.Tensor) -> th.Tensor: ...

    @abstractmethod
    def fit_(self, plf: pl.Fabric) -> dict[str, float]: ...


class ModisteKNNScoreEst(ModisteScoreEstimatorBase):
    n_neighbors: int
    init_strat: Literal["min", "mean"]

    def __init__(
        self,
        n_ctx_covs: int,
        n_experts_per_fcomb: int,
        n_neighbors: int,
        init_strat: Literal["min", "mean"] = "min",
    ) -> None:
        super().__init__(n_ctx_covs=n_ctx_covs, n_experts_per_fcomb=n_experts_per_fcomb)
        self.n_neighbors = n_neighbors
        self.init_strat = init_strat

    def forward(self, inputs: th.Tensor) -> th.Tensor:
        train_inputs: th.Tensor = self.train_inputs.to(device=self.device)
        train_targets: th.Tensor = self.train_targets.to(device=self.device)
        # initialize outputs to be the mean
        outputs: th.Tensor = th.empty(
            (len(inputs),), dtype=th.float32, device=self.device
        )
        if self.init_strat == "mean":
            outputs.fill_(th.mean(train_targets))
        elif self.init_strat == "min":
            outputs.fill_(th.min(train_targets))
        else:
            raise ValueError("self.init_strat must be either one of min or mean.")
        tctxs, _, _ = self.decompose_inputs(train_inputs)
        ctxs, cinds, exinds = self.decompose_inputs(inputs)
        xacts: th.Tensor = (
            cinds
            if self.n_experts_per_fcomb == 1
            else th.cat((cinds, exinds), dim=1).to(dtype=th.long)
        )
        uxacts_t, invidxs_t = th.unique(xacts, dim=0, return_inverse=True)
        xacts_l: list[tuple[int, ...]] = [tuple(xact.tolist()) for xact in uxacts_t]
        for i, xact in enumerate(xacts_l):
            if self.device.type == "cuda":
                th.cuda.empty_cache()
            # indices of inputs that have the same action
            idxs: th.Tensor = th.argwhere(invidxs_t == i).flatten()
            # indices to training inputs correspond to current xact
            aidxs_l: list[int] = self._xact_to_aidxs[xact]
            if len(aidxs_l) == 0:
                continue
            n_neighbors: int = min(self.n_neighbors, len(aidxs_l))
            aidxs: th.Tensor = th.as_tensor(aidxs_l, dtype=th.long, device=self.device)
            # compute distance of current context to training contexts
            # (len(_ctxs), 1, n_covs)
            _ctxs: th.Tensor = ctxs[idxs][:, None, :]
            # (len(_ctxs), len(aidxs), n_covs)
            _tctxs: th.Tensor = tctxs[aidxs][None, :, :].expand(len(_ctxs), -1, -1)
            # (len(_ctxs), len(aidxs))
            _, didxs_sorted = th.sort(th.cdist(_ctxs, _tctxs)[:, 0, :], dim=1)
            # estimated rewards
            # (len(_ctxs), )
            outputs[idxs] = th.mean(
                th.gather(
                    train_targets[aidxs][None, :].expand(len(_ctxs), -1),
                    dim=1,
                    index=didxs_sorted[:, :n_neighbors],
                ),
                dim=1,
            )
        return outputs

    def fit_(self, plf: pl.Fabric) -> dict[str, float]:
        return dict()


class ModisteUnifiedKNNScoreEst(ModisteScoreEstimatorBase):
    n_neighbors: int
    alpha: th.Tensor
    beta: th.Tensor

    def __init__(
        self,
        n_ctx_covs: int,
        n_experts_per_fcomb: int,
        n_neighbors: int,
        alpha: float,
        beta: float,
    ) -> None:
        super().__init__(n_ctx_covs=n_ctx_covs, n_experts_per_fcomb=n_experts_per_fcomb)
        # assert n_experts_per_fcomb == 1
        self.n_neighbors = n_neighbors
        assert alpha > 0 and beta > 0
        self.register_buffer("alpha", th.tensor(alpha, dtype=th.float32))
        self.register_buffer("beta", th.tensor(beta, dtype=th.float32))

    def forward(self, inputs: th.Tensor) -> th.Tensor:
        train_targets: th.Tensor = self.train_targets.to(device=self.device)
        # compute distnce among contexts and actions
        ds: th.Tensor = self._compute_distances(inputs)
        _, didxs_sorted = th.sort(ds, dim=1)
        # compute ouputs
        n_neighbors: int = min(self.n_neighbors, len(train_targets))
        outputs: th.Tensor = th.mean(
            th.gather(
                train_targets[None, :].expand(len(inputs), -1),
                dim=1,
                index=didxs_sorted[:, :n_neighbors],
            ),
            dim=1,
        )
        return outputs

    def _compute_distances(self, inputs: th.Tensor) -> th.Tensor:
        train_inputs: th.Tensor = self.train_inputs.to(device=self.device)
        ctxs, cinds, exinds = self.decompose_inputs(inputs)
        xacts: th.Tensor = (
            cinds if self.n_experts_per_fcomb == 1 else th.cat((cinds, exinds), dim=1)
        )
        ctxs_: th.Tensor = ctxs[:, None, :]
        xacts_: th.Tensor = xacts[:, None, :]
        tctxs, tcinds, texinds = self.decompose_inputs(train_inputs)
        txacts: th.Tensor = (
            cinds if self.n_experts_per_fcomb == 1 else th.cat((tcinds, texinds), dim=1)
        )
        tctxs_: th.Tensor = tctxs[None, :, :].expand(len(ctxs), -1, -1)
        txacts_: th.Tensor = txacts[None, :, :].expand(len(ctxs), -1, -1)
        # (len(ctxs), len(train_inputs))
        ds_ctx: th.Tensor = th.cdist(ctxs_, tctxs_)[:, 0, :]
        if self.device.type == "cuda":
            th.cuda.empty_cache()
        ds_xact: th.Tensor = th.cdist(xacts_, txacts_)[:, 0, :]
        if self.device.type == "cuda":
            th.cuda.empty_cache()
        # combine distances and identify neighbors
        ds: th.Tensor = 1 / self.alpha * ds_ctx + 1 / self.beta * ds_xact
        return ds

    def fit_(self, plf: pl.Fabric) -> dict[str, float]:
        return dict()
