import torch
from typing import Iterable

class SAM_SPS(torch.optim.Optimizer):
    r"""
    Stochastic SAM with SPS step and L2 weight decay.
    
    e^t = x^t + \rho * (1 - \lambd + \lambd / \|\nabla f_i(x^t)\|) * \nabla f_i(x^t)
    \gamma_t = \min\{ (f_i(e^t) - f_star - \langle\nabla f_i(e^t), e^t - x^t\rangle) / \|∇f_i(e^t)\|^2 , \gamma_b }
    x^{t+1} = x^t - \gamma_t * \nabla f_i(e^t)
    """

    def __init__(self,
                 params: Iterable[torch.nn.Parameter],
                 weight_decay: float = 5e-4,
                 rho: float = 0.1,
                 lambd: float = 1.0,
                 f_star: float = 0.0,
                 gamma_b: float = 1.0):
        if weight_decay < 0: 
            raise ValueError("weight_decay must be >= 0")

        # lr kept only for logging compatibility; overwritten each step by \gamma_t and used in SGD update rule
        defaults = dict(lr=gamma_b, weight_decay=weight_decay)
        super().__init__(params, defaults)

        self.rho = rho
        self.lambd = lambd
        self.f_star = f_star
        self.gamma_b = gamma_b

        # bookkeeping for the logger
        self.grad_norm = torch.tensor(0.0)
        self.extra: dict = {}

    @torch.no_grad()
    def _sgd_update(self, group):
        lr = group['lr']
        wd = group['weight_decay']
        
        for p in group['params']:
            if p.grad is None:
                continue
            d_p = p.grad

            if wd != 0.0: # weight decay
                d_p = d_p.add(p, alpha=wd)
            
            p.add_(d_p, alpha=-lr)

    def step(self, closure):
        eps = 1e-12
        if closure is None:
            raise ValueError("USAM_SPS.step() requires a closure")

        # make sure gradients are tracked inside the closure
        closure = torch.enable_grad()(closure)

        # # Pass 1: grads at x^t
        self.zero_grad(set_to_none=True)
        loss_x = closure()  # grads: \nabla f_i(x^t)
        # p = x^t
        # p.grad = \nabla f_i(x^t)
        # loss_x = f_i(x^t)

        # calc \|\nabla f_i(x^t)\|
        grad_norm_sq_x = None
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    grad_x = p.grad.detach().to(torch.float32)
                    grad_norm_sq_x = grad_x.pow(2).sum() if grad_norm_sq_x is None else grad_norm_sq_x + grad_x.pow(2).sum()
        self.grad_norm = grad_norm_sq_x.sqrt().clamp(min=eps)

        # Build e^t and keep e^t-x^t in state[p]['e^t-x^t']
        scale = self.rho * ((1 - self.lambd) + self.lambd / self.grad_norm)
        with torch.no_grad():
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue
                    rest = (p.grad * scale).to(dtype=p.dtype)
                    p.add_(rest) # p = x^t + rest
                    self.state[p]['e^t-x^t'] = rest
        # p = e^t
        # p.grad = \nabla f_i(x^t)

        # # Pass 2: grads at e^t
        self.zero_grad(set_to_none=True)
        loss_e = closure()  # grads: \nabla f_i(e^t)
        # p = e^t
        # p.grad = \nabla f_i(e^t)
        # loss_e = f_i(e^t)

        # Compute \gamma_t
        dot = None
        grad_norm_sq_e = None

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                rest = self.state.get(p, {}).get('e^t-x^t', None)
                if rest is None:
                    continue
                rest = rest.detach().to(torch.float32)
                grad_e = p.grad.detach().to(torch.float32)
                dot = torch.sum(grad_e * rest) if dot is None else dot + torch.sum(grad_e * rest)
                grad_norm_sq_e = grad_e.pow(2).sum() if grad_norm_sq_e is None else grad_norm_sq_e + grad_e.pow(2).sum()
        # p = e^t
        # p.grad = \nabla f_i(e^t)

        num = loss_e.detach().to(torch.float32) - self.f_star - dot
        denom = grad_norm_sq_e.clamp(min=eps)
        gamma = max(0.0, min((num / denom).item(), self.gamma_b))

        # update gamma in group['lr']
        for group in self.param_groups:
            group['lr'] = gamma

        # Restore p to x^t (back from e^t) ----
        with torch.no_grad():
            for group in self.param_groups:
                for p in group['params']:
                    rest = self.state.get(p, {}).get('e^t-x^t', None)
                    if rest is not None:
                        p.sub_(rest)
                        self.state[p].pop('e^t-x^t', None)
        # p = x^t
        # p.grad = \nabla f_i(e^t)

        # Classical SGD update using \nabla f_i(e^t) with weight decay
        with torch.no_grad():
            for group in self.param_groups:
                self._sgd_update(group)
        # p = x^{t+1}
        # p.grad = \nabla f_i(e^t)

        return loss_e