import torch


class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, eps=1e-12, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
        defaults = dict(rho=rho, adaptive=adaptive, eps=eps, **kwargs)
        super().__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        if grad_norm is None:
            return  

        for group in self.param_groups:
            rho = group["rho"]
            scale = rho / (grad_norm + group["eps"])

            for p in group["params"]:
                if p.grad is None:
                    continue

                self.state[p]["old_p"] = p.data.clone()

                if group["adaptive"]:
                    s = (p.data.abs() + group["eps"]) 
                    e_w = (s * s) * p.grad * scale
                else:
                    e_w = p.grad * scale

                p.add_(e_w)

        if zero_grad:
            self.zero_grad(set_to_none=True)

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                p.data.copy_(self.state[p]["old_p"])

        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad(set_to_none=True)

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "SAM requires closure"

        closure = torch.enable_grad()(closure)

        # Key: first compute the gradient on w
        loss = closure()

        self.first_step(zero_grad=True)
        closure()
        self.second_step(zero_grad=True)

        return loss

    def _grad_norm(self):
        norms = []
        for group in self.param_groups:
            adaptive = group["adaptive"]
            eps = group["eps"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                if adaptive:
                    scale = (p.data.abs() + eps)
                    norms.append((scale * p.grad).norm(p=2))
                else:
                    norms.append(p.grad.norm(p=2))

        if len(norms) == 0:
            return None

        device = norms[0].device
        return torch.norm(torch.stack([n.to(device) for n in norms]), p=2)

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups
