import numpy as np
import torch
import torch.nn.functional as F
# from imagenet import get_x_y_from_data_dict
from sklearn.svm import SVC

from .utils import AverageMeter, accuracy
def validate(val_loader, model, criterion, args):
    """
    Run evaluation
    """
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    # if args.imagenet_arch:
    #     device = (
    #         torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    #     )
    #     for i, data in enumerate(val_loader):
    #         image, target = get_x_y_from_data_dict(data, device)
    #         with torch.no_grad():
    #             output = model(image)
    #             loss = criterion(output, target)

    #         output = output.float()
    #         loss = loss.float()

    #         # measure accuracy and record loss
    #         prec1 = utils.accuracy(output.data, target)[0]
    #         losses.update(loss.item(), image.size(0))
    #         top1.update(prec1.item(), image.size(0))

    #         if i % args.print_freq == 0:
    #             print(
    #                 "Test: [{0}/{1}]\t"
    #                 "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
    #                 "Accuracy {top1.val:.3f} ({top1.avg:.3f})".format(
    #                     i, len(val_loader), loss=losses, top1=top1
    #                 )
    #             )

    #     print("valid_accuracy {top1.avg:.3f}".format(top1=top1))
    # else:
    for i, (image, target) in enumerate(val_loader):
        image = image.cuda()
        target = target.cuda()

        # compute output
        with torch.no_grad():
            output = model(image)
            loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), image.size(0))
        top1.update(prec1.item(), image.size(0))

        if i % args.print_freq == 0:
            print(
                "Test: [{0}/{1}]\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                "Accuracy {top1.val:.3f} ({top1.avg:.3f})".format(
                    i, len(val_loader), loss=losses, top1=top1
                )
            )
    return top1.avg


def collect_prob(data_loader, model):
    if data_loader is None:
        return torch.zeros([0, 10]), torch.zeros([0])

    prob = []
    targets = []

    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            # try:
            if True:
                batch = [tensor.to(next(model.parameters()).device) for tensor in batch]
                data, target = batch
            # except:
            #     device = (
            #         torch.device("cuda:0")
            #         if torch.cuda.is_available()
            #         else torch.device("cpu")
            #     )
            #     data, target = get_x_y_from_data_dict(batch, device)
            with torch.no_grad():
                output = model(data)
                prob.append(F.softmax(output, dim=-1).data)
                targets.append(target)

    return torch.cat(prob), torch.cat(targets)


def SVC_fit_predict(shadow_train, shadow_test, target_train, target_test):
    n_shadow_train = shadow_train.shape[0]
    n_shadow_test = shadow_test.shape[0]
    n_target_train = target_train.shape[0]
    n_target_test = target_test.shape[0]

    X_shadow = (
        torch.cat([shadow_train, shadow_test])
        .cpu()
        .numpy()
        .reshape(n_shadow_train + n_shadow_test, -1)
    )
    Y_shadow = np.concatenate([np.ones(n_shadow_train), np.zeros(n_shadow_test)])

    clf = SVC(C=3, gamma="auto", kernel="rbf")
    clf.fit(X_shadow, Y_shadow)

    accs = []

    if n_target_train > 0:
        X_target_train = target_train.cpu().numpy().reshape(n_target_train, -1)
        acc_train = clf.predict(X_target_train).mean()
        accs.append(acc_train)

    if n_target_test > 0:
        X_target_test = target_test.cpu().numpy().reshape(n_target_test, -1)
        acc_test = 1 - clf.predict(X_target_test).mean()
        accs.append(acc_test)

    return np.mean(accs)


def SVC_MIA(shadow_train, target_train, target_test, shadow_test, model):
    shadow_train_prob, shadow_train_labels = collect_prob(shadow_train, model)
    shadow_test_prob, shadow_test_labels = collect_prob(shadow_test, model)

    target_train_prob, target_train_labels = collect_prob(target_train, model)
    target_test_prob, target_test_labels = collect_prob(target_test, model)

    # shadow_train_corr = (
    #     torch.argmax(shadow_train_prob, axis=1) == shadow_train_labels
    # ).int()
    # shadow_test_corr = (
    #     torch.argmax(shadow_test_prob, axis=1) == shadow_test_labels
    # ).int()
    # target_train_corr = (
    #     torch.argmax(target_train_prob, axis=1) == target_train_labels
    # ).int()
    # target_test_corr = (
    #     torch.argmax(target_test_prob, axis=1) == target_test_labels
    # ).int()

    shadow_train_conf = torch.gather(shadow_train_prob, 1, shadow_train_labels[:, None])
    shadow_test_conf = torch.gather(shadow_test_prob, 1, shadow_test_labels[:, None])
    target_train_conf = torch.gather(target_train_prob, 1, target_train_labels[:, None])
    target_test_conf = torch.gather(target_test_prob, 1, target_test_labels[:, None])

    # shadow_train_entr = entropy(shadow_train_prob)
    # shadow_test_entr = entropy(shadow_test_prob)

    # target_train_entr = entropy(target_train_prob)
    # target_test_entr = entropy(target_test_prob)

    # shadow_train_m_entr = m_entropy(shadow_train_prob, shadow_train_labels)
    # shadow_test_m_entr = m_entropy(shadow_test_prob, shadow_test_labels)
    # if target_train is not None:
    #     target_train_m_entr = m_entropy(target_train_prob, target_train_labels)
    # else:
    #     target_train_m_entr = target_train_entr
    # if target_test is not None:
    #     target_test_m_entr = m_entropy(target_test_prob, target_test_labels)
    # else:
    #     target_test_m_entr = target_test_entr

    # acc_corr = SVC_fit_predict(
    #     shadow_train_corr, shadow_test_corr, target_train_corr, target_test_corr
    # )
    acc_conf = SVC_fit_predict(
        shadow_train_conf, shadow_test_conf, target_train_conf, target_test_conf
    )
    # acc_entr = SVC_fit_predict(
    #     shadow_train_entr, shadow_test_entr, target_train_entr, target_test_entr
    # )
    # acc_m_entr = SVC_fit_predict(
    #     shadow_train_m_entr, shadow_test_m_entr, target_train_m_entr, target_test_m_entr
    # )
    # acc_prob = SVC_fit_predict(
    #     shadow_train_prob, shadow_test_prob, target_train_prob, target_test_prob
    # )
    m = {
        "correctness": None,#acc_corr,
        "confidence": acc_conf,
        "entropy": None,  #acc_entr,
        "m_entropy": None,#acc_m_entr,
        "prob": None,  #acc_prob,
    }
    
    return m