#!/usr/bin/python3
"""
Main
    Model
    Config/Model/Dataset
    Train/Validate/Test
    Main/Hydra/Fold/Train
========================

"""
import argparse
import yaml
from yacs.config import CfgNode as CN
import os

import torch
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.graphgym.config import load_cfg
from torch_geometric.seed import seed_everything

from pytorch_lightning import Trainer

from models import MODEL_REGISTRY
from register import LOADER_REGISTRY
from optim.lightning import LightningWrapper
from optim.metrics import register_metrics


"""
=========
Configure
    Hardware > Seed > Logging
"""

def configure(cfg):
    # Hardware
    #---------
    cfg['device'] = cfg['device'] if torch.cuda.is_available() else 'cpu'
    if cfg['trainer']['deterministic']:
        seed_everything(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    # Logging
    #--------

    pass



"""
====
Load
    Dataset > Model > Lightning
"""

def load(cfg):

    # Dataset
    #--------
    # TODO: get the data loader from the registry
    loader_cfg = cfg['loader']
    loaders = LOADER_REGISTRY.get(loader_cfg['dataset'])(loader_cfg)

    # Metrics 
    #--------
    register_metrics(cfg['optim'],loaders)

    # Model
    #------
    model_config = cfg['model']
    model_config['degrees_hist'] = loaders['degrees_hist']
    model = MODEL_REGISTRY.get(model_config['name'])(model_config)
    model = LightningWrapper(model, cfg['optim'])
    if os.path.exists(cfg['path_pt']) and cfg['load_pt']:
        checkpoint = torch.load(cfg.load['path_pt'])
        model.load_state_dict(checkpoint['model_state_dict'])

    # Trainer
    #--------
    trainer_cfg = cfg['trainer']
    trainer = Trainer(default_root_dir=cfg['path_log'],
        log_every_n_steps=trainer_cfg['freq_log'],
        accelerator=cfg['device'],
        max_epochs=trainer_cfg['max_epochs'],
        check_val_every_n_epoch=trainer_cfg['freq_val'],
        enable_model_summary=trainer_cfg['summary'],
        enable_progress_bar=trainer_cfg['progress'],
        precision=trainer_cfg['precision'],
        accumulate_grad_batches=trainer_cfg['accumulate_grad_batches'],
        enable_checkpointing=trainer_cfg['checkpoint'],
        gradient_clip_val=1.0,
        gradient_clip_algorithm='norm',
        inference_mode=trainer_cfg['inference'],
        deterministic=trainer_cfg['deterministic'],
        #detect_anomaly=True
        )

    return trainer, model, loaders['train'], loaders['val'], loaders['test']


"""
=======
Drivers
"""

def run(cfg):
    configure(cfg)
    trainer, model, train_dl, val_dl, test_dl = load(cfg)
    trainer.fit(model, train_dl, val_dl)
    trainer.test(model, test_dl)
    return 1

if __name__ == '__main__':
    args = parse_args()
    with open(args.cfg_file, 'r') as f:
        cfg = CN.load_cfg(f)
    cfg.merge_from_list(args.opts)
    run(cfg)
