DEBUG = False
import os, sys, torch, wandb, pickle, numpy as np, pytorch_lightning as pl

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from train.train_funcs import make_checkpoint_callback_dict, which_dm
from train.real.real import print_results
from model.pl_prox_model import plCovNN
from utils.util_funcs import sample_spherical, graph_gen_info, coeffs_str_builder, construct_run_name

if 'max' in os.getcwd():
    os.environ["WANDB_MODE"] = "offline"

if __name__ == "__main__":

    which_exp = 'synthetics'
    project = f'{which_exp}-single-runs'
    # this rand seed ensures sample across all experiments are the same
    rand_seed = 50
    num_coeffs_sample = 3
    all_coeffs = sample_spherical(npoints=num_coeffs_sample, ndim=3, rand_seed=rand_seed)

    # dataset
    num_vertices, graph_gen, coeffs, num_signals = 68, 'geom', all_coeffs[:, 1], 50
    coeffs = [.5, .5, .2]
    sum_stat = 'sample_cov'
    r, prior_construction, sparsity_range = graph_gen_info(graph_gen)
    num_samples_train, num_samples_val, num_samples_test = 500, 100, 100

    # training
    batch_size = 200 if torch.cuda.is_available() else 32
    task = 'link-pred'
    if 'link' in task:
        monitor = 'val/full/error'
        loss = 'hinge'
    elif 'regress' in task:
        monitor = 'val/full/mae'  # CHANGE THIS FOR DETERMINING REGRESSION MONITOR
        # monitor = 'val/full/mse'
        loss = 'mse'
    max_epochs = 10000
    save_checkpoint, check_val_every_n_epoch = True, 15 if torch.cuda.is_available() else 1
    dm_args = {'graph_gen': graph_gen,
               'r': r, 'sparse_thresh_low': sparsity_range[0], 'sparse_thresh_high': sparsity_range[1],
               'num_samples_train': num_samples_train,
               'num_samples_val': num_samples_val,
               'num_samples_test': num_samples_test,
               'num_train_workers': 4 if "max" not in os.getcwd() else 0,
               'num_val_workers': 2 if "max" not in os.getcwd() else 0,
               'num_test_workers': 1 if "max" not in os.getcwd() else 0,
               'batch_size': batch_size,
               'val_batch_size': batch_size,  # must do entire validation batch to choose threshold
               'test_batch_size': batch_size,
               'rand_seed': rand_seed,
               'sum_stat': sum_stat, ## CHANGE
               'fc_norm': 'max_eig',
               'fc_norm_val': 'symeig',
               'binarize_labels_for_train': False,
               'coeffs': coeffs
               }
    dm = which_dm(which_exp)(**dm_args)
    dm.setup('fit')

    # Trainer
    trainer_args = {'max_epochs': max_epochs,
                    'gpus': 1 if torch.cuda.is_available() else 0,
                    'logger': WandbLogger(project=project) if project is not None else None,
                    'check_val_every_n_epoch': check_val_every_n_epoch,
                    'callbacks': []}
    checkpoint_callback_args = \
        make_checkpoint_callback_dict(path2currDir=path2currDir, monitor=monitor, task=task, loss=loss, which_exp=which_exp,
                                      rand_seed=rand_seed, trainer_args=trainer_args,
                                      run_directory=f"{graph_gen}_coeffs{coeffs_str_builder(coeffs)}"
                                      )
    checkpoint_callback = ModelCheckpoint(**checkpoint_callback_args)
    trainer_args['default_root_dir'] = path2currDir + 'checkpoints/'  # <- path to all checkpoints
    trainer_args['callbacks'].append(checkpoint_callback)

    # error: .0005 -> .05% decrease in error every _ epochs
    # .01 % decr in error every 15*13 ~= 200 epochs.
    min_delta = .0001 #if 'error' in monitor else None

    #if min_delta is None:
    #    raise ValueError(f"loss is {loss}: consider min delta for regression")
    # adjust patience for batch size...larger batch sizes will go through data faster, so give them more time
    patience = 13
    early_stop_cb = EarlyStopping(monitor=monitor,
                                  min_delta=min_delta,
                                  patience=patience,  # *check_val_every_n_epochs
                                  verbose=False,
                                  mode='min',
                                  strict=True,
                                  check_finite=False,  # allow some NaN...we will resample params
                                  stopping_threshold=None,
                                  divergence_threshold=None,
                                  check_on_train_epoch_end=False)  # runs at end of validations
    trainer_args['callbacks'].append(early_stop_cb)

    # Model
    channels = 8*[1]
    share_parameters = False
    poly_fc_order = 1
    model_args = {
        'channels': channels,
        'poly_fc_orders': [poly_fc_order] * (len(channels) - 1),
        'where_normalize_slices': "after_reduction",
        'poly_basis': 'cheb',
        'optimizer': "adam",
        'learning_rate': .01,  # .1 mse (>.5 produces unstable training), .3 hinge
        'momentum': 0.9,  # momentum,
        'share_parameters': share_parameters,
        'real_network': 'full',
        'threshold_metric_test_points': np.arange(.3, .7, .01),
        'monitor': monitor,
        'which_loss': loss,
        'logging': 'only_scalars',
        'learn_tau': True,
        'hinge_margin': 0.25,
        'include_nonlinearity': True,
        'n_train': num_vertices,
        'prior_construction': prior_construction,
        # 'single' 'block' 'single_grouped' 'multi', None #, 'prior_groups': 50}
        'rand_seed': rand_seed}

    path2Chpts = path2currDir + 'checkpoints/geom_prior_zeros_coeffs-0.017_0.797_0.604/'
    modelChkpt = None
    trainChkpt = modelChkpt

    model = plCovNN(**model_args) if modelChkpt is None else plCovNN.load_from_checkpoint(checkpoint_path=modelChkpt)
    model.subnetwork_masks = dm.subnetwork_masks

    trainer = pl.Trainer(resume_from_checkpoint=trainChkpt) if trainChkpt is None else pl.Trainer(**trainer_args)
    trainer.fit(model=model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())
    print(f'Training Results')
    print_results(model, subnetworks=['full'], stage='train')
    print('\n\n\n')

    trainer.test(datamodule=dm, model=None if trainChkpt is None else model) # by not feeding in model arg, trainer will load best checkpoint automatically
    print(f'Testing Results')
    best_model_path = modelChkpt if modelChkpt is not None else trainer.checkpoint_callback.best_model_path
    #print_results(plCovNN.load_from_checkpoint(checkpoint_path=best_model_path), subnetworks=['full'], stage='test')
