from fairseq.optim import FairseqOptimizer, register_optimizer
from .AdaBelief import AdaBelief
from .agd import AGD

@register_optimizer("agd")
class FairseqAGD(FairseqOptimizer):
    def __init__(self, args, params):
        super().__init__(args)
        self._optimizer = AGD(params, **self.optimizer_config)

    @staticmethod
    def add_args(parser):
        parser.add_argument('--agd-betas', default='(0.9, 0.999)', metavar='B',
                            help='betas for AGD optimizer')
        parser.add_argument('--agd-delta', type=float, default=1e-14, metavar='D',
                            help='delta for AGD optimizer')
        parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
                            help='weight decay')

    @property
    def optimizer_config(self):
        return {
            'lr': self.args.lr[0],
            'betas': eval(self.args.adadqg_betas),
            'delta': self.args.agd_delta,
            'weight_decay': self.args.weight_decay,
        }

@register_optimizer("adabelief")
class FairseqAdaBelief(FairseqOptimizer):
    def __init__(self, args, params):
        super().__init__(args)
        self._optimizer = AdaBelief(params, **self.optimizer_config)

    @staticmethod
    def add_args(parser):
        parser.add_argument('--adabelief-betas', default='(0.9, 0.999)', metavar='B',
                            help='betas for AdaBelief optimizer')
        parser.add_argument('--adabelief-eps', type=float, default=1e-16, metavar='D',
                            help='epsilon for AdaBelief optimizer')
        parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='WD',
                            help='weight decay')

    @property
    def optimizer_config(self):
        return {
            'lr': self.args.lr[0],
            'betas': eval(self.args.adabelief_betas),
            'eps': self.args.adabelief_eps,
            'weight_decay': self.args.weight_decay,
        }
