from pytorch_lightning.cli import LightningCLI
from symo.nanogpt_pl import NanoGPTLitModule
from symo.data import ShakespeareDataModule


class NanoGPTCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.set_defaults(
            {
                "trainer.gradient_clip_val": None,
                "trainer.gradient_clip_algorithm": None,
            }
        )


if __name__ == "__main__":
    cli = NanoGPTCLI(
        model_class=NanoGPTLitModule,
        datamodule_class=ShakespeareDataModule,
        save_config_callback=None,
        auto_configure_optimizers=False,
        run=True,
    )
    # cli.trainer.fit(cli.model, cli.datamodule)
