"""Softly gradient-coupling optimizers.

This module provides a wrapper around any Pytorch optimizer to implement
gradient coupling at the beginning of a step.
"""

import torch.optim as optim
from torch.optim.optimizer import required

class CoupledSGD(optim.SGD):

    def __init__(self, params, model, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, coupling=0, method='uniform'):
        super().__init__(params=params, lr=lr, momentum=momentum,
                         dampening=dampening, weight_decay=weight_decay,
                         nesterov=nesterov)
        self.coupling = coupling
        self.method = method
        self.model = model

    def step(self):
        self.model.couple_gradients(self.coupling, self.method)
        return super().step()
