import time
import numpy as np
import torch
from tqdm import tqdm
from utils.misc import (AverageMeter, accuracy, multi_class_accuracy, de_interleave,
                        get_cosine_schedule_with_warmup, interleave)
from sklearn.preprocessing import OneHotEncoder, LabelEncoder


def test(args, test_loader, model, loss_func, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    end = time.time()

    # if not args.no_progress:
    #     test_loader = tqdm(test_loader, disable=args.local_rank not in [-1, 0])

    targets_all = np.array([])
    predicts_all = np.array([])
    probs_all = np.array([])

    if model.num_classifier == 1:
        outputs_p_all = np.array([]).reshape(0, 1)
        outputs_n_all = np.array([]).reshape(0, 1)
    else:
        outputs_p_all = np.array([]).reshape(0, 2)
        outputs_n_all = np.array([]).reshape(0, 2)
    targets_p_all = np.array([])
    targets_n_all = np.array([])

    model.eval()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):

            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)

            # ############### add
            pos_mask = targets == 1
            neg_mask = targets == -1

            target_p = targets[pos_mask]
            target_n = targets[neg_mask]

            outputs_p_all = np.vstack((outputs_p_all, outputs[pos_mask].detach().cpu().numpy()))
            outputs_n_all = np.vstack((outputs_n_all, outputs[neg_mask].detach().cpu().numpy()))
            targets_p_all = np.hstack((targets_p_all, target_p.detach().cpu().numpy()))
            targets_n_all = np.hstack((targets_n_all, target_n.detach().cpu().numpy()))

            # ############### end add

            t = targets.detach().cpu().numpy()
            targets_all = np.hstack((targets_all, t))

            size = len(t)
            if model.num_classifier == 1:
                p = np.reshape(torch.sigmoid(outputs).detach().cpu().numpy(), size)
                probs_all = np.hstack((probs_all, p))
                # o = np.reshape(torch.sign(outputs).detach().cpu().numpy(), size)
                o = np.where(p > 0.5, 1, -1)
                predicts_all = np.hstack((predicts_all, o))
            else:
                p = torch.sigmoid(outputs).detach().cpu().numpy()
                probs_all = p if batch_idx == 0 else np.vstack((probs_all, p))
                o = torch.sign(outputs).detach().cpu().numpy()
                predicts_all = o if batch_idx == 0 else np.vstack(
                    (predicts_all, o))

    #         batch_time.update(time.time() - end)
    #         end = time.time()
    #         if not args.no_progress:
    #             test_loader.set_description(
    #                 "Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. "
    #                 .format(batch=batch_idx + 1,
    #                         iter=len(test_loader),
    #                         data=data_time.avg,
    #                         bt=batch_time.avg))
    #
    # if not args.no_progress:
    #     test_loader.close()

    outputs_p_all = torch.from_numpy(outputs_p_all).to(device)
    outputs_n_all = torch.from_numpy(outputs_n_all).to(device)
    targets_p_all = torch.from_numpy(targets_p_all).to(device)
    targets_n_all = torch.from_numpy(targets_n_all).to(device)
    loss, loss_p, loss_n = loss_func(outputs_p_all, outputs_n_all, targets_p_all, targets_n_all)

    if model.num_classifier == 1:
        overall_metrics, class_metrics = multi_class_accuracy(probs_all, predicts_all, targets_all)
        return overall_metrics, class_metrics, loss, loss_p, loss_n
    else:
        # # 将 target 转换为独热编码格式
        # def convert_to_onehot(y):
        #     # 创建两列的数组，分别对应-1和1
        #     y_onehot = np.zeros((len(y), 2))
        #     y_onehot[:, 0] = (y == -1)
        #     y_onehot[:, 1] = (y == 1)
        #     return y_onehot * 2 - 1  # 转换为-1/1格式

        # 转换target
        # targets_all = convert_to_onehot(targets_all)
        overall_metrics, class_metrics = multi_class_accuracy(probs_all, predicts_all, targets_all)
        # acc, auc, f1_macro, f1_micro, precision, recall, erate = overall_metrics
        # class_f1, class_precision, class_recall, class_npp = class_metrics
        return overall_metrics, class_metrics, loss, loss_p, loss_n

        # accs = []
        # aucs = []
        # f1s = []
        # precisions = []
        # recalls = []
        # erates = []
        # npps = []
        # for i in range(model.num_classifier):
        #     (acc, auc, f1, precision, recall, erate,
        #      npp) = accuracy(probs_all[:, i], predicts_all[:, i], targets_all)
        #     accs.append(acc)
        #     aucs.append(auc)
        #     f1s.append(f1)
        #     precisions.append(precision)
        #     recalls.append(recall)
        #     erates.append(erate)
        #     npps.append(npp)
        # return accs, aucs, f1s, precisions, recalls, erates, npps, loss, loss_p, loss_n


def record(args, logger, valid_loader, model, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    end = time.time()
    if not args.no_progress:
        valid_loader = tqdm(valid_loader,
                            disable=args.local_rank not in [-1, 0])

    targets_all = np.array([])
    predicts_all = np.array([])
    probs_all = np.array([])

    model.eval()
    preds = np.array(args.batch_size * 2)
    for batch_idx, (inputs, (target_u, targets)) in enumerate(valid_loader):
        data_time.update(time.time() - end)

        inputs = inputs.to(args.device)
        targets = targets.to(args.device)

        with torch.no_grad():
            outputs = model(inputs)

        t = targets.detach().cpu().numpy()
        targets_all = np.hstack((targets_all, t))

        size = len(t)
        if model.num_classifier == 1:
            p = np.reshape(torch.sigmoid(outputs).detach().cpu().numpy(), size)
            probs_all = np.hstack((probs_all, p))
            o = np.reshape(torch.sign(outputs).detach().cpu().numpy(), size)
            predicts_all = np.hstack((predicts_all, o))
        else:
            p = torch.sigmoid(outputs).detach().cpu().numpy()
            probs_all = p if batch_idx == 0 else np.vstack((probs_all, p))
            o = torch.sign(outputs).detach().cpu().numpy()
            predicts_all = o if batch_idx == 0 else np.vstack(
                (predicts_all, o))

        batch_time.update(time.time() - end)
        end = time.time()
        if not args.no_progress:
            valid_loader.set_description(
                "Record Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. "
                .format(
                    batch=batch_idx + 1,
                    iter=len(valid_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                ))
    if not args.no_progress:
        valid_loader.close()

    if model.num_classifier == 1:
        (acc, auc, f1, precision, recall, erate,
         npp) = accuracy(probs_all, predicts_all, targets_all)
        logger.info("acc: {:.6f}".format(acc))
        logger.info("auc: {:.6f}".format(auc))
        logger.info("f1: {:.6f}".format(f1))
        return (acc, auc, f1, precision, recall, erate, npp, logger)
    else:
        accs = []
        aucs = []
        f1s = []
        precisions = []
        recalls = []
        erates = []
        npps = []
        for i in range(model.num_classifier):
            (acc, auc, f1, precision, recall, erate,
             npp) = accuracy(probs_all[:, i], predicts_all[: i,], targets_all)
            accs.append(acc)
            aucs.append(auc)
            f1s.append(f1)
            precisions.append(precision)
            recalls.append(recall)
            erates.append(erate)
            npps.append(npp)

        return (accs, aucs, f1s, precisions, recalls, erates, npps, logger)
