import sys, os, pickle, torch
from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

import numpy as np
from model.pl_prox_model import plCovNN
from train.train_funcs import make_model_dict, make_checkpoint_callback_dict, which_dm
from utils.util_funcs import print_subnet_perf_dict, best_subnetwork_at_best_metric

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


def print_results(model, subnetworks=['frontal', 'temporal', 'occipital', 'parietal', 'full'], stage='train'):
    if stage == 'train':
        prior_metrics = model.prior_metrics_val
    elif stage == 'test':
        prior_metrics = model.prior_metrics_test
    final_epoch_metrics = model.list_of_metrics[-1]
    sort_subnetwork = 'full'
    best_err_epoch_metrics = model.best_metrics('error', sort_subnetwork=sort_subnetwork, top_k=1, maximize=False)[0]
    best_mse_epoch_metrics = model.best_metrics('mse', sort_subnetwork=sort_subnetwork, top_k=1, maximize=False)[0]
    best_mae_epoch_metrics = model.best_metrics('mae', sort_subnetwork=sort_subnetwork, top_k=1, maximize=False)[0]
    best_mcc_epoch_metrics = model.best_metrics('mcc', sort_subnetwork=sort_subnetwork, top_k=1)[0]

    print(f'\nPrior Prediction Metrics: {model.prior_construction} prior over {str.upper(stage)}')  # prior metrics is a dict over priors (single, single-batched, multi, zeros)
    print(f"Model Trained on TRAIN set, threshold chosen with {'TRAIN' if 'train' in stage else 'VALIDATION'} set. Prediction occuring on {'VALIDATION' if 'train' in stage else 'TEST'} set.")
    for i, (prior_channel_name, prior_metrics_) in enumerate(prior_metrics.items()):
        print(f'\tprior: {prior_channel_name} > {model.training_prior_threshold[i].item():.4f}')
        print_subnet_perf_dict(prior_metrics_, indents=2, convert_to_percent=['acc', 'error'])
    if stage=='train':
        print(f'Metrics of epoch which optimize metric _ on subnetwork {sort_subnetwork} (using train for train/threshold finding -> predict on validation)')
        print('\tERR'); print_subnet_perf_dict(best_err_epoch_metrics, indents=2, convert_to_percent=['acc', 'error'])
        print('\tMCC'); print_subnet_perf_dict(best_mcc_epoch_metrics, indents=2, convert_to_percent=['acc', 'error'])
        print('\tMSE'); print_subnet_perf_dict(best_mse_epoch_metrics, indents=2, convert_to_percent=['acc', 'error'])
        print('\tMAE'); print_subnet_perf_dict(best_mae_epoch_metrics, indents=2, convert_to_percent=['acc', 'error'])
        print('Final epoch metrics')
        print('\tFinal')
        print_subnet_perf_dict(final_epoch_metrics, indents=2, convert_to_percent=['acc', 'error'])
        # prints best val metrics over all epochs
        best_subnetwork_at_best_metric(model, prior_metrics, subnetworks=subnetworks, metrics=['mse', 'mae', 'error', 'mcc'], indents=1)
    return


if __name__ == "__main__":

    which_exp = 'real'
    tags = ['train_val_split']

    rand_seed = 50
    #num_splits, num_splits_train = 12, 1
    num_splits, num_patients_val, num_patients_test = None, 50, 100
    batch_size, sum_stat, fc_norm = 200 if torch.cuda.is_available() else 20, 'cov', 'max_eig'
    dm_args = {'num_splits': num_splits,
               'num_patients_val': num_patients_val,
               'num_patients_test': num_patients_test,
               'num_train_workers': 4 if torch.cuda.is_available() else 1,
               'num_val_workers': 2 if torch.cuda.is_available() else 0,
               'num_test_workers': 0,
               'batch_size': batch_size,
               'val_batch_size': None, #batch_size,
               'test_batch_size': batch_size,
               'rand_seed': rand_seed,
               'sum_stat': sum_stat,
               'fc_norm': fc_norm,
               'fc_norm_val': 'symeig',
               'binarize_labels_for_train': False,
               'fc_construction': 'concat_scandir_timeseries',
               'scan_combination': 'mean'
               }
    dm = which_dm(which_exp)(**dm_args)
    dm.setup("fit")

    # Trainer
    task = 'regress'  # link-pred or regress # CHANGE BETWEEN 2 TASKS HERE
    if 'link' in task:
        loss ='hinge'
        monitor = 'val/full/error'
        threshold_test_points = np.concatenate((np.arange(0, .4, .01), np.arange(.405, .6, .002), np.arange(.605, .9, .05)), axis=0)
    elif 'regress' in task:
        loss = 'mae'
        monitor = f'val/full/{loss}'  # CHANGE THIS FOR DETERMINING REGRESSION MONITOR
        # monitor = 'val/full/mse'
        threshold_test_points = np.concatenate((np.arange(0, .07, .02), np.arange(.09, .18, .001), np.arange(.18, .7, .05)), axis=0)


    project = 'brain-data'
    check_val_every_n_epoch, max_epochs = 30 if torch.cuda.is_available() else 3, 10000

    trainer_args = {'max_epochs': max_epochs,
                    'gpus': 1 if torch.cuda.is_available() else 0,
                    'logger': WandbLogger(project=project, name='practice'+task) if project is not None else None,
                    'check_val_every_n_epoch': check_val_every_n_epoch,
                    'callbacks': []}

    run_directory = f"{which_exp}_task_{task}"
    checkpoint_callback_args = \
        make_checkpoint_callback_dict(path2currDir=path2currDir, monitor=monitor, task=task, loss=loss,
                                      which_exp=which_exp,
                                      rand_seed=rand_seed, trainer_args=trainer_args,
                                      run_directory=run_directory)
    checkpoint_callback = ModelCheckpoint(**checkpoint_callback_args)
    trainer_args['default_root_dir'] = path2currDir + 'checkpoints/' # <- path to all checkpoints
    trainer_args['callbacks'].append(checkpoint_callback)

    # percent_change is how much of a percentage change we want from prior
    # min_delta is in absolute terms
    early_stop_cb = EarlyStopping(monitor=monitor,
                                  min_delta=1e-8,#min_delta,
                                  patience=30, #26  # *check_val_every_n_epochs
                                  verbose=False,
                                  mode='min',
                                  strict=True,
                                  check_finite=False,  # allow some NaN...we will resample params
                                  stopping_threshold=None,
                                  divergence_threshold=None, #divergence_threshold, #need to be able to get/stay within this when checked every check_val_every_n_epochs
                                  check_on_train_epoch_end=False)  # runs at end of validations

    trainer_args['callbacks'].append(early_stop_cb)
    trainer = pl.Trainer(**trainer_args)

    # Model
    channels = 8*[8] if torch.cuda.is_available() else 6*[1]
    model_args = {
        'channels': channels,
        'poly_fc_orders': [1] * (len(channels) - 1),
        'where_normalize_slices': 'after_reduction',
        'poly_basis': 'cheb',
        'optimizer': 'adam',
        'learning_rate': .008,  # .1 mse (>.5 produces unstable training), .3 hinge
        'momentum': 0.9,  # wandb.config.momentum,
        'share_parameters': False,
        'real_network': 'full',
        'threshold_metric_test_points': threshold_test_points,
        'monitor': monitor,
        'which_loss': loss,
        'logging': 'only_scalars',
        'learn_tau': True,
        'hinge_margin': .25,
        'hinge_slope': 1,
        'include_nonlinearity': True,
        'n_train': 68,
        'prior_construction': 'mean',
        # 'single' 'block' 'single_grouped' 'multi', None #, 'prior_groups': 50}
        'rand_seed': rand_seed}


    # path of form "checkpoints/filename"
    model_checkpoint_path = None
    trainer_checkpoint = False
    if model_checkpoint_path is None:
        model = plCovNN(**model_args)
    else:
        model = plCovNN.load_from_checkpoint(checkpoint_path=model_checkpoint_path)

    # Misc
    model.subnetwork_masks = dm.subnetwork_masks
    trainer.fit(model=model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())
    print(f'VALIDATION')
    print_results(model, stage='train')
    if (num_patients_test is not None) and (num_patients_test > 0):
        test_model = plCovNN.load_from_checkpoint(checkpoint_path=checkpoint_callback.best_model_path)
        test_model.subnetwork_masks = dm.subnetwork_masks
        trainer.test(model=test_model, datamodule=dm)
        print(f'TEST')
        print_results(model, stage='test')


