import numpy as np
import torch
from sklearn.metrics import roc_auc_score, f1_score
from tqdm import tqdm
import scipy
def test_stage_DOHSC_loader(test_loader, model,c, R, device, trained_class, temp_auc, temp_f1):
    model.eval()
    c = c.to(device)
    label_score = []
    with torch.no_grad():
        tq = tqdm(test_loader, total=len(test_loader))
        for x, y in tq:
            x = x.float().to(device)
            z = model(x)
            dist = torch.sum((z - c) ** 2, dim=1)

            scores = dist - R ** 2

            label_score += list(zip(
                y.cpu().data.numpy().tolist(),
                scores.cpu().data.numpy().tolist(),
                dist.cpu().data.numpy().tolist()
            ))

        test_scores = label_score
        labels, scores,dist = zip(*label_score)
        labels = np.array(labels)
        dist = np.array(dist)
        scores = np.array(scores)

        y_pred=np.where(scores>=0, 1, 0)
        f1 = f1_score(labels, y_pred)
        test_auc = roc_auc_score(labels, scores)
    if temp_f1 < f1:
        temp_f1 = f1
        temp_auc = test_auc

    return test_auc, temp_auc, f1, temp_f1


def test_stage_for_DO2HSC_loader(test_loader,model, c, R_max, R_min, device, trained_class,temp_auc, temp_f1):
    c = c.to(device)
    model.eval()
    label_score = []
    with torch.no_grad():
        tq = tqdm(test_loader, total=len(test_loader))
        for x, y in tq:
            x = x.float().to(device)
            z = model(x)
            dist = torch.sqrt(torch.sum((z - c) ** 2, dim=1))

            scores = (dist - R_max) * (dist - R_min)

            label_score += list(zip(
                y.cpu().data.numpy().tolist(),
                scores.cpu().data.numpy().tolist(),
                dist.cpu().data.numpy().tolist()))

    labels, scores, dist = zip(*label_score)
    labels = np.array(labels)
    scores = np.array(scores)
    dist=np.array(dist)
    y_pred=np.where(scores>=0, 1, 0)
    f1 = f1_score(labels, y_pred)

    # in_point_num = np.where(scores <= 0)[0].shape[0]
    # print('Number of test data within R1-R2:', in_point_num, 'Ratio:', in_point_num / len(scores))
    test_auc = roc_auc_score(labels, scores)


    # if temp_auc < test_auc:
    #     temp_auc = test_auc
    if temp_f1 < f1:
        temp_f1 = f1
        temp_auc = test_auc
    return test_auc, temp_auc, f1, temp_f1
