import torch
from torch.optim import Optimizer


class SAMAdam(Optimizer):
    r"""
    Sharpness-Aware Minimization (SAM) wrapped around Adam.
    Exposes Adam's inner state via `optimizer.state[p]` (e.g., 'exp_avg', 'exp_avg_sq').

    Args:
        params: iterable of parameters to optimize or dicts defining parameter groups
        rho (float): SAM neighborhood size (non-negative)
        adaptive (bool): if True, use ASAM (scale by |w|)
        adam_kwargs: arguments forwarded to torch.optim.Adam
            (betas, eps, weight_decay, amsgrad, foreach, maximize, capturable,
             differentiable, fused as supported by your PyTorch version)
    """
    def __init__(self, params, rho: float = 0.05, adaptive: bool = False, **adam_kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        # Initialize a minimal Optimizer just to hold groups; then create Adam on those groups
        defaults = dict(rho=rho, adaptive=adaptive)
        super().__init__(params, defaults)

        # Build Adam *on the same param_groups* so we can share state/refs
        self.base_optimizer = torch.optim.Adam(self.param_groups, **adam_kwargs)

        # Share param_groups/defaults/state with Adam so callers can access Adam internals
        self.param_groups = self.base_optimizer.param_groups      # same list object
        self.defaults.update(self.base_optimizer.defaults)        # merge hyperparams
        self.state = self.base_optimizer.state                    # critical: expose Adam state

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            for p in group["params"]:
                if p.grad is None:
                    continue
                # save original weight
                state_p = self.state[p]
                state_p["__sam_old_p__"] = p.data.clone()

                # e(w): SAM/ASAM perturbation
                e_w = (p.abs() if group["adaptive"] else 1.0) * p.grad
                # ensure scale lives on the same device as p
                if not torch.is_tensor(scale):
                    scale_t = torch.tensor(scale, device=p.device, dtype=e_w.dtype)
                else:
                    scale_t = scale.to(device=p.device, dtype=e_w.dtype)
                e_w = e_w * scale_t
                p.add_(e_w)  # w <- w + e(w)

        if zero_grad:
            self.zero_grad(set_to_none=True)

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        # restore original weights
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                state_p = self.state[p]
                if "__sam_old_p__" in state_p:
                    p.data = state_p.pop("__sam_old_p__")  # w <- w

        # actual optimizer step at original point using grads from perturbed point
        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad(set_to_none=True)

    @torch.no_grad()
    def step(self, closure):
        """
        SAM requires a closure that:
          1) does a full forward/backward pass and returns the loss.
        We call it twice: at w (to get grads), and at w+e(w) (to get grads for the final step).
        """
        assert closure is not None, "SAM requires a closure that returns the loss."

        # 1) forward-backward at w
        loss = None
        with torch.enable_grad():
            loss = closure()

        # 2) perturb to w+e(w)
        self.first_step(zero_grad=True)

        # 3) forward-backward at w+e(w)
        with torch.enable_grad():
            closure()

        # 4) step at w and restore weights
        self.second_step(zero_grad=True)
        return loss

    def zero_grad(self, set_to_none: bool = False):
        # delegate to Adam for consistency
        self.base_optimizer.zero_grad(set_to_none=set_to_none)

    def _grad_norm(self) -> torch.Tensor:
        # compute global grad L2 norm (respecting ASAM if adaptive=True)
        device = self.param_groups[0]["params"][0].device
        norms = []
        for group in self.param_groups:
            adaptive = group["adaptive"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                g = p.grad
                v = (p.abs() if adaptive else 1.0) * g
                norms.append(v.norm(p=2).to(device))
        if not norms:
            return torch.tensor(0.0, device=device)
        return torch.norm(torch.stack(norms), p=2)

    def state_dict(self):
        # Keep using the base optimizer's serialization (so Adam buffers are saved)
        return self.base_optimizer.state_dict()

    def load_state_dict(self, state_dict):
        # Load into Adam, then re-tie references
        self.base_optimizer.load_state_dict(state_dict)
        self.param_groups = self.base_optimizer.param_groups
        self.state = self.base_optimizer.state
        self.defaults.update(self.base_optimizer.defaults)
