import sys, os, numpy as np, torch, wandb
from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project)# add top level directory -> geom_dl/
from utils.util_funcs import sparsity, sample_spherical
from baselines.baseline_utils import zero_diagonals, make_synthetic_datamodule, make_pseudo_synthetic_datamodule, wandb_setup, find_best_performances
np.set_printoptions(precision=4)

wandb_setup(offline_mode='max' in os.getcwd())


def network_deconvolution(x):
    assert x.ndim == 3
    batch_size, N = x.shape[:2]
    assert x.shape[1] == x.shape[2], f'input must be square'

    vals, vecs = torch.symeig(x, eigenvectors=True)
    #assert torch.max(vals.abs()) <= 1.0, f'network deconv: max eigenvalue should be normalized'
    nd_vals = vals / (1 + vals)  # network deco
    nd = torch.matmul(torch.matmul(vecs, torch.diag_embed(nd_vals)), torch.transpose(vecs, 1, 2) )
    return nd


# feed in found threshold on test
def run_batch(wandb, dataloader, threshold=None, regressions=None):
    batch = next(iter(dataloader))
    x, y, _, _, _ = batch
    y_hat = network_deconvolution(x)
    y_hat = zero_diagonals(y_hat)

    metrics, threshold, regressions = find_best_performances(y, y_hat, threshold, regressions)
    print(f'\n### Prediction Stats: using fc {wandb.config.sum_stat}, threshold {threshold:.5f} ###')
    print(f'\tpred edge weights: max: {np.nanmax(y_hat):.3f}, median {np.nanmedian(y_hat):.3f}, mean {np.nanmean(y_hat):.3f}')
    print(f'\tave sparsity of true graphs: {sparsity(y).mean():.3f}')
    print(f'\tave sparsity of (i) raw pred: {sparsity(y_hat).mean():.3f}, (ii) pred > {threshold:.7f}:  {sparsity((y_hat > threshold) + 0.0).mean():.3f}')
    print(f'### Performance ###')
    print(f"\terrors {metrics['error']*100:.5f}%, using found best threshold: {threshold:.7f}")
    print(f"\tmses:   raw: {metrics['raw_mse']:.5f}, ols: {metrics['ols_mse']:.5f}, scaling_mse: {metrics['ols_no_intercept_mse']:.5f}")
    return metrics, threshold, regressions


def train():
    hyperparameter_defaults = dict(
        graph_gen='SC',
        coeffs_index=1,
        num_vertices=68,
        fc_norm='max_eig',
        sum_stat="sample_corr",
        num_signals=68,
        num_samples_train=1, num_samples_val=50, num_samples_test=50,
        rand_seed=50)

    with wandb.init(config=hyperparameter_defaults) as run:
        # build coefficients -> will be the same given same rand_seed
        num_coeffs_sample = 3
        all_coeffs = sample_spherical(npoints=num_coeffs_sample, ndim=3, rand_seed=wandb.config.rand_seed)
        if wandb.config.graph_gen != 'SC':
            dm = make_synthetic_datamodule(wandb, all_coeffs)
        else:
            dm = make_pseudo_synthetic_datamodule(wandb, all_coeffs)
        dm.setup('fit')

        print(f'VALIDATION')
        metrics, best_threshold, regressions = run_batch(wandb, dataloader=dm.val_dataloader())
        metrics = {'val/' + m: v for m, v in metrics.items()}
        run.log(data=metrics)
        print(f'val metrics: {metrics}')
        del metrics #so dont reuse by accident

        print(f'TEST')
        metrics, _, _ = run_batch(wandb, dataloader=dm.test_dataloader(), threshold=best_threshold, regressions=regressions)
        metrics = {'test/' + m: v for m, v in metrics.items()}
        run.log(data=metrics)
        print(f'test metrics: {metrics}')
        del metrics

        if 'colab' in os.getcwd():
            wandb.finish()  # only needed on colab

if __name__ == "__main__":
    train()
