"""
Parameter-scaled D-Adapt SGD
Adapted from https://github.com/facebookresearch/dadaptation/blob/main/dadaptation/dadapt_adagrad.py
"""


import math
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch
import torch.optim
import pdb
import logging

if TYPE_CHECKING:
    from torch.optim.optimizer import _params_t
else:
    _params_t = Any


class PSDASGD(torch.optim.Optimizer):
    def __init__(
        self, params: _params_t, 
        lr: float = 1.0,
        momentum: float = 0.9, 
        beta: float = 0.999,
        log_every: int = 0,
        weight_decay: float = 0.0,
        eps: float = 1e-8,
        d0 = 1e-6, growth_rate=float('inf'),
        clip_ratio=0.1,
    ):
        if d0 <= 0:
            raise ValueError("Invalid d0 value: {}".format(d0))
        if lr <= 0:
            raise ValueError(f"Learning rate {lr} must be positive")
        if momentum < 0:
            raise ValueError(f"Momentum {momentum} must be non-negative")
        if eps <= 0:
            raise ValueError("Invalid epsilon value: {}".format(eps))

        defaults = dict(lr=lr, 
            momentum=momentum,
            beta=beta,
            eps=eps, 
            weight_decay=weight_decay,
            log_every=log_every,
            d=d0,
            growth_rate=growth_rate,
            k=0,
            numerator=0.0,
            elr=0.0,)
        self.d0 = d0
        self.clip_ratio = clip_ratio
        super().__init__(params, defaults)

    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        group = self.param_groups[0]
        d = group['d']
        eps = group['eps']
        lr = max(group['lr'], eps)
        momentum = group['momentum']
        beta = group['beta']
        k = group['k']
        
        log_every = group['log_every']
        growth_rate = group['growth_rate']

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'alpha' not in state:
                    state['alpha'] = torch.zeros_like(p)
                    state['alpha_max'] = torch.zeros_like(p)
                    state['x0'] = torch.clone(p).detach()
                    state['z'] = torch.clone(p).detach()
                    state['gM'] = torch.clone(p.grad.abs())
                    state['sk'] = torch.zeros_like(p)
                    state['exp_avg_sqgrad'] = torch.zeros_like(p)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                alpha = state['alpha']
                alpha_max = state['alpha_max']
                gM = state['gM']
                exp_avg_sqgrad = state['exp_avg_sqgrad']

                exp_avg_sqgrad.mul_(beta).addcmul_(p.grad, p.grad, value=1 - beta)

                unbiased_exp_avg_sqgrad = exp_avg_sqgrad / (1 - beta ** (k + 1))
                _sqalpha = unbiased_exp_avg_sqgrad.sqrt() + eps
                alpha.copy_(_sqalpha.sqrt())
                
                torch.maximum(p.grad.abs(), gM, out=gM)
                torch.maximum(alpha, alpha_max, out=alpha_max)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                alpha = state['alpha']
                x0 = state['x0']
                z = state['z']
                sk = state['sk']

                p.data.mul_(alpha)
                x0.mul_(alpha)
                z.mul_(alpha)
                p.grad.div_(alpha)
                sk.div_(alpha)

        for group in self.param_groups:
            decay = group['weight_decay']
            for p in group['params']:
                if p.grad is None:
                    continue
                p.grad.add_(p.data, alpha=decay)

        max_alpha_ratio = 0.0
        gM_sq = 0.0
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                alpha = state['alpha']
                alpha_max = state['alpha_max']
                gM = state['gM']

                gM_sq += (gM / alpha_max).norm().pow(2).item()
                alpha_ratio = (alpha / alpha_max).max().item()
                max_alpha_ratio = max(alpha_ratio, max_alpha_ratio)

        gM_norm = math.sqrt(gM_sq)

        dlr = d * lr / gM_norm * max_alpha_ratio
        for group in self.param_groups:
            group['elr'] = dlr
            
        numerator = group['numerator']
        sksq = 0.0
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                x0 = state['x0']
                sk = state['sk']
                dx = x0 - p.data

                numerator += dlr * torch.dot(p.grad.flatten(), dx.flatten()).item()
                sk.add_(p.grad, alpha=dlr)
                sksq += (sk * sk).sum().item()
            
        d_hat = numerator / math.sqrt(sksq)
        d = max(d, min(d_hat, d*growth_rate))

        if log_every > 0 and k % log_every == 0:
            logging.info(f"d_hat: {d_hat}, d: {d}. sksq={sksq:1.1e} numerator={numerator:1.1e} lr={lr}")


        for group in self.param_groups:
            group['numerator'] = numerator
            group['d'] = d
            group['k'] = k + 1

            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]
                z = state['z']

                z.add_(p.grad, alpha=-dlr)
                p.data.mul_(momentum).add_(z, alpha=1 - momentum)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                alpha = state['alpha']
                x0 = state['x0']
                z = state['z']
                sk = state['sk']

                p.data.div_(alpha)
                x0.div_(alpha)
                z.div_(alpha)
                p.grad.mul_(alpha)
                sk.mul_(alpha)

        return loss
