DEBUG = False
import os, sys, torch, wandb, pickle, numpy as np, pytorch_lightning as pl, math
from copy import deepcopy

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[1]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from train.train_funcs import make_checkpoint_callback_dict, which_dm
from train.real.real import print_results
from model.pl_prox_model import plCovNN
from utils.util_funcs import sample_spherical, graph_gen_info, coeffs_str_builder, construct_run_name
from data.pl_data import SyntheticDataModule


which_exp = 'synthetics'
project = "graph-size-generalization-experiment"

def log_metrics(metrics, wandb):
    print(f'Performance: ', end='')
    for metric_name, values in metrics.items():
        mean, stde = torch.mean(values), torch.std(values)/math.sqrt(len(values))
        print(f'{metric_name}: mean {mean:.5f}, stde {stde:.5f} || ', end="")
        data = {f'{metric_name}/mean': mean, f'{metric_name}/stde': stde, f'{metric_name}': values}
        wandb.log(data)
    print("")


def construct_large_data(config):
    # dataset
    # this rand seed ensures sample across all experiments are the same
    num_coeffs_sample = 3
    all_coeffs = sample_spherical(npoints=num_coeffs_sample, ndim=3, rand_seed=config['rand_seed'])

    r, prior_construction, sparsity_range = graph_gen_info(config['graph_gen'])
    dm_args = {'graph_gen': config['graph_gen'],
               'num_vertices': config['num_vertices'],
               'r': r, 'sparse_thresh_low': sparsity_range[0], 'sparse_thresh_high': sparsity_range[1],
               'num_samples_train': config['num_samples_train'],
               'num_samples_val': config['num_samples_val'],
               'num_samples_test': config['num_samples_test'],
               'num_train_workers': 16 if "max" not in os.getcwd() else 0,
               'num_val_workers': 16 if "max" not in os.getcwd() else 0,
               'num_test_workers': 16 if "max" not in os.getcwd() else 0,
               'batch_size': config['batch_size'],
               'val_batch_size': config['batch_size'],
               'test_batch_size': config['batch_size'],
               'rand_seed': config['rand_seed'],
               'sum_stat': config['sum_stat'],
               'fc_norm': config['fc_norm'],
               'fc_norm_val': 'symeig',
               'binarize_labels_for_train': False,
               'coeffs': all_coeffs[:, config['coeffs_index']]
               }
    dm = which_dm(which_exp)(**dm_args)
    return dm


def eval(graph_sizes, model_checkpoints):
    # see train_models.yaml for these values
    config = {'graph_gen': 'geom', 'num_samples_train': 0, 'num_samples_val': 0, 'num_samples_test': 200,
              'rand_seed': 50, 'sum_stat': 'analytic_cov', 'fc_norm': 'max_eig', 'coeffs_index': 1} #important! We used coeffs 1!!
    trainer = pl.Trainer(gpus=1 if torch.cuda.is_available() else 0)
    for size in graph_sizes:
        # add batch size based on graph size: big for small graphs
        config_copy = deepcopy(config)
        config_copy['batch_size'], config_copy['num_vertices'] = 20 if size < 500 else 1, size
        dm = construct_large_data(config_copy)
        dm.setup('fit')
        for model_name, model_chkpt in model_checkpoints.items():
            test_model = plCovNN.load_from_checkpoint(checkpoint_path=model_chkpt)
            test_model.subnetwork_masks = dm.subnetwork_masks
            if torch.cuda.is_available():
                # data will be moved to cuda one by one to avoid overloading memory
                test_model.cuda()
                print(f'MODEL ON GPU: {test_model.device}')

            for tune_threshold in [False]:
                metrics, threshold = test_model.test_large_graphs(val_dl=dm.val_dataloader() if tune_threshold else None,
                                                                  test_dl=dm.test_dataloader(),
                                                                  use_val_for_threshold=tune_threshold)
                if 'max' in os.getcwd():
                    os.environ["WANDB_MODE"] = 'offline' # for testing

                config_copy['tune_threshold'] = tune_threshold
                config_copy['share_parameters'], config_copy['channels'], config_copy['model_checkpoint'] = 'share' in model_name, 1 if 'no-mimo' in model_name else 8, model_chkpt
                with wandb.init(project=project, reinit=True, config=config_copy) as run:
                    wandb.run.name = model_name + f'__size={size}_tune={tune_threshold}' #override run name
                    print(f'graph_size : {size} - tune {tune_threshold}')
                    log_metrics(metrics, run)
                    run.log({'threshold': threshold})


if __name__ == "__main__":
    graph_sizes = [68, 75, 100, 150]
    path2models = path2currDir + "models/"#"/home/jovyan/proximal_gradient_topology_inference/graph_size_generalization/checkpoints/"
    # iterations = 8
    # mimo -> 8 channels in/out, else 1 channel
    # share -> share-parameters=True
    #using analytic_cov!
    model_checkpoints = \
        {'no-mimo_share': path2models+"no-mimo_share/link-pred_loss_hinge_epoch00029_error0.0623837_mcc0.8736465_mse0.0752334_mae0.2019557_seed50_date&time09-29_03:31:47.ckpt",
         # train a different no-mimo_indep for less epochs for stable generalization
         #'no-mimo_indep-overfit': path2models+"no-mimo_indep/link-pred_loss_hinge_epoch00734_error0.0432615_mcc0.9122221_mse0.4961985_mae0.4969408_seed50_date&time09-29_03:31:42.ckpt",
         #'mimo_share':    path2models+  "mimo_share/link-pred_loss_hinge_epoch00839_error0.0531281_mcc0.8926293_mse0.6532385_mae0.5226228_seed50_date&time09-29_03:31:54.ckpt",
         #'mimo_indep':    path2models+  "mimo_indep/link-pred_loss_hinge_epoch00809_error0.0354354_mcc0.9280853_mse0.6906086_mae0.5159252_seed50_date&time09-29_03:31:49.ckpt"
         }

    eval(graph_sizes, model_checkpoints)
