import torch
from adaptation import LMMDLoss
from utils.clip_util import AverageMeter
import utils.clip_util as clu
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from sklearn.metrics import balanced_accuracy_score, f1_score
from DCA_ROC import calculate_net_benefit_multiclass, net_benefit_all, net_benefit_none, roc

def toeval(model):
    model.model.eval()
    model.fea_attn.eval()


def test(args, model, data_loader, device, discr=None):
    toeval(model)
    if discr is not None:
        discr.eval()
    total = 0
    correct = 0
    bacc = 0
    f1 = 0
    Prediction = []
    Preds = []
    Label = []
    texts = model.labels
    text_features = clu.get_text_features_list(texts, model.model, device=device).float()
    with torch.no_grad():
        for batch in data_loader:
            image, text, label, _ = batch
            if len(text) > 1:
                image = image.to(device)
                label = label.to(device)
                image_features = model.model.encode_image(image).float()

                similarity = discr(image_features, text_features)
                _, indices = similarity.topk(1)  # indices equals pseudo-label
                total += len(label)
                pred = torch.squeeze(indices)
                res = torch.cat([pred.view(-1, 1), label.view(-1, 1)], dim=1)

                res = res.cpu().numpy()
                Preds.append(pred.cpu().numpy())
                Prediction.append(similarity.cpu().numpy())
                Label.append(label.cpu().numpy())
                correct += np.sum(np.array(res)[:, 0] == np.array(res)[:, 1])
        all_preds = np.concatenate(Prediction)
        all_predcition = np.concatenate(Preds)
        all_labels = np.concatenate(Label)
        net_benefits = calculate_net_benefit_multiclass(all_labels, all_preds)
        net_benefits_all = net_benefit_all(all_labels)
        net_benefits_none = net_benefit_none()
        tpr, fpr = roc(all_labels, all_preds)
        bacc = balanced_accuracy_score(all_labels, all_predcition)
        f1 = f1_score(all_labels, all_predcition, average='macro')

    return correct/total, bacc, f1, net_benefits, tpr, fpr, net_benefits_all, net_benefits_none



def test_K(args, model, data_loader, device, discr=None):
    model.fea_attn.eval()
    if discr is not None:
        discr.eval()
    total = 0
    correct = 0
    bacc = 0
    f1 = 0
    Prediction = []
    Preds = []
    Label = []
    texts = model.labels
    text_features = clu.get_text_features_list(texts, model.model, device=device).float()
    with torch.no_grad():
        for batch in data_loader:
            image, text, label, _ = batch
            if len(text) > 1:
                image = image.to(device)
                label = label.to(device)
                image_features = model.model.encode_image(image).float()
                similarity = discr(image_features, text_features)
                _, indices = similarity.topk(1)  # indices equals pseudo-label
                total += len(label)
                pred = torch.squeeze(indices)
                # print(indices, pred)
                res = torch.cat([pred.view(-1, 1), label.view(-1, 1)], dim=1)

                res = res.cpu().numpy()
                Preds.append(pred.cpu().numpy())
                Prediction.append(similarity.cpu().numpy())
                Label.append(label.cpu().numpy())
                correct += np.sum(np.array(res)[:, 0] == np.array(res)[:, 1])
        all_preds = np.concatenate(Prediction)
        all_predcition = np.concatenate(Preds)
        all_labels = np.concatenate(Label)
        net_benefits = calculate_net_benefit_multiclass(all_labels, all_preds)
        net_benefits_all = net_benefit_all(all_labels)
        net_benefits_none = net_benefit_none()
        tpr, fpr = roc(all_labels, all_preds)
        bacc = balanced_accuracy_score(all_labels, all_predcition)
        f1 = f1_score(all_labels, all_predcition, average='macro')

    return correct/total, bacc, f1, net_benefits, tpr, fpr, net_benefits_all, net_benefits_none