import torch

class FGSAMp(torch.optim.Optimizer):

    def __init__(self, params, base_optimizer, k, alpha, rho=0.05, lam=0.5, num_pert=-1, **kwargs):

        """
        LookSAM algorithm: https://arxiv.org/pdf/2203.02714.pdf
        Optimization algorithm that capable of simultaneously minimizing loss and loss sharpness to narrow
        the generalization gap.

        :param params: parameters of the model
        :param base_optimizer: optimizer module (SGD, Adam, etc...)
        :param k: frequency of SAM's gradient calculation (default: 10)
        :param alpha: scaling factor for the adaptive ratio (default: 0.7)
        :param rho: radius of the l_p ball (default: 0.1)

        :return: None

        Usage:
            model = YourModel()
            base_optimizer = YourBaseOptimizer
            optimizer = LookSAM(params=model.parameters(), 
                                base_optimizer=base_optimizer, 
                                k=k, 
                                alpha=alpha, 
                                rho=rho, 
                                **kwargs)
            ...

            def forward():
                ...
                return loss, acc
                
            for train_index, data in enumerate(loader):
                loss, acc = forward()
                loss.backward()
                optimizer.step(t=train_index, forward, zero_grad=True)

            ...

        """

        defaults = dict(alpha=alpha, rho=rho, **kwargs)
        super(FGSAMp, self).__init__(params, defaults)

        self.k = k
        self.alpha = torch.tensor(alpha, requires_grad=False)
        self.lam = lam
        self.num_pert = len(self.param_groups) if num_pert == -1 else num_pert

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @staticmethod
    def normalized(g):
        return g / g.norm(p=2)

    def step(self, t, forward, g_mlp, zero_grad=False):
        # for t: 0 --> T-1

        if t % self.k == 0:  # including the first time
            grad_norm = self._grad_norm()
            for i, group in enumerate(self.param_groups[:self.num_pert]):
                scale = group['rho'] / (grad_norm + 1e-8)  # ρ/||▽wLs(w)||

                for j, p in enumerate(group['params']):
                    if p.grad is None:
                        continue

                    self.state[p]['old_p'] = p.data.clone()
                    self.state[p]['old_p_grad'] = g_mlp[i][j]  # g

                    with torch.no_grad():
                        e_w = p.grad * scale.to(p)
                        # p.add_(e_w)
                        p.data = p.data + e_w
                        p.grad.data = (1. - self.lam) * p.grad.data

            # if zero_grad: self.zero_grad()
            loss_perturbed_mlp = forward(False)[0]
            (self.lam * loss_perturbed_mlp).backward()

        for group in self.param_groups[:self.num_pert]:
            for p in group['params']:
                if p.grad is None:
                    continue
                if t % self.k == 0:  # including the first time
                    old_p_grad = self.state[p]['old_p_grad']  # g
                    g_grad_norm = FGSAMp.normalized(old_p_grad)  # g/||g||
                    gs_grad_norm = FGSAMp.normalized(p.grad)    # gs/||gs||
                    self.state[p]['gv'] = torch.sub(p.grad, p.grad.norm(p=2) * torch.sum(
                        g_grad_norm * gs_grad_norm) * g_grad_norm)  # gs - ||gs|| * (g/||g|| * gs/||gs||) * g/||g||

                else:
                    with torch.no_grad():
                        gv = self.state[p]['gv']
                        p.grad.add_(self.alpha.to(p) * (p.grad.norm(p=2) / (gv.norm(p=2) + 1e-8) * gv))

                p.data = self.state[p]['old_p']

        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

    def _grad_norm(self):
        shared_device = self.param_groups[0]['params'][0].device
        norm = torch.norm(
            torch.stack([
                p.grad.norm(p=2).to(shared_device) 
                for group in self.param_groups[:self.num_pert] 
                for p in group['params']
                if p.grad is not None
            ]),
            p=2
        )

        return norm
