import torch
from torch.optim import Optimizer
import numpy as np
import math

from typing import List, Tuple, Dict, Optional, Callable, Union, Any, Iterable
from typing_extensions import ParamSpec, Self, TypeAlias
from torch import Tensor


ParamsT: TypeAlias = Union[
    Iterable[torch.Tensor], Iterable[dict[str, Any]], Iterable[tuple[str, torch.Tensor]]
]


class AdamBiasCorrectedEps(torch.optim.Optimizer):
    def __init__(
        self, 
        params: ParamsT,
        weight_decay: float = 0.001,
        betas: Tuple[float, float] = (0.9, 0.999), 
        lr: float = 1e-3, 
        eps: float = np.finfo(np.float32).eps, 
        amsgrad: bool = False,
    ):

        defaults = dict(betas=betas, amsgrad=amsgrad)
        super(AdamBiasCorrectedEps, self).__init__(params, defaults=defaults)

        # theta related parameters
        self.lr = lr
        self.weight_decay = weight_decay
        self.betas = betas
        self.eps = eps
            
    def _compute_adam_step(self, p: Tensor, group: Dict[str, Any]) -> Optional[Tensor]:
        """Compute Adam gradient update for a parameter."""

        if p.grad is None:
            return None

        grad = p.grad.data
        amsgrad = group['amsgrad']
        state = self.state[p]
        
        # State initialization
        if f'step' not in state:
            state[f'step'] = 0
            state[f'exp_avg'] = torch.zeros_like(p.data)
            state[f'exp_avg_sq'] = torch.zeros_like(p.data)
            if amsgrad:
                state[f'max_exp_avg_sq'] = torch.zeros_like(p.data)
        
        exp_avg, exp_avg_sq = state[f'exp_avg'], state[f'exp_avg_sq']
        if amsgrad:
            max_exp_avg_sq = state[f'max_exp_avg_sq']
        beta1, beta2 = group['betas']
        
        state[f'step'] += 1
        
        if self.weight_decay != 0:
            grad = grad.add(p.data, alpha=self.weight_decay)
        
        # Update momentum
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            denom = max_exp_avg_sq.sqrt().add_(self.eps)
        else:
            denom = exp_avg_sq.sqrt().add_(self.eps)

        # Bias correction
        bias_correction1 = 1 - beta1 ** state[f'step']
        bias_correction2 = 1 - beta2 ** state[f'step']
        step_size = self.lr * math.sqrt(bias_correction2) / bias_correction1
        
        return -step_size * exp_avg / denom

    def step(self, closure=None) -> float:
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups: 
            for p in group['params']:
                if p.grad is None:
                    continue
                p.data.add_(self._compute_adam_step(p, group))

        return loss


class Adam(torch.optim.Optimizer):
    def __init__(
        self, 
        params: ParamsT,
        weight_decay: float = 0.001,
        betas: Tuple[float, float] = (0.9, 0.999), 
        lr: float = 1e-3, 
        eps: float = np.finfo(np.float32).eps, 
        amsgrad: bool = False,
    ):

        defaults = dict(betas=betas, amsgrad=amsgrad)
        super(Adam, self).__init__(params, defaults=defaults)

        # theta related parameters
        self.lr = lr
        self.weight_decay = weight_decay
        self.betas = betas
        self.eps = eps
            
    def _compute_adam_step(self, p: Tensor, group: Dict[str, Any]) -> Optional[Tensor]:
        """Compute Adam gradient update for a parameter."""

        if p.grad is None:
            return None

        grad = p.grad.data
        amsgrad = group['amsgrad']
        state = self.state[p]
        
        # State initialization
        if f'step' not in state:
            state[f'step'] = 0
            state[f'exp_avg'] = torch.zeros_like(p.data)
            state[f'exp_avg_sq'] = torch.zeros_like(p.data)
            if amsgrad:
                state[f'max_exp_avg_sq'] = torch.zeros_like(p.data)
        
        exp_avg, exp_avg_sq = state[f'exp_avg'], state[f'exp_avg_sq']
        if amsgrad:
            max_exp_avg_sq = state[f'max_exp_avg_sq']
        beta1, beta2 = group['betas']
        
        state[f'step'] += 1
        
        if self.weight_decay != 0:
            grad = grad.add(p.data, alpha=self.weight_decay)
        
        # Update momentum
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            denom = max_exp_avg_sq.sqrt()
        else:
            denom = exp_avg_sq.sqrt()
        bias_correction1 = 1 - beta1 ** state[f'step']
        bias_correction2 = 1 - beta2 ** state[f'step']
        
        step_size = self.lr / bias_correction1
        denom = denom / math.sqrt(bias_correction2) + self.eps

        return -step_size * (exp_avg / denom)

    def step(self, closure=None) -> float:
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups: 
            for p in group['params']:
                if p.grad is None:
                    continue
                p.data.add_(self._compute_adam_step(p, group))

        return loss
