import torch
import torch.nn.functional as F
from src.models.bundle.foca import FOCA


class FO_FOCA(FOCA):

    def __init__(self,
                 bundling_k: int = None,
                 pushforward_n: int = None,
                 pred_config: dict = {},
                 **kwargs):
        super().__init__(**kwargs)

        self.bundling_k = bundling_k
        self.pushforward_n = pushforward_n
        self.pred_config = pred_config

    def infer_context(self, x, y):
        ctx, info = super().infer_context(x, y)

        ctx = ctx.clone().detach()  # disengage from the chain
        with torch.set_grad_enabled(True):
            # and re-engage the chain only for 1-step
            ctx.requires_grad_()
            pred = self.model_forward(x, ctx)
            loss = F.mse_loss(pred, y)
            grad = torch.autograd.grad(loss, ctx, create_graph=True)[0]
            ctx = ctx - self.ctx_opt_config['lr'] * grad
        return ctx, info

    def model_forward(self, x, ctx, *args, **kwargs):
        return self.model(ctx=ctx, us=x)
