import torch
from torch.optim import Optimizer

class HomM(Optimizer):
    def __init__(self, params, lr=0.005, alpha=-0.5, beta=0.7, gamma=0.9):
        """
        Homogeneous Momentum Optimizer

        Args:
            params (iterable): model parameters
            lr (float): learning rate
            alpha (float): exponent on the norm ( negative)
            beta (float): gradient scaling coefficient, (0,1)
            gamma (float): velocity coupling factor, (0,1)
        """
        defaults = dict(lr=lr, alpha=alpha, beta=beta, gamma=gamma)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Perform a single optimization step.
        """
        loss = closure() if closure is not None else None

        for group in self.param_groups:
            lr = group['lr']
            alpha = group['alpha']
            beta = group['beta']
            gamma = group['gamma']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                # Get state dict for this parameter
                state = self.state[p]

                # Initialize velocity buffer if not present
                if len(state) == 0:
                    state['v'] = torch.zeros_like(p)  
 
                v = state['v']
               
                # Compute element-wise Euclidean norm: || [grad_i, v_i] ||
                norm_z = torch.sqrt( grad**2 +  v**2).clamp(min=1e-8)

                # Fast elementwise scaling
                if alpha == -0.5:
                    scaling = lr * torch.rsqrt(norm_z)  # Use rsqrt instead of /sqrt()
                elif alpha == -0.25:
                    scaling = lr * torch.rsqrt(torch.sqrt(norm_z)) 
                else:
                    scaling = lr * norm_z.pow(alpha)         

                # Coefficients for the semi-implicit update
                denom = 1 + scaling * (1 - gamma)
                c1 = 1 / denom
                c2 = scaling * gamma / denom

                # ensure scalars are tensors on the same device/dtype
                c1 = torch.as_tensor(c1, dtype=v.dtype, device=v.device)
                c2 = torch.as_tensor(c2, dtype=v.dtype, device=v.device)

                # Update velocity: v = c1 * v - c2 * grad
                v.mul_(c1).add_(grad * (-c2))

                # Update parameters: p = p - scaling*(1-beta)*grad + scaling*beta*v  
                p.add_(grad * (-scaling *(1- beta)))
                p.add_(v * (scaling *  beta))

        return loss