import os, sys, wandb,  numpy as np, torch
from sklearn.covariance import graphical_lasso
np.set_printoptions(precision=4)
from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from utils.util_funcs import sample_spherical
from baselines.baseline_utils import zero_diagonals, make_synthetic_datamodule, make_pseudo_synthetic_datamodule, wandb_setup, find_best_performances
wandb_setup(offline_mode='max' in os.getcwd())
"""
# Cov and Pr are 3D tensors. Each slice is cov/prec. Same for Pr.
def normalize_dataset(cov: np.ndarray, eps=1e-12):
    assert cov.ndim == 3
    dataset_size = cov.shape[0]

    # This will be the identity mapping if Cov is already Corr.

    # row i are the variances of Cov i
    vars = np.diagonal(cov, axis1=1, axis2=2)
    stdvs = np.sqrt(vars)

    # V is 3D tensor where each slice is a diagonal matrix
    # where V[*, i, i]     <- stdv_ii
    # where V_inv[*, i, i] <- 1/stdv_ii
    V, V_inv = np.zeros_like(cov), np.zeros_like(cov)
    for i in range(dataset_size):
        V[i] = np.diag(stdvs[i])
        V_inv[i] = np.diag(1/(stdvs[i]+eps))

    # corr = V^-1 * Cov * V^-1
    corr = np.matmul(np.matmul(V_inv, cov), V_inv)

    ## corr^-1 = (V^-1 * Cov * V^-1)^-1 = V * Pr * V
    #rescale_prec = np.matmul( np.matmul(V, prec), V)

    return corr, V


def normalize_datasets_TESTS():
    ###### normalize dataset test cases #####
    N = 68
    prng = np.random.RandomState(1)
    num_test_samples = 10
    for i in range(num_test_samples):
        prec = make_sparse_spd_matrix(N, alpha=.98,
                                      smallest_coef=.01,
                                      largest_coef=.99,
                                      random_state=prng)
        cov = linalg.inv(prec)
        V = np.diag(np.sqrt(np.diag(cov)))
        V_inv = np.diag(1/(np.sqrt(np.diag(cov)) + 1e-12))
        corr = np.matmul( np.matmul(V_inv, cov), V_inv)
        rescale_prec = np.matmul( np.matmul(V, prec), V)

        #corr_out, rescale_prec_out, V_out = normalize_dataset(cov[np.newaxis, :], prec[np.newaxis, :])
        corr_out, V_out = normalize_dataset(cov[np.newaxis, :], prec[np.newaxis, :])
        assert np.allclose(corr, corr_out) \
               and np.allclose(V, V_out) #and np.allclose(rescale_prec, rescale_prec_out)\

    #####


#entire batch must succeed (optimization cannot fail on any sample)
def run_batch(wandb, dataloader):
    batch = next(iter(dataloader))
    alpha, threshold = wandb.config.alpha, wandb.config.threshold
    x, y, _, _, _ = batch
    prec_estimate, cov_estimate = run_glasso(emp_cov=x.numpy(), alpha=alpha)
    y_hat = np.abs(prec_estimate)
    # TO DO: find best threshold by 'acc', and log it. Then run test on best val at logged threshold....reduces search space of sampling
    metrics = prediction_metrics(y=y, y_hat=torch.tensor(y_hat), threshold=threshold, hinge_margin=0, hinge_slope=1, reduction='ave')
    del metrics['hinge']
    metrics['error'] = 1 - metrics['acc']
    return metrics
# emp_cov :: 3D tensor of sample covariances
# prec :: 3D tensor of TRUE precision
# alpha : regularization parameter
# max_eigval_norm :: divide input cov/corr by max_eigval, then multiply output estimates by max_eigval
def run_glasso(emp_cov: np.ndarray, alpha: float, max_eigenvalue_norm=True):

    # GLASSO computes an estimate of the inverse of sample_cov under the assumption that the data used to compute
    #  it was generated by a gaussian: N(u, prec^-1)

    # GLASSO uses L1 regularization on the off diagonal terms of the estimated precision matrix. Because of this shared
    # alpha, a different scaling of the variables can punish some variables more than others:
    # https://stats.stackexchange.com/questions/118309/is-standardizing-data-necessary-for-glasso

    # To solve this scaling issue, scale the input (to a correlation matrix) and then rescale output (to come
    # back to an estimate of the precision matrix of the sample covariance).

    # GLASSO(Cov)  = Pr = Cov^-1
    # GLASSO(Corr) = Corr^-1 = V*Pr*V
    # Thus V*GLASSO(Corr)*V -> Pr

    # Thus, we will bring all Covs -> Corr, and Pr -> V*Pr*V for a fair comparison.
    emp_corr, V = normalize_dataset(emp_cov)

    # run glasso on each emp_corr matrix
    corr_estimates_from_corr, prec_estimates_from_corr = np.zeros_like(emp_cov), np.zeros_like(emp_cov)
    for i in range(emp_corr.shape[0]):
        print(f'Running GLASSO on {i}/{emp_corr.shape[0]}')
        if max_eigenvalue_norm:
            max_eig_corr = max(np.linalg.eigvals(emp_corr[i]).real)
        else:
            max_eig_corr = 1.0 # no scaling done
        corr_estimates_from_scaled_corr, prec_estimates_from_scaled_corr, costs_ = graphical_lasso(emp_cov=emp_corr[i]/max_eig_corr, alpha=alpha, return_costs=True)
        corr_estimates_from_corr[i] = max_eig_corr*corr_estimates_from_scaled_corr
        prec_estimates_from_corr[i] = max_eig_corr*prec_estimates_from_scaled_corr

    # to get an MSE estimate, scale back UP (stdvs on diagonal) to cov scale:
    prec_estimates_from_cov = np.matmul(np.matmul(V, prec_estimates_from_corr), V)

    # GLASSO also gives estimate of our input. Scale back to cov
    cov_estimates_from_corr = np.matmul(np.matmul(V, corr_estimates_from_corr), V)


    # we can now compare either:
    #    cov_prec_estimates  vs prec                OR
    #    corr_prec_estimates vs corr_prec_estimates

    return prec_estimates_from_cov, cov_estimates_from_corr
"""


# x :: 3D tensor of sample correlation matrices with normalized maximum eigenvalue
def glasso(x, alpha: float):
    x_estimates = []
    y_hat = torch.zeros_like(x)
    for i in range(len(x)):
        x_i = x[i].numpy() if torch.is_tensor(x) else x[i]
        x_estimate, y_hat_i, costs_ = graphical_lasso(emp_cov=x_i, alpha=alpha, return_costs=True)
        y_hat[i] = torch.tensor(y_hat_i)
        x_estimates.append(x_estimate)
        print(f'{i}, ', end="")
    print(f'\nDone')
    return y_hat


# entire batch must succeed (optimization cannot fail on any sample)
def run_batch(wandb, dataloader, threshold=None, regressions=None):
    batch = next(iter(dataloader))
    alpha = wandb.config.alpha
    x, y, _, _, _ = batch

    y_hat = glasso(x=x, alpha=alpha)
    y_hat = torch.abs(y_hat)
    y_hat = zero_diagonals(y_hat)

    metrics, threshold, regressions = find_best_performances(y, y_hat, threshold, regressions)
    """
    if threshold is None:
        threshold = best_threshold_by_metric(np.linspace(0, y_hat.max(), num=20), y, y_hat)
    metrics = prediction_metrics(y=y, y_hat=y_hat, threshold=threshold, hinge_margin=0, hinge_slope=1, reduction='ave')
    del metrics['hinge'], metrics['mse'], metrics['mae']
    metrics['error'] = 1 - metrics['acc']

    if regressions is None:
        ols, ols_no_intercept = min_mse_regressions(y, y_hat)
        regressions = {'ols': ols, 'ols_no_intercept': ols_no_intercept}

    raw_mse, ols_mse, ols_no_intercept_mse = find_mse(y=y, y_hat=y_hat, regressions=regressions)

    metrics['raw_mse'] = raw_mse
    metrics['ols_mse'] = ols_mse
    metrics['ols_no_intercept_mse'] = ols_no_intercept_mse
    """
    return metrics, threshold, regressions


def train():
    hyperparameter_defaults = dict(
        graph_gen='geom',
        coeffs_index=1,
        alpha=.008,
        num_vertices=68,
        fc_norm='max_eig',
        sum_stat="sample_corr",
        num_signals=50,
        num_samples_train=0, num_samples_val=3, num_samples_test=3,
        rand_seed=50)

    with wandb.init(config=hyperparameter_defaults) as run:
        # build coefficients -> will be the same given same rand_seed
        num_coeffs_sample = 3
        all_coeffs = sample_spherical(npoints=num_coeffs_sample, ndim=3, rand_seed=wandb.config.rand_seed)
        if wandb.config.graph_gen != 'SC':
            dm = make_synthetic_datamodule(wandb, all_coeffs)
        else:
            dm = make_pseudo_synthetic_datamodule(wandb, all_coeffs)
        dm.setup('fit')
        # validation
        try:
            metrics, threshold, regressions = run_batch(wandb, dataloader=dm.val_dataloader())
            metrics = {'val/' + m: v for m, v in metrics.items()}
            run.log(data=metrics)
            run.log({'threshold': threshold,
                     'ols_coeff': regressions['ols'].coef_[0], 'ols_intercept': regressions['ols'].intercept_[0],
                     'ols-wo-intercept_coeff': regressions['ols_no_intercept'].coef_[0]})
            run.log({'val/convergence': True})
            print(f'val metrics: {metrics}')
            del metrics #so dont reuse by accident
        except Exception as e:
            run.log(data={'val/error': 1, 'val/mse': 1e8, 'val/convergence': False})
            raise RuntimeError(f'Validation failed {e}')

        # test
        try:
            metrics, _, _ = run_batch(wandb, dataloader=dm.test_dataloader(), threshold=threshold, regressions=regressions)
            metrics = {'test/' + m: v for m, v in metrics.items()}
            run.log(data=metrics)
            run.log({'test/convergence': True})
            print(f'test metrics: {metrics}')
            del metrics
        except Exception as e:
            raise RuntimeError(f'Test Failed: {e}.')

        if 'colab' in os.getcwd():
            wandb.finish()  # only needed on colab

if __name__ == '__main__':
    train()