import torch

class FixedSignOptimizerWrapper(torch.optim.Optimizer):
    def __init__(self, optimizer):
        """
        A wrapper around a standard optimizer that updates only the magnitude
        of parameters based on the gradient, while preserving the sign of each parameter.
        
        Args:
            optimizer (torch.optim.Optimizer): The base optimizer to wrap.
        """
        self.optimizer = optimizer

    def __getattr__(self, name):
        """
        Delegate attribute access to the wrapped optimizer.
        """
        if name == "optimizer":  # Prevent infinite recursion
            return super().__getattr__(name)
        return getattr(self.optimizer, name)
    
    def step(self, closure=None):
        """
        Perform a single optimization step, updating only the magnitude
        while preserving the sign of the parameters.
        """
        # Iterate over parameter groups and their parameters
        for group in self.optimizer.param_groups:
            for param in group['params']:
                if param.grad is not None:
                    # Save the original signs of the parameters
                    param_sign = torch.sign(param.data)
                    
                    # Let the base optimizer compute the update step
                    self.optimizer.step(closure)
                    
                    # Preserve the sign of the parameters while updating their magnitude
                    param.data = param_sign * torch.abs(param.data)

    def zero_grad(self):
        """Clear gradients in the wrapped optimizer."""
        self.optimizer.zero_grad()