from graph_learning.config import Config
import os

class ExperimentConfig(Config):
    def __init__(self, args, context):
        super().__init__(args, context)
        self.trainer = context.trainer

    @classmethod
    def build(cls, args, context):
        config = cls(args, context)
        config.trainer.run(config.command, config)
        os._exit(0)

@Config.register('train',
                 help='Run training.')
class TrainConfig(ExperimentConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

    @property
    def command(self):
        return 'train'


@Config.register('test',
                 help='Run testing')
class TestConfig(ExperimentConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--load-epoch')

    @property
    def command(self):
        return 'test'
