import sys

import numpy as np

sys.path.append('..')

import torch
# import visualization_metric
from . import visualization_metric


def evaluate(net, d_loader):
    net.eval()
    with torch.no_grad():
        scs = []
        nmis = []
        losses = []
        for xs, n_view_graphs, z, zdist, y in d_loader:
            # inp = inp.to(torch.float32).to(device)
            # z = torch.from_numpy(np.concatenate(z)).to(torch.float32).to(device)

            z_hat = net(xs, n_view_graphs)
            # y_hat = y_hat.view(-1)
            zdist = zdist[0].to(torch.float32).to(net.device)
            student_t_loss = net.loss_fn_kl_t(xs, n_view_graphs, zdist)

            nmi, sc = visualization_metric.get_nmi_sc(z_hat.cpu().numpy(), y[0].tolist())

            scs.append(sc)
            nmis.append(nmi)
            losses.append(student_t_loss)

    scs_t = torch.mean(torch.Tensor(scs))
    nmis_t = torch.mean(torch.Tensor(nmis))
    losses_t = torch.mean(torch.Tensor(losses))
    net.train()
    return nmis_t, scs_t, losses_t


def get_true_metric(d_loader):
    scs = {}
    nmis = {}
    for d_name, n_view_graphs in d_loader:

        tsne_z = n_view_graphs.ndata['z']
        umap_z = n_view_graphs.ndata['z_umap']
        y = n_view_graphs.ndata['y']
        tsne_nmi, tsne_sc = visualization_metric.get_nmi_sc(tsne_z, y.tolist())
        umap_nmi, umap_sc = visualization_metric.get_nmi_sc(umap_z, y.tolist())

        print(d_name, (tsne_nmi, umap_nmi), (tsne_sc, umap_sc))
        scs[d_name] = {'tsne': tsne_sc, 'umap': umap_sc}
        nmis[d_name] = {'tsne': tsne_nmi, 'umap': umap_nmi}

    # scs_t = torch.mean(torch.Tensor(scs))
    # nmis_t = torch.mean(torch.Tensor(nmis))
    return nmis, scs


def get_ground_truth_from_precomp(d_names, gt_dict):
    nmis = []
    scs = []
    for d_name in d_names:
        # print(gt_dict[0].keys())
        nmi = gt_dict[0][(d_name,)]
        sc = gt_dict[1][(d_name,)]
        nmis.append(nmi)
        scs.append(sc)
    nmis = np.array(nmis)
    scs = np.array(scs)
    return (np.mean(nmis), np.std(nmis)), (np.mean(scs), np.std(scs))


def filter_low_nmi(d_names, gt_dict, th=0.1):
    new_d_names = []
    for d_name in d_names:
        nmi = gt_dict[0][(d_name,)]
        if nmi >= th:
            new_d_names.append(d_name)
    return new_d_names
