# Copyright (c) Microsoft Corporation.
# The file is modified based on the original Graphormer's source code.
# Copyright (c) 2022 Tianyu Wen
# Licensed under the MIT License.

import time
from argparse import ArgumentParser

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger

from data import GraphDataModule
from model import Graphormer


def cli_main():
    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = Graphormer.add_model_specific_args(parser)
    parser = GraphDataModule.add_argparse_args(parser)
    args = parser.parse_args()
    args.max_steps = args.tot_updates + 1
    if not args.test and not args.validate:
        print(args)
    pl.seed_everything(args.seed)

    # ------------
    # data
    # ------------
    dm = GraphDataModule.from_argparse_args(args)

    # ------------
    # model
    # ------------
    model = Graphormer(
        n_layers=args.n_layers,
        num_heads=args.num_heads,
        hidden_dim=args.hidden_dim,
        attention_dropout_rate=args.attention_dropout_rate,
        dropout_rate=args.dropout_rate,
        intput_dropout_rate=args.intput_dropout_rate,
        weight_decay=args.weight_decay,
        ffn_dim=args.ffn_dim,
        dataset_name=dm.dataset_name,
        warmup_updates=args.warmup_updates,
        tot_updates=args.tot_updates,
        peak_lr=args.peak_lr,
        end_lr=args.end_lr,
        edge_type=args.edge_type,
        multi_hop_max_dist=args.multi_hop_max_dist,
        flag=args.flag,
        flag_m=args.flag_m,
        flag_step_size=args.flag_step_size,
        epochs=args.max_epochs,
        lr_min=args.end_lr
    )
    if not args.test and not args.validate:
        print(model)
    print('total params:', sum(p.numel() for p in model.parameters()))

    # ------------
    # training
    # ------------
    csv_logger = CSVLogger(save_dir='./result',
                           name='graphormer',
                           version=dm.dataset_name+'_'+str(args.k)+'_'+time.strftime('%Y_%m_%d_%H_%M_%S')
                           )
    trainer = Trainer(default_root_dir='./result',
                      max_epochs=args.max_epochs,
                      gpus=[0],
                      benchmark=True,
                      logger=[csv_logger],
                      enable_progress_bar=True,
                      num_sanity_val_steps=0,
                      precision=16
                      )
    trainer.callbacks.append(LearningRateMonitor(logging_interval='step'))

    trainer.fit(model, datamodule=dm)
    trainer.test(model, datamodule=dm)


if __name__ == '__main__':
    cli_main()
