import numpy as np
import torch
from utils.general import lossfun


def reject(X, X_hat, Y, abnormality, target, q="null"):
	# Normal label
	assert target == 1 or target ==0
	indices = torch.where(abnormality==target)
	length = len(indices[0])
	if length != 0:
		dim = tuple([i for i in range(1,X_hat.dim())])
		dist= torch.sum((X-X_hat)**2, dim=dim)#.detach()

		mse_dist = dist[indices]
		if type(q)==float or type(q)==int:
			q = torch.quantile(mse_dist, q)
			q_indices = torch.where(mse_dist> q)
		else:
			q_indices=([])
	else:
		q_indices=([])

	return {"q_indices": q_indices, 
			}




