# maml_rl/anil_optimizer.py
import torch.nn as nn

class DifferentiableSGD:
    """In-graph SGD for ANIL inner loop. Updates only `model.anil_update_params` if present."""
    def __init__(self, model: nn.Module, lr: float = 1e-3):
        self.model = model
        self.lr = lr

    def _iter_params(self):
        if hasattr(self.model, "anil_update_params") and self.model.anil_update_params:
            yield from self.model.anil_update_params
        else:
            yield from self.model.parameters()

    def step(self):
        for p in self._iter_params():
            if p.grad is not None:
                p.data.add_(p.grad, alpha=-self.lr)

    def zero_grad(self, set_to_none: bool = False):
        for p in self._iter_params():
            if p.grad is not None:
                if set_to_none:
                    p.grad = None
                else:
                    p.grad.detach_()
                    p.grad.zero_()
