"""
包含实现基本测试功能代码的文件
"""

import math
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn
def test(model,classifier_name,device, dataiter, test_set, batch_size, recoder, log, tp=False):  # 训练函数
    """
    此代码功能为对数据进行一个EPOCH的测试

    输入数据rcs的格式为：
    [batch_size, time_steps, input_channel=1]
    标签label格式为：
    [batch_size, class_num]
    输出output格式为：
    [batch_size, class_num]
    :param model: 深度学习模型
    :param device: 计算处理器设备
    :param dataiter: 数据迭代器
    :param test_set: 测试集
    :param batch_size: 批大小
    :param mission: 任务类型
    :param recoder: 日志记录器
    :param tp: 是否记录结果
    :return: loss, acc
    """
    model.eval()  # 将model设定为测试模式
    img_names = []
    originlables = []
    labels = []
    with torch.no_grad():
        aloss = 0
        aacc1 = 0
        aacc5 = 0
        iter_times = math.ceil(len(test_set) / batch_size) #1

        # print('测试开始：', iter_times)
        for iteration in tqdm(range(iter_times)):
            input_data, label, img_name,  originlable = next(dataiter)
            input_data = input_data.clone().detach().float().to(device)
            for i in range(len(label)):
                labels.append(label.numpy()[i])
            label = label.clone().detach().to(device)
            loss, acc, embedding, output = my_forward(model, classifier_name,input_data, label, recoder,log, tp)
            acc1 = acc[0]
            acc5 = acc[1]
            #######################
            aloss += loss.cpu().detach().numpy().mean()
            aacc1 += acc1.cpu().detach().numpy()
            aacc5 += acc5.cpu().detach().numpy()
            if log is True:
                recoder.log_test_loss(loss.cpu().detach().numpy())
                recoder.log_test_acc(acc5.cpu().detach().numpy())

    return aloss / iter_times, aacc1 / iter_times

def my_forward(model, classifier_name, input_data, label, recoder,log, tp=False,topk=(1,5)):
    """
    进行一次正向传播
    :param classifier_name: 深度学习模型
    :param input_data: 输入数据
    :param label: 样本标签
    :param recoder: 日志记录器
    :param tp: 是否记录当前结果
    :return: loss, acc
    """
    # loss_function = vae_loss_function()
    loss_function = nn.CrossEntropyLoss(reduction='none')
    output, embedding = model(input_data)
    embedding = embedding.cpu().numpy()
    # correct = torch.eq(torch.argmax(output, dim=1), label)
    loss = loss_function(output, label.long())
    maxk = max(topk)
    # label = torch.from_numpy(label)
    batch_size = label.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(label.view(1, -1).expand_as(pred)).contiguous()
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(1 / batch_size))

    # acc = torch.mean(correct.float())


    if tp:
        # 记录当前batch的真值和预测结果
        for i in range(len(torch.argmax(output, dim=1).cpu().numpy())):
            if log is True:
                recoder.log_test_label(label.cpu().numpy()[i],
                                   torch.argmax(output, dim=1).cpu().numpy()[i])
    return loss, res, embedding, output
