import os, sys, torch, wandb, numpy as np, pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2SmoothData = path2project + 'data/smooth_signals/'
path2BrainData = path2project + 'data/brain_data/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory

from gdn import gdn
from train_funcs import make_checkpoint_run_folder, make_checkpoint_callback_dict
from data.network_diffusion.diffused_signals import DiffusionDataModule
from utils import sample_spherical


if __name__ == "__main__":
    graph_sampling = 'geom'
    which_exp = f'gdn-{graph_sampling}'
    project = None #f'{which_exp}-single-runs'

    # this rand seed ensures sample across all experiments are the same
    seed = 50

    # dataset
    if 'small' in graph_sampling:
        good_model_path = None
    elif 'BA' in graph_sampling: # pref-attach
        graph_sampling_params = {'graph_sampling': graph_sampling, 'num_vertices': 68, 'r': 15, 'edge_density_low': 0.3, 'edge_density_high': 0.4}
        prior_construction = 'zeros'
        good_model_path = None
    elif 'ER' in graph_sampling:
        graph_sampling_params = {'graph_sampling': graph_sampling, 'num_vertices': 68, 'p': 0.56, 'edge_density_low': 0.5, 'edge_density_high': 0.6}
        prior_construction = 'zeros'
        good_model_path = None
    elif 'sbm' in graph_sampling:
        graph_sampling_params = {'graph_sampling': graph_sampling, 'num_vertices': 21, 'num_communities': 3, 'p_in': .6, 'p_out': .1, 'edge_density_low': 0.22, 'edge_density_high': 0.27}
        prior_construction = 'sbm'
        good_model_path = None
    elif 'geom' in graph_sampling:
        graph_sampling_params = {'graph_sampling': graph_sampling, 'num_vertices': 68, 'r': 0.56, 'dim': 2, 'edge_density_low': 0.5, 'edge_density_high': 0.6}
        prior_construction = 'zeros'
        good_model_path = None #path2currDir+"checkpoints/geom/geom-link_loss_hinge_epoch00104_error0.2598727_mcc0.4808988_f10.7557446_se807.5948486_ae991.8848267_seed50_date&time01-03_19:23:25.ckpt"
    else:
        raise ValueError(f'unrecognized graph generation func {graph_sampling}')

    batch_size = 200 if torch.cuda.is_available() else 64
    all_rand_coeffs = sample_spherical(npoints=3, ndim=3, rand_seed=seed)
    dm_args = {'graph_sampling_params': graph_sampling_params,
               'gso': 'adjacency',
               'label': 'adjacency',
               'sigma': 9,
               'normal_mle': False,
               'num_signals': 50,
               'train_size': 500,
               'val_size': 100,
               'test_size': 150,
               'num_workers': 4 if "max" not in os.getcwd() else 0,
               'batch_size': batch_size,
               'seed': seed,
               'sum_stat': 'sample_cov', ## CHANGE
               'sum_stat_norm': 'max_eig',
               'sum_stat_norm_val': 'symeig',
               'coeffs': all_rand_coeffs[:, 1],
               }
   dm = DiffusionDataModule(**dm_args)

    # training
    task = 'regress'
    if 'link' in task:
        monitor = 'val/full/f1/mean'
        #monitor = 'val/error/mean'
        loss = 'hinge'
    elif 'regress' in task:
        monitor = 'val/full/se/mean' # mae
        loss = 'se'
    mode = 'max' if any([(a in monitor) for a in ['f1', 'mcc', 'acc']]) else 'min'

    # Model
    lr, depth, share_parameters = .05, 11, False
    mimo_architecture = (depth+1)*[1]
    model_args = {
        'depth': depth,
        'mimo_architecture': (depth+1)*[1],
        'poly_fc_orders': depth*[1],
        'gradient_ablation': False,
        'poly_basis': 'cheb',
        'optimizer': "adam",
        'learning_rate': lr, #.05,  # .1 mse (>.5 produces unstable training), .3 hinge
        'momentum': 0.9,  # momentum,
        'share_parameters': share_parameters,
        'real_network': 'full',
        'monitor': monitor,
        'which_loss': loss,
        'tau_info': {'type': 'scalar', 'learn': True, 'low': .01, 'high': .99, 'max_clamp_val': 0.99},
        'normalization_info':  {'method': 'max_abs', 'value': .99, 'where': 'after_reduction', 'norm_last_layer': False, 'tau_mlp': {'num_features': 3, 'h_layer_size': 3, 'depth': 4}},
        'gamma': 0.0,
        'learn_prior': False,
        'hinge_margin': 0.25,
        'include_nonlinearity': True,
        'n_train': graph_sampling_params['num_vertices'],
        'prior_construction': prior_construction,
        # 'single' 'block' 'single_grouped' 'multi', None #, 'prior_groups': 50}
        'seed': seed}

    max_epochs = 10000
    save_checkpoint, check_val_every_n_epoch = True, 15 if torch.cuda.is_available() else 5

    # Trainer
    trainer_args = {'max_epochs': max_epochs,
                    'gradient_clip_val': 0.0, #1.0,
                    'gpus': 1 if torch.cuda.is_available() else 0,
                    'logger': WandbLogger(project=project, name=f'task{task}_graphs{graph_sampling}_depth{depth}_lr{lr}') if project is not None else None,
                    'check_val_every_n_epoch': check_val_every_n_epoch,
                    'callbacks': []}
    checkpoint_callback_args = \
        make_checkpoint_callback_dict(path2currDir=path2currDir, monitor=monitor, mode=mode, task=task, loss=loss, which_exp=which_exp,
                                      rand_seed=seed,
                                      run_directory=f"{graph_sampling}",
                                      misc=graph_sampling,
                                      subnets=True
                                      )
    checkpoint_callback = ModelCheckpoint(**checkpoint_callback_args)
    trainer_args['default_root_dir'] = path2currDir + 'checkpoints/'  # <- path to all checkpoints
    trainer_args['callbacks'].append(checkpoint_callback)

    min_delta = .0001
    patience = 20
    early_stop_cb = EarlyStopping(monitor=monitor,
                                  min_delta=min_delta,
                                  patience=patience,  # *check_val_every_n_epochs
                                  verbose=False,
                                  mode=mode,
                                  strict=True,
                                  check_finite=False,  # allow some NaN...we will resample params
                                  stopping_threshold=None,
                                  divergence_threshold=None,
                                  check_on_train_epoch_end=False)  # runs at end of validations
    trainer_args['callbacks'].append(early_stop_cb)

    path2Chpts = path2currDir + 'checkpoints/'
    modelChkpt = good_model_path
    trainChkpt = modelChkpt

    if modelChkpt is not None:
        # will load both model weights and trainer history
        model, trainer = gdn.load_from_checkpoint(checkpoint_path=modelChkpt), pl.Trainer()
        trainer.fit(model, ckpt_path=modelChkpt, datamodule=dm)
    else:
        model, trainer = gdn(**model_args), pl.Trainer(**trainer_args)
        trainer.fit(model=model, datamodule=dm)

    trainer.test(model=model, datamodule=dm, ckpt_path='best' if trainChkpt is None else None) # use current model weights if no trainer_checkpoint given, else load best checkpoint
