# optimizer.py

import torch
import torch.optim

class CLAGR(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho, inner_step, cr_lambda, lmomentum, gamma_interp, beta_start, beta_end) -> None:
        assert isinstance(base_optimizer, torch.optim.Optimizer), f"base_optimizer must be an `Optimizer`"
        self.base_optimizer = base_optimizer

        assert 0 <= rho, f"rho should be non-negative:{rho}"
        assert 1 <= inner_step, f"inner_step should be >= 1:{inner_step}"
        assert 0 <= cr_lambda <= 1, f"cr_lambda should be in [0,1]:{cr_lambda}"
        assert 0 <= lmomentum <= 1, f"lmomentum should be in [0,1]:{lmomentum}"
        assert 0 <= gamma_interp <= 1, f"gamma_interp should be in [0,1]:{gamma_interp}"
        assert 0 <= beta_start <= 1, f"beta_start should be in [0,1]:{beta_start}"
        assert 0 <= beta_end <= 1, f"beta_end should be in [0,1]:{beta_end}"
        
        defaults = dict(rho=rho, inner_step=inner_step, cr_lambda=cr_lambda, lmomentum=lmomentum, gamma_interp=gamma_interp, beta_start=beta_start, beta_end=beta_end)
        super(CLAGR, self).__init__(params, defaults)

        self.param_groups = self.base_optimizer.param_groups
        for group in self.param_groups:
            group.update(defaults)
            
            
    @torch.no_grad()
    def first_step(self, closure, zero_grad=False):
        """
        Performs the multi-step gradient ascent process.
        This process modifies the model's weights to reach the perturbation point.
        """
        # 1. Save original parameters in the state dictionary
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None and p.requires_grad:
                    self.state[p]["p_original"] = p.clone(memory_format=torch.contiguous_format)
                    self.state[p]["p_grad_original"] = p.grad.clone(memory_format=torch.contiguous_format)

        # 2. Perform multi-step ascent
        inner_steps = self.param_groups[0]["inner_step"]
        cr_lambda = self.param_groups[0]["cr_lambda"]
        beta_start = self.param_groups[0]["beta_start"]
        beta_end = self.param_groups[0]["beta_end"]
        gamma_interp = self.param_groups[0]["gamma_interp"]
        
        for i in range(inner_steps):
            if i == 0:
                grad_norm = self._grad_norm()
            for group in self.param_groups:
                for p in group["params"]:
                    if p.grad is None: continue
                    
                    if i == 0:
                        v_i = p.grad/(grad_norm + 1e-7)
                        self.state[p]["v_i"] = v_i.clone(memory_format=torch.contiguous_format)
                    else:
                        beta_i = beta_start + (beta_end - beta_start)*i/(inner_steps-1)
                        # Safely get state
                        grad_add_normal = self.state[p].get("p_grad_add_normal", torch.zeros_like(p.grad))
                        grad_sub_normal = self.state[p].get("p_grad_sub_normal", torch.zeros_like(p.grad))
                        v_i = beta_i*self.state[p]["v_i"]+(1-beta_i)*(gamma_interp*grad_add_normal+(1-gamma_interp)*grad_sub_normal)
                        self.state[p]["v_i"] = v_i.clone(memory_format=torch.contiguous_format)

                    perturb = group["rho"]*v_i
                    self.state[p]["perturb"] = perturb.clone(memory_format=torch.contiguous_format)
                    p.add_(perturb)
            
            if zero_grad: self.zero_grad(set_to_none=True)

            with torch.enable_grad(): closure()
            grad_norm_add = self._grad_norm()

            for group in self.param_groups:
                for p in group["params"]:
                    # ======================= BUG FIX START =======================
                    # Only perform these operations if 'perturb' exists in the state
                    if p.grad is not None and 'perturb' in self.state[p]:
                        self.state[p]["p_grad_add"] = p.grad.clone(memory_format=torch.contiguous_format)
                        p_grad_add_normal = p.grad/(grad_norm_add + 1e-7)
                        self.state[p]["p_grad_add_normal"] = p_grad_add_normal.clone(memory_format=torch.contiguous_format)
                        p.sub_(self.state[p]["perturb"],alpha=2)
                    # ======================== BUG FIX END ========================
                        
            if zero_grad: self.zero_grad(set_to_none=True)

            with torch.enable_grad(): closure()
            grad_norm_sub = self._grad_norm()

            
            for group in self.param_groups:
                for p in group["params"]:
                    # ======================= BUG FIX START =======================
                    # Similarly, only perform subsequent operations if 'perturb' exists
                    if p.grad is not None and 'perturb' in self.state[p]:
                        self.state[p]["p_grad_sub"] = p.grad.clone(memory_format=torch.contiguous_format)
                        p_grad_sub_normal = p.grad/(grad_norm_sub + 1e-7)
                        self.state[p]["p_grad_sub_normal"] = p_grad_sub_normal.clone(memory_format=torch.contiguous_format)
                        p.add_(self.state[p]["perturb"],alpha=1)
                        p.grad = 1/2*(self.state[p]["p_grad_add"]+self.state[p]["p_grad_sub"]) + cr_lambda/(2*group["rho"])*(self.state[p]["p_grad_add"]-self.state[p]["p_grad_sub"])
                        
                    # ======================== BUG FIX END ========================

        self.base_optimizer.step()

        lmomentum = self.param_groups[0]["lmomentum"]
        for group in self.param_groups:
            for p in group["params"]:
                # ======================= BUG FIX START =======================
                # Also check if the key exists when cleaning up the state
                if p.requires_grad and 'p_original' in self.state.get(p, {}):
                    p.data = self.state[p]["p_original"] + lmomentum*(p.data-self.state[p]["p_original"])
                    del self.state[p] # Clean up the entire state for the parameter
                # ======================== BUG FIX END ========================


    @torch.no_grad()
    def step(self, closure=None, **kwargs):
        assert closure is not None, "CLAGR requires a closure to re-evaluate the gradient."
        self.first_step(closure, zero_grad=True)
        
    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        norms = [p.grad.norm(p=2).to(shared_device) for group in self.param_groups for p in group["params"] if p.grad is not None]
        if not norms: return torch.tensor(0.0, device=shared_device)
        return torch.norm(torch.stack(norms), p=2)



class SAM(torch.optim.Optimizer):

    def __init__(self, params, base_optimizer, rho) -> None:
        assert isinstance(base_optimizer, torch.optim.Optimizer), f"base_optimizer must be an `Optimizer`"
        self.base_optimizer = base_optimizer

        assert 0 <= rho, f"rho should be non-negative:{rho}"
        self.rho = rho
        super(SAM, self).__init__(params, dict(rho=rho))

        self.param_groups = self.base_optimizer.param_groups
        for group in self.param_groups:
            group["rho"] = rho
    

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-7)
            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w
        if zero_grad: self.zero_grad()
    
    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"
        
        self.base_optimizer.step()
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None, **kwargs):
        assert closure is not None, "SAM requires closure, which is not provided."
        
        self.first_step(True)
        with torch.enable_grad():
            closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

