import os, sys, wandb,  numpy as np, torch
from sklearn.linear_model import LinearRegression
np.set_printoptions(precision=4)
from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/
from utils.util_funcs import graph_gen_info
from train.train_funcs import which_dm
from model.model_utils import prediction_metrics, best_threshold_by_metric


def wandb_setup(offline_mode=True):
    os.environ["WANDB_AGENT_MAX_INITIAL_FAILURES"] = "100"
    os.environ["WANDB_AGENT_DISABLE_FLAPPING"] = "true"
    if offline_mode:
        os.environ["WANDB_MODE"] = "offline"


def coeffs_str_builder(coeffs):
    coeffs_str = ""
    for f in coeffs:
        coeffs_str += str(round(f, 3)) + '_'
    return coeffs_str[:-1]


def make_synthetic_datamodule(wandb, all_coeffs):
    which_exp = "synthetics"
    r, prior_construction, sparsity_range = graph_gen_info(wandb.config.graph_gen)
    dm_args = {'graph_gen': wandb.config.graph_gen,
               'r': r, 'sparse_thresh_low': sparsity_range[0], 'sparse_thresh_high': sparsity_range[1],
               'coeffs': all_coeffs[:, wandb.config.coeffs_index],
               'num_samples_train': wandb.config.num_samples_train,
               'num_samples_val': wandb.config.num_samples_val,
               'num_samples_test': wandb.config.num_samples_test,
               'num_train_workers': 0,
               'num_val_workers': 4,
               'num_test_workers': 4,
               'batch_size': 10, # no train set, doesnt matter
               'val_batch_size': wandb.config.num_samples_val,  # must do entire validation batch to choose threshold
               'test_batch_size': wandb.config.num_samples_test,
               'rand_seed': wandb.config.rand_seed,
               'sum_stat': wandb.config.sum_stat,
               'fc_norm': wandb.config.fc_norm,
               'fc_norm_val': 'symeig',
               'binarize_labels_for_train': False,
               'num_signals': wandb.config.num_signals}
    dm = which_dm(which_exp)(**dm_args)
    return dm


def make_pseudo_synthetic_datamodule(wandb, all_coeffs):
    which_exp = 'pseudo-synthetics'
    prior_construction = 'mean'
    dm_args = {'coeffs': all_coeffs[:, wandb.config.coeffs_index],
               'num_patients_val': wandb.config.num_samples_val,
               'num_patients_test': wandb.config.num_samples_test if wandb.config.num_samples_test>0 else 1, # fails on 0
               'num_train_workers': 4 if "max" not in os.getcwd() else 0,
               'num_val_workers': 4 if "max" not in os.getcwd() else 0,
               'num_test_workers': 4 if "max" not in os.getcwd() else 0,
               'batch_size': wandb.config.num_samples_val, #doesnt matter, not using train set
               'val_batch_size': wandb.config.num_samples_val,  # must do entire validation batch to choose threshold
               'test_batch_size': wandb.config.num_samples_test,
               'rand_seed': wandb.config.rand_seed,
               'sum_stat': wandb.config.sum_stat,
               'fc_norm': wandb.config.fc_norm,
               'fc_norm_val': 'symeig',
               'binarize_labels_for_train': False,
               'num_signals': wandb.config.num_signals}
    dm = which_dm(which_exp)(**dm_args)
    dm.setup("fit")
    return dm


# zero diagonals of all slices
def zero_diagonals(a):
    assert a.ndim == 3 and a.shape[-1] == a.shape[-2]
    # construct tensor which has 0 on on slice diagonals, 1s everywhere else
    N = a.shape[-1]
    remove_diag = torch.ones((N, N)) - torch.eye(N)
    zd = torch.broadcast_to(remove_diag, (len(a), N, N))
    # ignore diagonals
    return a*zd


# vectorize inputs and regress to min. mse
def min_mse_regressions(y, y_hat):
    y_np, y_hat_np = y.view(-1, 1).numpy(), y_hat.view(-1, 1).numpy()
    ols = LinearRegression().fit(y_hat_np, y_np)
    ols_no_intercept = LinearRegression(fit_intercept=False).fit(y_hat.view(-1, 1).numpy(), y.view(-1, 1).numpy())
    return ols, ols_no_intercept


def find_mse(y, y_hat, regressions):
    assert torch.is_tensor(y) and torch.is_tensor(y_hat)
    og_shape = y.numpy().shape
    y_np, y_hat_np = y.view(-1, 1).numpy(), y_hat.view(-1, 1).numpy()
    ols, ols_no_intercept = regressions['ols'], regressions['ols_no_intercept']
    raw_mse = ((y_hat_np-y_np)**2).reshape(og_shape).mean(axis=(1, 2)).mean()
    ols_mse = ((ols.predict(y_hat_np) - y_np) ** 2).reshape(og_shape).mean(axis=(1, 2)).mean()
    ols_no_intercept_mse = ((ols_no_intercept.predict(y_hat_np)-y_np)**2).reshape(og_shape).mean(axis=(1,2)).mean()
    return raw_mse, ols_mse, ols_no_intercept_mse


def find_best_performances(y, y_hat, threshold=None, regressions=None, num_threshold_points=20):
    # if threshold is None, find the best by discretizing possible thresholds
    if threshold is None:
        threshold = best_threshold_by_metric(np.linspace(0, y_hat.max(), num=num_threshold_points), torch.tensor(y), torch.tensor(y_hat))
    metrics = prediction_metrics(y=y, y_hat=torch.tensor(y_hat), threshold=threshold, hinge_margin=0, hinge_slope=1, reduction='ave')
    metrics['error'] = 1 - metrics['acc']
    del metrics['mse'], metrics['mae'], metrics['hinge']

    if regressions is None:
        ols, ols_no_intercept = min_mse_regressions(y, y_hat)
        regressions = {'ols': ols, 'ols_no_intercept': ols_no_intercept}

    raw_mse, ols_mse, ols_no_intercept_mse = find_mse(y=y, y_hat=y_hat, regressions=regressions)
    metrics['raw_mse'], metrics['ols_mse'], metrics['ols_no_intercept_mse'] = raw_mse, ols_mse, ols_no_intercept_mse

    return metrics, threshold, regressions
