import numpy as np
import torch
import networkx as nx
from torch_geometric.utils import from_networkx, to_networkx
from torch_geometric.utils import subgraph 
from torch_geometric.data import Data
from model.model import *


def filter_by_distance(idx, emb, centroid, nu):
    dist = torch.sum((emb[idx, :] - centroid) ** 2, dim=1)
    thresh = np.quantile(dist.clone().data.cpu().numpy(), 1 - nu)
    normal_idx = np.where(dist.cpu().detach().numpy()<thresh)[0]

    retain_normal_idx = np.array(idx).take(normal_idx)
    filtered_idx = list(set(idx)-set(retain_normal_idx))

    return retain_normal_idx, filtered_idx

def get_candidate(data, y_pred, emb, centroid):
    nu=0.01
    class_0 = y_pred.tolist().count(0)
    class_1 = y_pred.tolist().count(1)
    if class_0<class_1:
        ab_idx = np.where(y_pred.cpu()==0)[0]
    else:
        ab_idx = np.where(y_pred.cpu()==1)[0]
    total_idx = [i for i in range(data.num_nodes)]
    normal_idx = list(set(total_idx) - set(ab_idx))

    pre_normal_idx, aft_normal_idx = filter_by_distance(normal_idx, emb, centroid[0,:], nu)
    return ab_idx, normal_idx, pre_normal_idx, aft_normal_idx



def get_anomalous_score(anomalous_set, normal_distribution):
    anomalous_scores=[]
    loss_fn = MMD_loss()
    for abnormal_graph in anomalous_set:
        value = loss_fn(abnormal_graph.unsqueeze(0), normal_distribution)
        anomalous_scores.append(value.item())
    return anomalous_scores


