import os, sys, torch, wandb, numpy as np, pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2SmoothData = path2project + 'data/smooth_signals/'
path2BrainData = path2project + 'data/brain_data/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory

from gdn import gdn
from train_funcs import make_checkpoint_run_folder, make_checkpoint_callback_dict
from data.brain_data.brain_data import PsuedoSyntheticDataModule
from utils import sample_spherical


if __name__ == "__main__":
    which_fcs = 'ps'
    which_exp = f'gdn-brain-{which_fcs}'
    project = None #f'{which_exp}-single-runs'

    # this rand seed ensures sample across all experiments are the same
    seed = 50

    # dataset
    all_rand_coeffs = sample_spherical(npoints=3, ndim=3, rand_seed=seed)
    fc_info = {'coeffs': all_rand_coeffs[:, 1],
               'remove_diag': False, 'summary_statistic': 'sample_cov',
               'normalization': 'max_eig', 'normalization_value': 'symeig',
               'frob_norm_high': None}
    dm_args = {'fc_info': fc_info,
               'sc_info': {'scaling': 9.9, 'edge_density_low': 0.35},
               'num_signals': 50,
               'num_patients_val': 50,
               'num_patients_test': 100,
               'num_workers': 2,
               'batch_size': 64,
               'seed': seed}
    dm = PsuedoSyntheticDataModule(**dm_args)

    # training
    task = 'regress'
    if 'link' in task:
        monitor = 'val/full/f1/mean'
        #monitor = 'val/error/mean'
        loss = 'hinge'
    elif 'regress' in task:
        monitor = 'val/full/se/mean' # mae
        loss = 'se'
    mode = 'max' if any([(a in monitor) for a in ['f1', 'mcc', 'acc']]) else 'min'

    # Model
    good_model_path = None
    lr, depth, share_parameters = .05, 11, False
    mimo_architecture = (depth+1)*[1]
    model_args = {
        'depth': depth,
        'mimo_architecture': (depth+1)*[1],
        'poly_fc_orders': depth*[1],
        'gradient_ablation': False,
        'poly_basis': 'cheb',
        'optimizer': "adam",
        'learning_rate': lr,
        'momentum': 0.9,  # momentum,
        'share_parameters': share_parameters,
        'real_network': 'full',
        'monitor': monitor,
        'which_loss': loss,
        'tau_info': {'type': 'scalar', 'learn': True, 'low': .01, 'high': .99, 'max_clamp_val': 0.99},
        'normalization_info': {'method': 'max_abs', 'value': .99, 'where': 'after_reduction', 'norm_last_layer': False},
        'gamma': 0.0,
        'learn_prior': False,
        'hinge_margin': 0.25,
        'include_nonlinearity': True,
        'n_train': 68,
        'prior_construction': 'mean',
        # 'single' 'block' 'single_grouped' 'multi', None #, 'prior_groups': 50}
        'seed': seed}

    max_epochs = 10000
    save_checkpoint, check_val_every_n_epoch = True, 15 if torch.cuda.is_available() else 5

    # Trainer
    trainer_args = {'max_epochs': max_epochs,
                    'gradient_clip_val': 0.0, #1.0,
                    'gpus': 1 if torch.cuda.is_available() else 0,
                    'logger': WandbLogger(project=project, name=f'task{task}_graphsBRAIN{which_fcs}_depth{depth}_lr{lr}') if project is not None else None,
                    'check_val_every_n_epoch': check_val_every_n_epoch,
                    'callbacks': []}
    checkpoint_callback_args = \
        make_checkpoint_callback_dict(path2currDir=path2currDir, monitor=monitor, mode=mode, task=task, loss=loss, which_exp=which_exp,
                                      rand_seed=seed,
                                      run_directory=f"{which_fcs}",
                                      misc=which_fcs,
                                      subnets=True
                                      )
    checkpoint_callback = ModelCheckpoint(**checkpoint_callback_args)
    trainer_args['default_root_dir'] = path2currDir + 'checkpoints/'  # <- path to all checkpoints
    trainer_args['callbacks'].append(checkpoint_callback)

    min_delta = .0001
    patience = 20
    early_stop_cb = EarlyStopping(monitor=monitor,
                                  min_delta=min_delta,
                                  patience=patience,  # *check_val_every_n_epochs
                                  verbose=False,
                                  mode=mode,
                                  strict=True,
                                  check_finite=False,  # allow some NaN...we will resample params
                                  stopping_threshold=None,
                                  divergence_threshold=None,
                                  check_on_train_epoch_end=False)  # runs at end of validations
    trainer_args['callbacks'].append(early_stop_cb)

    path2Chpts = path2currDir + 'checkpoints/'
    modelChkpt = good_model_path
    trainChkpt = modelChkpt

    if modelChkpt is not None:
        # will load both model weights and trainer history
        model, trainer = gdn.load_from_checkpoint(checkpoint_path=modelChkpt), pl.Trainer()
        trainer.fit(model, ckpt_path=modelChkpt, datamodule=dm)
    else:
        model, trainer = gdn(**model_args), pl.Trainer(**trainer_args)
        trainer.fit(model=model, datamodule=dm)

    trainer.test(model=model, datamodule=dm, ckpt_path='best' if trainChkpt is None else None) # use current model weights if no trainer_checkpoint given, else load best checkpoint
