import torch
from torch.optim.optimizer import Optimizer


import math
from torch.optim import Optimizer

class AProx(Optimizer):
    r"""
    Adaptive Proximal Gradient Optimiser (clean version).

    Args
    ----
    params : iterable
    lr     : base learning rate (α in paper)
    beta   : momentum coefficient   (β₁)
    alpha  : second-moment coeff.   (β₂ in Adam paper)
    gamma  : parameter-EMA coeff.
    wd     : weight-decay λ
    eps    : numerical fudge
    """

    def __init__(self, params, lr=3e-4, beta=0.9, alpha=0.99,
                 gamma=0.999, wd=5e-4, eps=1e-8):
        defaults = dict(lr=lr, beta=beta, alpha=alpha,
                        gamma=gamma, wd=wd, eps=eps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):


        return loss

    # Convenience accessor -------------------------------------------------
    def ema_parameters(self):
        """Iterator over the smoothed (Polyak) parameters."""
        pass


class AProx_NoProx(Optimizer):
    r"""
    Adaptive Proximal Gradient Optimiser (clean version).

    Args
    ----
    params : iterable
    lr     : base learning rate (α in paper)
    beta   : momentum coefficient   (β₁)
    alpha  : second-moment coeff.   (β₂ in Adam paper)
    gamma  : parameter-EMA coeff.
    wd     : weight-decay λ
    eps    : numerical fudge
    """

    def __init__(self, params, lr=3e-4, beta=0.9, alpha=0.99,
                 gamma=0.999, wd=5e-4, eps=1e-8):
        defaults = dict(lr=lr, beta=beta, alpha=alpha,
                        gamma=gamma, wd=wd, eps=eps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        return loss

    # Convenience accessor -------------------------------------------------
    def ema_parameters(self):
        """Iterator over the smoothed (Polyak) parameters."""
        pass

class AProx_NoAdaptive(Optimizer):
    def __init__(self, params, lr=1e-3, beta=0.9, weight_decay=0.0):
        defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay)
        super(APGO_NoAdaptive, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):


        return loss


class AProx_NoMomentum(Optimizer):
    def __init__(self, params, lr=1e-3, alpha=0.99, epsilon=1e-8, weight_decay=0.0):
        defaults = dict(lr=lr, alpha=alpha, epsilon=epsilon, weight_decay=weight_decay)
        super(APGO_NoMomentum, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step without momentum."""



        return loss



