import os, sys, wandb, numpy as np, torch
from pathlib import Path

import utils
from math import sqrt

file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project)

from data.network_diffusion.diffused_signals import DiffusionDataset
from metrics import best_threshold_by_metric
from metrics import compute_metrics
import metrics
from sklearn.covariance import GraphicalLasso, GraphicalLassoCV, graphical_lasso
from utils import adj2vec


def glasso_batch(x, y, alpha, mode='lars', tol=1e-6, max_iter=100):
    y_hat = torch.zeros_like(y)
    for i in range(len(x)):
        x_np, y_hat_np = graphical_lasso(emp_cov=x[i].numpy(), alpha=alpha, mode=mode, tol=tol, max_iter=max_iter)
        y_hat[i] = torch.tensor(y_hat_np)
    return y_hat


def run_batch(wandb, dataloader, name, test=False):
    batch = next(iter(dataloader))
    x, y = batch[:2]
    y_hat = glasso_batch(x, y, alpha=wandb.alpha, mode=wandb.mode, tol=wandb.tol, max_iter=wandb.max_iter)
    if not test:
        # if we're not testing, find best threshold by discretizing
        thresholds = torch.linspace(start=0, end=y_hat.abs().max(), steps=100)
        best_threshold = best_threshold_by_metric(y_hat.abs(), y.abs(), thresholds=thresholds)
    threshold = best_threshold if not test else wandb.config.threshold

    metrics_ = metrics.compute_metrics_glad(y_hat=y_hat, y=y, threshold=threshold, reduction=torch.mean)
    metrics_ = {name + '/' + m: v for m, v in metrics_.items()}
    return metrics


if __name__ == '__main__':
