import os, sys, wandb, pickle, argparse, numpy as np, pprint, torch, \
    pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from model.model_utils import prediction_metrics
from train.train_funcs import which_dm
from utils.util_funcs import graph_gen_info, sample_spherical, correlation_from_covariance


def coeffs_str_builder(coeffs):
    coeffs_str = ""
    for f in coeffs:
        coeffs_str += str(round(f, 3)) + '_'
    return coeffs_str[:-1]


def make_synthetic_datamodule(wandb, all_coeffs):
    which_exp = "synthetics"
    r, prior_construction, sparsity_range = graph_gen_info(wandb.config.graph_gen)
    dm_args = {'graph_gen': wandb.config.graph_gen,
                   'r': r, 'sparse_thresh_low': sparsity_range[0], 'sparse_thresh_high': sparsity_range[1],
                   'coeffs': all_coeffs[:, wandb.config.coeffs_index],
                   'num_samples_train': wandb.config.num_samples_train,
                   'num_samples_val': wandb.config.num_samples_val,
                   'num_samples_test': wandb.config.num_samples_test,
                   'num_train_workers':1,
                   'num_val_workers': 1,
                   'num_test_workers': 1,
                   'batch_size': 10, # no train set, doesnt matter
                   'val_batch_size': wandb.config.num_samples_val,  # must do entire validation batch to choose threshold
                   'test_batch_size': wandb.config.num_samples_test,
                   'rand_seed': wandb.config.rand_seed,
                   'sum_stat': wandb.config.sum_stat,
                   'fc_norm': wandb.config.fc_norm,
                   'fc_norm_val': 'symeig',
                   'binarize_labels_for_train': False,
                   'num_signals': wandb.config.num_signals}
    dm = which_dm(which_exp)(**dm_args)
    return dm


def make_pseudo_synthetic_datamodule(wandb, all_coeffs):
    which_exp = 'pseudo-synthetics'
    prior_construction = 'mean'
    dm_args = {'coeffs': all_coeffs[:, wandb.config.coeffs_index],
               'num_patients_val': wandb.config.num_samples_val,
               'num_patients_test': wandb.config.num_samples_test,
               'num_train_workers': 1,
               'num_val_workers': 1,
               'num_test_workers': 1,
               'batch_size': 50, #note used
               'val_batch_size': wandb.config.num_samples_val,  # must do entire validation batch to choose threshold
               'test_batch_size': wandb.config.num_samples_test,
               'rand_seed': wandb.config.rand_seed,
               'sum_stat': wandb.config.sum_stat,
               'fc_norm': wandb.config.fc_norm,
               'fc_norm_val': 'symeig',
               'binarize_labels_for_train': False,
               'num_signals': wandb.config.num_signals}
    dm = which_dm(which_exp)(**dm_args)
    dm.setup("fit")
    return dm

# emp_cov :: 3D tensor of sample covariances
# threshold : float cuttoff mapping to structure
# use_corr :: whether to use covariance matrix or covariance matrix
# use_abs_val :: whether to use the abs_val of entries, or raw value
def run_hard_thresholding(emp_cov, threshold: float, use_corr=False, use_abs_val=True):

    x = emp_cov
    if use_corr:
        x = correlation_from_covariance(emp_cov)
    if use_abs_val:
        x = torch.abs(x)

    y_hat = (x > threshold) + 0.0

    return y_hat


def run_batch(wandb, dataloader, name):
    batch = next(iter(dataloader))
    x, y, _, _, _ = batch
    y_hat = run_hard_thresholding(emp_cov=x, threshold=wandb.config.threshold, use_corr=True, use_abs_val=wandb.config.abs_val)
    metrics = prediction_metrics(y=y, y_hat=y_hat, threshold=0, hinge_margin=0, hinge_slope=1, reduction='ave')
    del metrics['hinge']
    metrics['error'] = 1 - metrics['acc']
    metrics = {name + '/' + m: v for m, v in metrics.items()}
    return metrics


def train():
    hyperparameter_defaults = dict(
        graph_gen='geom',
        coeffs_index=1,
        abs_val=True,
        threshold=0.7,
        num_vertices=68,
        fc_norm="max_eig", sum_stat="sample_corr", num_signals=50,
        num_samples_train=0, num_samples_val=500, num_samples_test=500,
        rand_seed=50)

    with wandb.init(config=hyperparameter_defaults) as run:
        # build coefficients -> will be the same given same rand_seed
        num_coeffs_sample = 3
        all_coeffs = sample_spherical(npoints=num_coeffs_sample, ndim=3, rand_seed=wandb.config.rand_seed)
        if wandb.config.graph_gen != 'SC':
            dm = make_synthetic_datamodule(wandb, all_coeffs)
        else:
            dm = make_pseudo_synthetic_datamodule(wandb, all_coeffs)
        dm.setup('fit')
        # validation
        metrics_val = run_batch(wandb, dataloader=dm.val_dataloader(), name='val')
        run.log(data=metrics_val)

        # test - do this now so don't have to re run later
        metrics_test = run_batch(wandb, dataloader=dm.test_dataloader(), name='test')
        run.log(data=metrics_test)

        if 'colab' in os.getcwd():
            wandb.finish()  # only needed on colab

if __name__ == '__main__':
    train()
