import torch

class SPS_safe(torch.optim.Optimizer):
    def __init__(self, params, ell_star, M, weight_decay):
        defaults = dict(ell_star=ell_star, M=M, weight_decay=weight_decay)
        super(SPS_safe, self).__init__(params, defaults)
        
        self.ell_star = ell_star
        self.M = M
        self.weight_decay = weight_decay

        self.ss = 0
        self.grad_norm = 0

        self.extra = []
        self.tru = 0

    def step(self, loss):
        self.grad_norm = self.compute_grad_terms().item()
        
        true_sps = ((loss-self.ell_star)/self.grad_norm**2).item()
        spsm = ((loss-self.ell_star)/max(self.M, self.grad_norm**2)).item()
        if true_sps == spsm:
            self.tru += 1
        
        self.ss = spsm
        self.extra = [self.tru]

        for group in self.param_groups:
            for p in group['params']:
                p.data.mul_(1-self.weight_decay*spsm) # weight decay
                p.data.add_(other=p.grad.data.detach(), alpha=-spsm)

        return loss
    
    @torch.no_grad()
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm




class IMA_SPS_safe(torch.optim.Optimizer):
    def __init__(self, params, ell_star, lambd, M, weight_decay):
        defaults = dict(ell_star=ell_star, lambd=lambd, M=M, weight_decay=weight_decay)
        super(IMA_SPS_safe, self).__init__(params, defaults)

        self.lambd = lambd
        self.ell_star = ell_star
        self.M = M
        self.number_steps = 0

        self.ss = 0
        self.grad_norm = 0

        self.extra = []

    def step(self, loss):
        self.number_steps += 1
        _norm = 0.
        _dot = 0.

        for group in self.param_groups:
            for p in group['params']:
                grad = p.grad.data.detach()
                state = self.state[p]

                if self.number_steps == 1:
                    state['z'] = p.detach().clone().to(p.device)

                z = state['z']
                _dot += torch.sum(torch.mul(grad, z-p.data))
                _norm += torch.sum(torch.mul(grad, grad))
        
        if self.M <= 0:
            ima_sps = (max(loss.item()-self.ell_star+_dot, 0)/_norm).item()
            self.M = _norm
        else:
            ima_sps = (max(loss.item()-self.ell_star+_dot, 0)/max(self.M, _norm)).item()
        
        self.grad_norm = torch.sqrt(_norm)
        self.ss = ima_sps

        for group in self.param_groups:
            for p in group['params']:
                grad = p.grad.data.detach()
                state = self.state[p]

                z = state['z']
                z.add_(grad, alpha=-ima_sps)
                p.data.mul_(self.lambd/(1+self.lambd)).add_(other=z, alpha=1/(1+self.lambd))

        return loss