import os, sys, wandb, pickle, argparse, numpy as np, pprint, torch, \
    pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from model.pl_prox_model import plCovNN
from train.train_funcs import make_checkpoint_callback_dict, which_dm
from train.real.real import print_results
from utils.util_funcs import graph_gen_info, sample_spherical

def coeffs_str_builder(coeffs):
    coeffs_str = ""
    for f in coeffs:
        coeffs_str += str(round(f, 3)) + '_'
    return coeffs_str[:-1]

# parse graph_gen argument
"""

parser = argparse.ArgumentParser(description="which synthetic to run")
parser.add_argument('--graph_gen', type=str)
#args = parser.parse_args()
args, unknown = parser.parse_known_args()
print(args.graph_gen)
assert args.graph_gen in ['geom', 'ER', 'er', 'sbm', 'pref_attach']
graph_gen = args.graph_gen
"""
"""
hyperparameter_defaults = dict(
    task = "link-pred",
    num_vertices=68, fc_norm="max_eig", sum_stat="sample_cov", num_signals=50,
    num_samples_train=913, num_samples_val=500, num_samples_test=500,
    coeffs_index=1,
    #model
    channels=8, iterations=8, share_parameters=False,
    poly_fc_order=1, where_normalize_slices="after_reduction",
    learn_tau=True, include_nonlinearity=True,
    #optimizer
    optimizer="adam", learning_rate=.008, hinge_margin=.25,
    # training
    max_epochs=4000, batch_size=50, rand_seed=50
    )
project = f"{graph_gen}-final-experiments"
if 'max' in os.getcwd():
    project = 'tinker_YAML--' + project
wandb.init(config=hyperparameter_defaults, project=project)
"""
wandb.init()

if 'link' in wandb.config.task:
    loss, monitor = "hinge", "val/full/error"
    threshold_test_points = np.concatenate((np.arange(0, .4, .01), np.arange(.405, .6, .005), np.arange(.605, .9, .05)),axis=0)
else:
    loss, monitor = 'mse', "val/full/mse"
    threshold_test_points = np.concatenate((np.arange(0, .07, .02), np.arange(.09, .18, .001), np.arange(.18, .7, .05)), axis=0)

# build coefficients -> will be the same given same rand_seed
num_coeffs_sample = 3
all_coeffs = sample_spherical(npoints=num_coeffs_sample, ndim=3, rand_seed=wandb.config.rand_seed)
coeffs = all_coeffs[:, wandb.config.coeffs_index]

prior_construction='mean'

def make_datamodule(wandb):

    dm_args = {'coeffs': all_coeffs[:, wandb.config.coeffs_index],
               'num_patients_val': wandb.config.num_patients_val,
               'num_patients_test': wandb.config.num_patients_test,
               'num_train_workers': 4 if "max" not in os.getcwd() else 0,
               'num_val_workers': 2 if "max" not in os.getcwd() else 0,
               'num_test_workers': 1 if "max" not in os.getcwd() else 0,
               'batch_size': wandb.config.batch_size,
               'val_batch_size': None,  # must do entire validation batch to choose threshold
               'test_batch_size': wandb.config.batch_size,
               'rand_seed': wandb.config.rand_seed,
               'sum_stat': wandb.config.sum_stat,
               'fc_norm': wandb.config.fc_norm,
               'fc_norm_val': 'symeig',
               'binarize_labels_for_train': False,
               'num_signals': wandb.config.num_signals}
    dm = which_dm(which_exp)(**dm_args)
    dm.setup("fit")
    return dm


def make_trainer(wandb):
    # Trainer
    check_val_every_n_epoch = 30 if torch.cuda.is_available() else 1

    # this will be done for each graph gen over each 3 coeffs
    run_directory = f"task_{wandb.config.task}_coeffs{coeffs_str_builder(coeffs)}_shareParams{wandb.config.share_parameters}_"
    trainer_args = {'max_epochs': wandb.config.max_epochs,
                    'gpus': 1 if torch.cuda.is_available() else 0,
                    'logger': WandbLogger(name=run_directory),# if project is not None else None,
                    'check_val_every_n_epoch': check_val_every_n_epoch,
                    'callbacks': []}
    checkpoint_callback_args = \
        make_checkpoint_callback_dict(path2currDir=path2currDir, monitor=monitor, task=wandb.config.task, loss=loss,
                                      which_exp=which_exp,
                                      rand_seed=wandb.config.rand_seed, trainer_args=trainer_args,
                                      run_directory=run_directory)
    checkpoint_callback = ModelCheckpoint(**checkpoint_callback_args)
    trainer_args['default_root_dir'] = path2currDir + 'checkpoints/'  # <- path to all checkpoints
    trainer_args['callbacks'].append(checkpoint_callback)

    # BE VERY LENIENT, THIS IS FINAL RUN, BUT DONT WANT TO WASTE TOO MUCH TIME
    # stop if hasn't improved AT ALL in 500 epochs
    early_stop_cb = EarlyStopping(monitor=monitor,
                                  min_delta=1e-8,
                                  patience=25,  # *check_val_every_n_epochs (30) => 600 epochs
                                  verbose=False,
                                  mode='min',
                                  strict=True,
                                  check_finite=False,  # allow some NaN...we will resample params
                                  stopping_threshold=None,
                                  divergence_threshold=None,
                                  # prior_vals[wandb.config.graph_gen][wandb.config.prior_construction][monitor],
                                  check_on_train_epoch_end=False)  # runs at end of validations
    trainer_args['callbacks'].append(early_stop_cb)
    trainer = pl.Trainer(**trainer_args)
    return trainer


def make_model(wandb):
    # Model
    channels = wandb.config.iterations * [wandb.config.channels]
    if not wandb.config.share_parameters:
        channels[0], channels[-1] = 1, 1  # save unused params

    model_args = {
        'channels': channels,
        'poly_fc_orders': [wandb.config.poly_fc_order] * (len(channels) - 1),
        'where_normalize_slices': wandb.config.where_normalize_slices,
        'poly_basis': 'cheb',
        'optimizer': wandb.config.optimizer,
        'learning_rate': wandb.config.learning_rate,  # .1 mse (>.5 produces unstable training), .3 hinge
        'momentum': 0.9,  # momentum ununsed atm
        'share_parameters': wandb.config.share_parameters,
        'real_network': 'full',
        'threshold_metric_test_points': threshold_test_points,
        'monitor': monitor,
        'which_loss': loss,
        'logging': 'only_scalars',
        'learn_tau': wandb.config.learn_tau,
        'hinge_margin': wandb.config.hinge_margin,
        'include_nonlinearity': wandb.config.include_nonlinearity,
        'n_train': wandb.config.num_vertices,
        'prior_construction': prior_construction,
        # 'single' 'block' 'single_grouped' 'multi', None #, 'prior_groups': 50}
        'rand_seed': wandb.config.rand_seed}
    model = plCovNN(**model_args)
    return model


def train_model(dm, model, trainer):
    model.subnetwork_masks = dm.subnetwork_masks

    trainer.fit(model=model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())
    print_results(model)
    if (wandb.config.num_patients_test is not None) and (wandb.config.num_patients_test > 0):
        test_model = plCovNN.load_from_checkpoint(checkpoint_path=trainer.checkpoint_callback.best_model_path)
        test_model.subnetwork_masks = dm.subnetwork_masks
        trainer.test(datamodule=dm)

    if 'content' in os.getcwd():
        wandb.finish()  # only needed on colab

if __name__ == '__main__':
   dm = make_datamodule(wandb)
   trainer = make_trainer(wandb)
   model = make_model(wandb)
   train_model(dm, model, trainer)
