import torch
from torch.optim import Optimizer

class ZenGrad_M(Optimizer):
    def __init__(self, params, lr=0.01, initial_accumulator_value=0,
                 weight_decay=1e-4, epsilon=1e-8, momentum=0.0, nesterov=False):
        if lr <= 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if initial_accumulator_value < 0.0:
            raise ValueError(f"Invalid accumulator value: {initial_accumulator_value}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight decay: {weight_decay}")
        if epsilon <= 0.0:
            raise ValueError(f"Invalid epsilon value: {epsilon}")
        if momentum < 0.0 or momentum >= 1.0:
            raise ValueError(f"Invalid momentum value: {momentum}")
        if nesterov and momentum <= 0.0:
            raise ValueError("Nesterov momentum requires a positive momentum")

        defaults = dict(
            lr=lr,
            initial_accumulator_value=initial_accumulator_value,
            weight_decay=weight_decay,
            epsilon=epsilon,
            momentum=momentum,
            nesterov=nesterov
        )
        super(ZenGrad_M, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            acc_init = group['initial_accumulator_value']
            weight_decay = group['weight_decay']
            epsilon = group['epsilon']
            momentum = group['momentum']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("ZenGrad does not support sparse gradients")

                state = self.state[p]

                # State initialization
                if 'accumulator' not in state:
                    state['accumulator'] = torch.full_like(p.data, acc_init)
                if momentum > 0 and 'momentum_buffer' not in state:
                    state['momentum_buffer'] = torch.zeros_like(p.data)

                accumulator = state['accumulator']
                buf = state.get('momentum_buffer', None)

                # Accumulate squared gradients
                accumulator.add_(grad.pow(2))

                # Decoupled weight decay
                if weight_decay != 0:
                    p.data.mul_(1 - lr * weight_decay)

                # Effective learning rate
                effective_lr = lr / (torch.log(accumulator + 1) + epsilon)

                if momentum > 0:
                    # Update momentum buffer
                    buf.mul_(momentum).add_(grad)

                    if nesterov:
                        # Nesterov update: use g_t + mu * v_t
                        p.data.addcmul_(grad + momentum * buf, -effective_lr)
                    else:
                        # Classical momentum update
                        p.data.addcmul_(buf, -effective_lr)
                else:
                    # No momentum, vanilla ZenGrad update
                    p.data.addcmul_(grad, -effective_lr)

        return loss
