from src.models.bundle.foca import FOCA
from src.models.components.argmin import ArgminLayer


class ImplicitFOCA(FOCA):

    def __init__(self,
                 bundling_k: int = None,
                 pushforward_n: int = None,
                 pred_config: dict = {},
                 **kwargs):
        super().__init__(**kwargs)

        self.bundling_k = bundling_k
        self.pushforward_n = pushforward_n
        self.pred_config = pred_config
        self.armgin = ArgminLayer(self.model, self.ctx_dim, **self.ctx_opt_config)

    def infer_context(self, x, y):
        ctx, loss = self.armgin(x, y)
        return ctx, loss

    def model_forward(self, x, ctx, *args, **kwargs):
        return self.model(ctx=ctx, us=x)
