"""
This file runs the main training/val loop, etc... using Lightning Trainer
"""
import sys
from argparse import ArgumentParser
from pprint import pprint

from pytorch_lightning import Trainer, seed_everything

# sets seeds for numpy, torch, etc...
# must do for DDP to work well
from pytorch_lightning.loggers import WandbLogger
from zendo.learning.models.structure_classifier import StructureClassifier
from zendo.utils import load_envs

seed_everything(123)


def main(args):
    # init module
    model = StructureClassifier(hparams=args)

    wandb_logger = WandbLogger(project="zendo", log_model=True)
    wandb_logger.experiment.watch(model, log="all")
    # wandb_logger = None
    # pprint(args.__dict__)

    # most basic trainer, uses good defaults
    trainer = Trainer.from_argparse_args(args, logger=wandb_logger)
    trainer.fit(model)

    trainer.test()


if __name__ == "__main__":
    load_envs()
    # sys.argv = ["test3.py", "--max_epochs", "15"]

    parser = ArgumentParser(add_help=False)
    parser.add_argument("--dataset_name", default="DATASET_S6_startPROP", type=str)
    parser.add_argument("--structure_dim", default=6, type=int)

    # add args from trainer
    parser = Trainer.add_argparse_args(parser)
    parser.set_defaults(max_epochs=10)
    # give the module a chance to add own params
    # good practice to define LightningModule specific params in the module
    parser = StructureClassifier.add_model_specific_args(parser)

    # parse params
    args = parser.parse_args()

    main(args)
