import numpy as np
import torch
from sklearn.metrics import roc_auc_score


def test_stage(z, y, c, R, device, trained_class, temp_auc):
    z = z.to(device)
    c = c.to(device)
    label_score = []

    dist = torch.sum((z - c) ** 2, dim=1)
    scores = dist - R ** 2
    label_score += list(zip(
        y.tolist(),
        scores.cpu().data.numpy().tolist()))

    labels, scores = zip(*label_score)
    labels = np.array(labels)
    labels = np.where(labels == trained_class, 0, 1)
    scores = np.array(scores)
    test_auc = roc_auc_score(labels, scores)
    if temp_auc < test_auc:
        temp_auc = test_auc

    return test_auc, temp_auc


def test_stage_for_improved(z, y, c, R_max, R_min, device, trained_class,temp_auc):
    z = z.to(device)
    c = c.to(device)
    label_score = []

    dist = torch.sqrt(torch.sum((z - c) ** 2, dim=1))
    scores = (dist - R_max) * (dist - R_min)

    label_score += list(zip(
        y.tolist(),
        scores.cpu().numpy().tolist()))

    labels, scores = zip(*label_score)
    labels = np.array(labels)
    labels = np.where(labels == trained_class, 0, 1)

    scores = np.array(scores)
    # in_point_num = np.where(scores <= 0)[0].shape[0]
    # print('Number of test data within R1-R2:', in_point_num, 'Ratio:', in_point_num / len(scores))
    test_auc = roc_auc_score(labels, scores)
    if temp_auc < test_auc:
        temp_auc = test_auc

    return test_auc, temp_auc