import os, sys, torch, wandb, numpy as np, pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2SmoothData = path2project + 'data/smooth_signals/'
path2BrainData = path2project + 'data/brain_data/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory

from data.brain_data.brain_data import PsuedoSyntheticDataModule
from glad import glad
from train_funcs import make_checkpoint_run_folder, make_checkpoint_callback_dict
from utils import sample_spherical
from data.network_diffusion.diffused_signals import DiffusionDataModule

project = f'glad-synthetic-final-exps'


def coeffs_str_builder(coeffs):
    coeffs_str = ""
    for f in coeffs:
        coeffs_str += str(round(f, 3)) + '_'
    return coeffs_str[:-1]


# log into correct wandb account
which_exp = 'synthetics'


def make_datamodule(wandb, coeffs):
    graph_sampling = wandb.graph_sampling
    if 'BA' in wandb.graph_sampling:  # pref-attach
        graph_sampling_params = {'graph_sampling': graph_sampling, 'num_vertices': 68, 'm': 15, 'edge_density_low': 0.3,
                                 'edge_density_high': 0.4}
        prior_construction = 'zeros'
    elif 'ER' in wandb.graph_sampling:
        graph_sampling_params = {'graph_sampling': graph_sampling, 'num_vertices': 68, 'p': 0.56,
                                 'edge_density_low': 0.5, 'edge_density_high': 0.6}
        prior_construction = 'zeros'
    elif 'sbm' in wandb.graph_sampling:
        graph_sampling_params = {'graph_sampling': graph_sampling, 'num_vertices': 21, 'num_communities': 3, 'p_in': .6,
                                 'p_out': .1, 'edge_density_low': 0.22, 'edge_density_high': 0.27}
        prior_construction = 'sbm'
    elif 'geom' in wandb.graph_sampling:
        graph_sampling_params = {'graph_sampling': graph_sampling, 'num_vertices': 68, 'r': 0.56, 'dim': 2,
                                 'edge_density_low': 0.5, 'edge_density_high': 0.6}
        prior_construction = 'zeros'
    else:
        raise ValueError(f'unrecognized graph generation func {graph_sampling}')

    dm_args = {'graph_sampling_params': graph_sampling_params,
               'gso': 'adjacency',
               'label': 'laplacian',
               'label_norm': {'normalization': wandb.label_norm, 'min_eig': wandb.label_norm_min_eig},
               'normal_mle': False,
               'num_signals': wandb.num_signals,
               'train_size': wandb.train_size,
               'val_size': wandb.val_size,
               'test_size': wandb.test_size,
               'num_workers': 4 if "max" not in os.getcwd() else 0,
               'batch_size': 200 if torch.cuda.is_available() else 64,
               'seed': wandb.seed,
               'sum_stat': 'sample_cov', ## CHANGE
               'sum_stat_norm': 'max_eig',
               'sum_stat_norm_val': 'symeig',
               'coeffs': coeffs,
               }
    dm = DiffusionDataModule(**dm_args)
    #dm.setup("fit")
    return dm


def make_trainer(wandb):
    save_checkpoint, check_val_every_n_epoch = True, 10 if torch.cuda.is_available() else 2
    if 'link' in wandb.task:
        monitor = 'val/f1/mean'
        # monitor = 'val/error/mean'
        loss = 'hinge'
    elif 'regress' in wandb.task:
        monitor = 'val/se/mean'  # mae
        loss = 'se'
    mode = 'max' if any([(a in monitor) for a in ['f1', 'mcc', 'acc']]) else 'min'
    trainer_args = {'max_epochs': wandb.max_epochs,
                    'gradient_clip_val': 0.0,  # 1.0,
                    'gpus': 1 if torch.cuda.is_available() else 0,
                    'logger': WandbLogger(name=f'graphs{wandb.graph_sampling}_coeffIdx{wandb.coeffs_index}'),
                    'check_val_every_n_epoch': check_val_every_n_epoch,
                    'callbacks': []}
    checkpoint_callback_args = \
        make_checkpoint_callback_dict(path2currDir=path2currDir, monitor=monitor, mode=mode, task=wandb.task, loss=loss,
                                      which_exp=which_exp,
                                      rand_seed=wandb.seed,
                                      run_directory=f"{wandb.graph_sampling}",
                                      misc=wandb.graph_sampling,
                                      subnets=False
                                      )
    checkpoint_callback = ModelCheckpoint(**checkpoint_callback_args)
    trainer_args['default_root_dir'] = path2currDir + 'checkpoints/'  # <- path to all checkpoints
    trainer_args['callbacks'].append(checkpoint_callback)

    min_delta = .000001
    patience = 5
    early_stop_cb = EarlyStopping(monitor=monitor,
                                  min_delta=min_delta,
                                  patience=patience,  # *check_val_every_n_epochs
                                  verbose=False,
                                  mode=mode,
                                  strict=True,
                                  check_finite=False,  # allow some NaN...we will resample params
                                  stopping_threshold=None,
                                  divergence_threshold=None,
                                  check_on_train_epoch_end=False)  # runs at end of validations
    trainer_args['callbacks'].append(early_stop_cb)

    return pl.Trainer(**trainer_args)


def make_model(wandb, dm):
    if 'link' in wandb.task:
        monitor = 'val/f1/mean'
        # monitor = 'val/error/mean'
        loss = 'hinge'
    elif 'regress' in wandb.task:
        monitor = 'val/se/mean'  # mae
        loss = 'se'
    # Model
    assert dm.non_neg_labels == False
    model_args = {
        # architecture
        'depth': wandb.depth,
        'h': 3,
        'share_parameters': True,
        'lambda_init': 1.0,
        'theta_init_offset': 1.0, #10,
        'non_neg_outputs': dm.non_neg_labels,
        # loss, optimizer
        'loss': loss,
        'monitor': monitor,
        'optimizer': 'adam',
        'learning_rate':  wandb.learning_rate,
        'adam_beta_1': 0.9, 'adam_beta_2': 0.999,
        'gamma': wandb.gamma,
        # reproducability
        'seed':  wandb.seed,
        # thresholding
        'threshold_metric': 'acc'
    }
    return glad(**model_args)


if __name__ == '__main__':

    hyperparameter_defaults = dict(
        task="regress",
        # dataset
        graph_sampling='geom',
        fc_norm="max_eig", sum_stat="sample_cov", num_signals=50,
        train_size=150, val_size=50, test_size=51,
        coeffs_index=1,
        label_norm="min_eig",
        label_norm_min_eig=100.0,
        #model
        depth=15, gamma=0.8,
        #optimizer
        learning_rate=.1,
        # training
        max_epochs=100,
        # reproduc
        seed=50
        )
    with wandb.init(config=hyperparameter_defaults) as run:
        # build coefficients -> will be the same given same seed
        num_coeffs_sample = 3
        all_coeffs = sample_spherical(npoints=num_coeffs_sample, ndim=3, rand_seed=wandb.config.seed)
        coeffs = all_coeffs[:, wandb.config.coeffs_index]

        dm = make_datamodule(wandb.config, coeffs)
        trainer = make_trainer(wandb.config)
        model = make_model(wandb.config, dm)

        #dm.setup('fit')
        trainer.fit(model=model, datamodule=dm)#)train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())
        trainer.test(datamodule=dm)  # by not feeding in model arg, trainer will load best checkpoint automatically
