# Adapted from https://github.com/lucidrains/lion-pytorch
from __future__ import annotations
from typing import Tuple, Callable

import torch
from torch.optim.optimizer import Optimizer
from torch.cuda.amp import GradScaler
import torch.distributed as dist

def exists(val):
    return val is not None

def clip_tensor(tensor, max_norm=1.0, eps=1e-6):
    norm = tensor.norm(p=2) 
    if norm > max_norm:
        tensor = tensor * (max_norm / (norm + eps))
    return tensor


def update_fn(p, grad, prev_grad, curr_grad, exp_avg, lr, wd, beta1, beta2):

    if prev_grad is None:
        update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).sign_()
        exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2)
    else:
        update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).add(curr_grad - prev_grad, alpha = beta1).sign_()
        exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2).add(curr_grad - prev_grad, alpha = beta2)
        
    p.data.mul_(1. - lr * wd)
    p.add_(update, alpha = -lr)
    

class Lion_VR(Optimizer):
    def __init__(
        self,
        params,
        lr: float = 1e-4,
        betas: Tuple[float, float] = (0.9, 0.99),
        weight_decay: float = 0.1,
        # use_triton: bool = False,
        decoupled_weight_decay: bool = False,
    ):
        assert lr > 0.
        assert all([0. <= beta <= 1. for beta in betas])

        self._init_lr = lr
        self.decoupled_wd = decoupled_weight_decay

        defaults = dict(
            lr = lr,
            betas = betas,
            weight_decay = weight_decay
        )

        super().__init__(params, defaults)

        self.update_fn = update_fn

        # if use_triton:
        #     from lion_pytorch.triton import update_fn as triton_update_fn
        #     self.update_fn = triton_update_fn

    @torch.no_grad()
    def step(self):

        loss = None
        
        for group in self.param_groups:
            lr = group['lr']
            wd = group['weight_decay']
            beta1, beta2 = group['betas']
            decoupled_wd = getattr(self, 'decoupled_wd', False)
            init_lr = getattr(self, '_init_lr', 1.0)

            if decoupled_wd:
                wd = wd / init_lr

            for p in group['params']:
                if p.grad is None:
                    continue
                state = self.state[p]

                grad = p.grad

                prev_grad = self.prev_grads.get(p, None) if self.prev_grads is not None else None
                curr_grad = self.curr_grads.get(p, None)

                if 'exp_avg' not in state:
                    state['exp_avg'] = torch.zeros_like(p)

                exp_avg = state['exp_avg']

                self.update_fn(
                    p,
                    grad,
                    prev_grad,
                    curr_grad,
                    exp_avg,
                    lr,
                    wd,
                    beta1,
                    beta2
                )
                


        return loss