DEBUG = False
import os, sys, torch, wandb, pickle, numpy as np, pytorch_lightning as pl

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[1]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from train.train_funcs import make_checkpoint_callback_dict, which_dm
from train.real.real import print_results
from model.pl_prox_model import plCovNN
from utils.util_funcs import sample_spherical, graph_gen_info, coeffs_str_builder, construct_run_name
from data.pl_data import SyntheticDataModule

which_exp = 'synthetics'
def train():
    hyperparameter_defaults = dict(
        channels=1,
        share_parameters=False,
        task='link-pred',
        graph_gen='geom',
        coeffs_index=1,
        num_vertices=68,
        fc_norm="max_eig", sum_stat="analytic_cov",
        num_signals=50, #place holder, not using signals, using analytic_cov
        num_samples_train=913, num_samples_val=500, num_samples_test=500,
        iterations=8, include_nonlinearity=True, poly_fc_order=1,
        optimizer='adam', batch_size=200, learning_rate=.01, hinge_margin=.25, rand_seed=50)
    with wandb.init(config=hyperparameter_defaults) as run:
        config = wandb.config

        # this rand seed ensures sample across all experiments are the same
        num_coeffs_sample = 3
        all_coeffs = sample_spherical(npoints=num_coeffs_sample, ndim=3, rand_seed=config.rand_seed)

        # dataset
        r, prior_construction, sparsity_range = graph_gen_info(config.graph_gen)
        dm_args = {'graph_gen': config.graph_gen,
                   'num_vertices': config.num_vertices,
                   'r': r, 'sparse_thresh_low': sparsity_range[0], 'sparse_thresh_high': sparsity_range[1],
                   'num_samples_train': config.num_samples_train,
                   'num_samples_val': config.num_samples_val,
                   'num_samples_test': config.num_samples_test,
                   'num_train_workers': 4 if "max" not in os.getcwd() else 0,
                   'num_val_workers': 2 if "max" not in os.getcwd() else 0,
                   'num_test_workers': 1 if "max" not in os.getcwd() else 0,
                   'batch_size': config.batch_size,
                   'val_batch_size': config.batch_size,
                   'test_batch_size': config.batch_size,
                   'rand_seed': config.rand_seed,
                   'sum_stat': config.sum_stat,
                   'fc_norm': config.fc_norm,
                   'fc_norm_val': 'symeig',
                   'binarize_labels_for_train': False,
                   'coeffs': all_coeffs[:, config.coeffs_index]
                   }
        dm = which_dm(which_exp)(**dm_args)
        dm.setup('fit')

        # Trainer
        if 'link' in config.task:
            monitor = 'val/full/error'
            loss = 'hinge'
        elif 'regress' in config.task:
            monitor = 'val/full/mae'  # CHANGE THIS FOR DETERMINING REGRESSION MONITOR
            # monitor = 'val/full/mse'
            loss = 'mse'
        check_val_every_n_epoch = 15 if torch.cuda.is_available() else 1
        run_name = ("mimo" if config.channels>1 else "no-mimo") + ("_share" if config.share_parameters else "_indep")
        trainer_args = {'max_epochs': 5000,
                        'gpus': 1 if torch.cuda.is_available() else 0,
                        'logger': WandbLogger(),
                        'check_val_every_n_epoch': check_val_every_n_epoch,
                        'callbacks': []}
        checkpoint_callback_args = \
            make_checkpoint_callback_dict(path2currDir=path2currDir, monitor=monitor, task=config.task, loss=loss,
                                          which_exp=which_exp,
                                          rand_seed=config.rand_seed, trainer_args=trainer_args,
                                          run_directory=run_name)
        checkpoint_callback = ModelCheckpoint(**checkpoint_callback_args)
        trainer_args['default_root_dir'] = path2currDir + 'checkpoints/'  # <- path to all checkpoints
        trainer_args['callbacks'].append(checkpoint_callback)

        # Don't need to train super long, in fact, too long may overfit to smaller graphs. We want to generalize.
        early_stop_cb = EarlyStopping(monitor=monitor,
                                      min_delta=1e-8,
                                      patience=13,  # ~200 epochs
                                      verbose=False,
                                      mode='min',
                                      strict=True,
                                      check_finite=False,  # allow some NaN...we will resample params
                                      stopping_threshold=None,
                                      divergence_threshold=None,
                                      check_on_train_epoch_end=False)  # runs at end of validations
        trainer_args['callbacks'].append(early_stop_cb)

        # Model
        channels = config.iterations*[config.channels]
        model_args = {
            'channels': channels,
            'poly_fc_orders': [config.poly_fc_order] * (len(channels) - 1),
            'where_normalize_slices': "after_reduction",
            'poly_basis': 'cheb',
            'optimizer': config.optimizer,
            'learning_rate': config.learning_rate,  # .1 mse (>.5 produces unstable training), .3 hinge
            'momentum': 0.9,  # not used,
            'share_parameters': config.share_parameters,
            'real_network': 'full',
            'threshold_metric_test_points': np.arange(.3, .7, .01),
            'monitor': monitor,
            'which_loss': loss,
            'logging': 'only_scalars',
            'learn_tau': True,
            'hinge_margin': config.hinge_margin,
            'include_nonlinearity': config.include_nonlinearity,
            'n_train': config.num_vertices,
            'prior_construction': prior_construction,
            # 'single' 'block' 'single_grouped' 'multi', None #, 'prior_groups': 50}
            'rand_seed': config.rand_seed}

        model = plCovNN(**model_args)
        model.subnetwork_masks = dm.subnetwork_masks

        trainer = pl.Trainer(**trainer_args)
        trainer.fit(model=model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())
        print(f'Training Results')
        print_results(model, subnetworks=['full'], stage='train')
        print('\n\n\n')

        trainer.test(datamodule=dm)  # by not feeding in model arg, trainer will load best checkpoint automatically
        print(f'Testing Results')
        best_model_path = trainer.checkpoint_callback.best_model_path
        print_results(plCovNN.load_from_checkpoint(checkpoint_path=best_model_path), subnetworks=['full'], stage='test')

if __name__ == "__main__":
    train()
