import torch


class GCSAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, low_bound=0.5, up_bound=0.8, adaptive=False, **kwargs):
        # print(rho)
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(GCSAM, self).__init__(params, defaults)
        
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        # print(self.base_optimizer)
        self.param_groups = self.base_optimizer.param_groups
        self.low_bound = low_bound
        self.up_bound = up_bound

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                self.state[p]["old_g"] = p.grad.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                #print('e_w: ', e_w)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def gradient_constrain(self, low_bound, up_bound):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                x = p.grad.data.clone()
                y = self.state[p]["old_g"]
                x_norm = x / torch.linalg.norm(x)
                y_norm = y / torch.linalg.norm(y)
                z = x + y
                z_norm = z / torch.linalg.norm(z)
                # score = torch.dot(x_norm.flatten(), y_norm.flatten())
                # print(score)
                # print(low_bound)
                # print(up_bound)
                if low_bound > 0:
                    if low_bound < torch.dot(x_norm.flatten(), y_norm.flatten()) < up_bound:
                        p.grad.data = torch.dot(x.flatten(), z_norm.flatten())*z_norm
                    elif torch.dot(x_norm.flatten(), y_norm.flatten()) <= low_bound:
                        p.grad.data = y
                else:
                    # print(low_bound)
                    # print(up_bound)
                    if torch.dot(x_norm.flatten(), y_norm.flatten()) < up_bound:
                        p.grad.data = torch.dot(x.flatten(), z_norm.flatten())*z_norm

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        self.gradient_constrain(low_bound=self.low_bound, up_bound=self.up_bound)
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups
