import torch

from graph_learning.module import ModuleConfig, register_module

from functools import partial

@ModuleConfig.register('adam',
                       help='[Optimizer] Adam')
class AdamOptimizerModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--lr', type=float)
        parser.add_argument('--weight-decay', type=float, default=0)

    @classmethod
    def build(cls, args, context):
        config = AdamOptimizerModuleConfig(args, context)
        content = partial(torch.optim.Adam, lr=config.lr, weight_decay=config.weight_decay)
        register_module(context, config.name, content)

@ModuleConfig.register('sgd',
                       help='[Optimizer] Sgd')
class SgdOptimizerModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--lr', type=float)
        parser.add_argument('--weight-decay', type=float, default=0)

    @classmethod
    def build(cls, args, context):
        config = SgdOptimizerModuleConfig(args, context)
        content = partial(torch.optim.SGD, lr=config.lr, weight_decay=config.weight_decay)
        register_module(context, config.name, content)

from .lamb import Lamb

@ModuleConfig.register('lamb',
                       help='[Optimizer] Lamb')
class LambOptimizerModuleConfig(ModuleConfig):
    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--lr', type=float)
        parser.add_argument('--weight-decay', type=float, default=0)

    @classmethod
    def build(cls, args, context):
        config = LambOptimizerModuleConfig(args, context)
        content = partial(Lamb, lr=config.lr, weight_decay=config.weight_decay)
        register_module(context, config.name, content)
