import sys
import argparse
import time
import wandb
import hydra
from omegaconf import DictConfig, OmegaConf, open_dict
from torch.optim.lr_scheduler import StepLR

@hydra.main(config_path='../0_config', config_name='config')
def main(args: DictConfig):
    from gpl import CONFIG_DIR, LOG_DIR, TBLOG_DIR, TBLOG_HPARAMS_DIR, CKPT_DIR, TMP_DIR
    from gpl.utils.utils import load_config
    from gpl.utils.visualize import show_assignments_callback, show_metrics_callback, investigate_Z_interpolation_evaluate_callback, show_prototypes_criticisms_callback
    from gpl.utils.evaluate import embedding_evaluate_callback, vis_Z_callback, prediction_task_callback, prototype_performance, compare_reconstruct_callback, validation_metric_functon
    from gpl.utils.evaluate import train_epoch_log_metrics, test_log_metrics
    from gpl.utils.get_dataset import get_dataset, get_dataloaders
    from gpl.utils.get_model import get_model, get_mlp
    from gpl.training import Trainer, DataLoaders, get_optimizer, get_scheduler, get_device, get_default_logger, dict_of_dicts_merge, yaml_load, get_git_revision_hash

    from gpl.models.gpl_v5 import ib_coeff_uppb_scheduler_callback

    
    
    EXP_NAME, maybe_altered_args = load_config(args)
    EXP_GROUP_NAME = maybe_altered_args.exp_group_name

    ################################################## using config, EXP_NAME, and EXP_GROUP_NAME
    
    # set wandb
    if args.train and not args.debug:
        wandb.init(
            project='GPL',
            name=EXP_NAME,
            config=args,
            group=EXP_GROUP_NAME,
            notes=args.exp_note if args.exp_note != '' else None,
            )
        log_to_file = False
    else:
        log_to_file = args.log.log_file

    ####### running code
    # device
    # device = get_device(args.cuda)
    device = maybe_altered_args.device
    debug = maybe_altered_args.debug

    # dataset
    compute_degree = True if maybe_altered_args.model.name == 'PNA' else False
    train_set, val_set, test_set, extra_info = get_dataset(maybe_altered_args.dataset.name, splits=maybe_altered_args.dataset.splits, device=device, \
                                                            mutag_x=maybe_altered_args.dataset.mutag_x, compute_degree=compute_degree)
    print('Dataset info:', extra_info)

    data_loaders = get_dataloaders(train_set, val_set, test_set, batch_size=maybe_altered_args.dataset.batch_size)
    loaders_wraper = DataLoaders(data_loaders)

    # add some information from the dataset
    with open_dict(maybe_altered_args):
        maybe_altered_args.framework.num_class = extra_info['num_class']
        maybe_altered_args.framework.multi_label = extra_info['multi_label']
    
    model_constructor_params = {
        'x_dim': extra_info['x_dim'],
        'edge_attr_dim': extra_info['edge_attr_dim'],
        'num_class': extra_info['num_class'],
        'multi_label': extra_info['multi_label'],
        'deg': extra_info['deg'],
        'config': maybe_altered_args,
    }
    model = get_model(**model_constructor_params)

    optimizer_params = maybe_altered_args['training']['optimizer_params']
    scheduler_params = maybe_altered_args['training']['scheduler_params']
    epochs = maybe_altered_args['training']['epochs']

    print('Experiment name:', EXP_NAME)
    
    if maybe_altered_args['framework']['gpl_version'] == 'prediction':
        pass
    elif maybe_altered_args['framework']['gpl_version'] in ['gpl_v3', 'gpl_v5']:
        train_callbacks_epoch_end = [train_epoch_log_metrics, ]
        train_callbacks_epoch_start = [ib_coeff_uppb_scheduler_callback]

    else: raise ValueError

    ########## evaluation callbacks ##########
    
    if maybe_altered_args['framework']['gpl_version'] == 'prediction':
        evaluation_callbacks = [prediction_task_callback]
    elif maybe_altered_args['framework']['gpl_version'] in ['gpl_v3', 'gpl_v5']:
        evaluation_callbacks = [test_log_metrics, ]
    else: raise ValueError
    
    ########## evaluation callbacks ##########
    # run
    trainer = Trainer(
        model_name=maybe_altered_args.model.name,
        dataset_name=maybe_altered_args.dataset.name,
        model=model,
        data_loaders=loaders_wraper,
        optimizer_constructor=get_optimizer,
        optimizer_params=optimizer_params,
        scheduler_constructor=StepLR,
        scheduler_params=scheduler_params,
        random_seed=maybe_altered_args.random_seed,
        tb_log_dir=TBLOG_DIR/EXP_GROUP_NAME,
        ckpt_dir=CKPT_DIR,
        log_dir=LOG_DIR,
        tmp_dir=TMP_DIR,
        hparams_save=maybe_altered_args,
        hparams_save_dir=TBLOG_HPARAMS_DIR,
        epochs=epochs,
        device=device,
        val_metric_func=validation_metric_functon,
        debug=debug,
        experiment_name=EXP_NAME,
        use_val=True, 
        training_mode=args.train,
        log2tensorboard=False,
        ########### auxiliary arguments ###########
        task_id=maybe_altered_args.task_id,
    )

    if args.train:
        trainer.train(train_callbacks_epoch_end=train_callbacks_epoch_end, train_callbacks_epoch_start=train_callbacks_epoch_start)
        trainer.evaluate(ckpt_file='best', test_evaluation=True, evaluation_callbacks=evaluation_callbacks )
        wandb.finish()
    else:
        trainer.evaluate(ckpt_file=maybe_altered_args.evaluation.ckpt_file, test_evaluation=False, evaluation_callbacks=evaluation_callbacks )

    print(f'\nExperiment {EXP_NAME} finished.')


if __name__ == "__main__":
    main()
