import torch
import torch.nn.functional as F


def val_gzsl(test_X, test_label, target_classes, in_package, bias=0):
    """
        广义零样本学习(GZSL)评估函数

        在GZSL设置中，模型需要同时识别已见类和未见类，这更符合实际应用场景。
        与传统ZSL不同，GZSL允许测试时出现训练时见过的类别。

        Args:
            test_X: 测试特征，形状为(N, feature_dim)，通常是ResNet101提取的2048维特征
            test_label: 真实标签，形状为(N,)，包含类别ID
            target_classes: 目标类别索引，可以是已见类或未见类的索引
            in_package: 包含模型、设备、批次大小等信息的字典
            bias: 偏置校准参数，用于调整已见类/未见类的预测偏好

        Returns:
            acc: 在目标类别上的平均准确率
    """
    batch_size = in_package['batch_size']
    model = in_package['model']
    device = in_package['device']
    with torch.no_grad():
        start = 0
        ntest = test_X.size()[0]  # 测试样本总数
        predicted_label = torch.LongTensor(test_label.size())  # 预测标签容器
        # 批次处理测试数据
        for i in range(0, ntest, batch_size):
            end = min(ntest, start + batch_size)
            # 获取当前批次数据
            input = test_X[start:end].to(device)

            # 前向传播：视觉特征 → Transformer → 语义嵌入 → 分类得分
            out_package = model(input)

            output = out_package['S_pp']  # 模型输出的分类得分，形状为(batch, num_classes)

            # 偏置校准：调整目标类别的得分
            # 正偏置增强目标类别，负偏置抑制目标类别
            output[:, target_classes] = output[:, target_classes] + bias

            # 预测类别：选择得分最高的类别
            predicted_label[start:end] = torch.argmax(output.data, 1)

            start = end

        # 计算在目标类别上的每类平均准确率
        acc = compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package)
        return acc


def map_label(label, classes):
    """
        标签映射函数：将原始类别ID映射到连续的索引

        在零样本学习中，未见类的标签可能不连续（如CUB中的类别ID），
        需要映射为0, 1, 2, ...的连续索引以便计算准确率。

        Args:
            label: 原始标签张量，包含原始类别ID
            classes: 类别索引张量，定义了映射关系

        Returns:
            mapped_label: 映射后的标签，-1表示不在classes中的类别

        Example:
            原始label=[5, 12, 8, 5], classes=[5, 8, 12]
            映射后=[0, 2, 1, 0]
    """
    mapped_label = torch.LongTensor(label.size()).fill_(-1)
    for i in range(classes.size(0)):
        # 将classes[i]类别的所有样本标签映射为索引i
        mapped_label[label == classes[i]] = i

    return mapped_label


# 评估未知类
def val_zs_gzsl(test_X, test_label, unseen_classes, in_package, bias=0):
    """
        零样本学习(ZSL)和广义零样本学习(GZSL)联合评估函数

        该函数同时计算ZSL和GZSL性能：
        - ZSL：只在未见类中进行分类
        - GZSL：在所有类（已见+未见）中进行分类

        Args:
            test_X: 未见类测试特征
            test_label: 未见类真实标签
            unseen_classes: 未见类别索引
            in_package: 模型信息包
            bias: 偏置校准参数，通常对未见类使用正偏置

        Returns:
            acc_gzsl: 广义零样本学习准确率（在所有类中分类）
            acc_zs_t: 纯零样本学习准确率（只在未见类中分类）
    """

    batch_size = in_package['batch_size']
    model = in_package['model']
    device = in_package['device']
    with torch.no_grad():
        start = 0
        ntest = test_X.size()[0]
        # 三种预测结果容器
        predicted_label_gzsl = torch.LongTensor(test_label.size())  # GZSL预测
        predicted_label_zsl = torch.LongTensor(test_label.size())  # ZSL预测（技巧实现）
        predicted_label_zsl_t = torch.LongTensor(test_label.size())  # ZSL预测（真实实现）
        for i in range(0, ntest, batch_size):
            end = min(ntest, start + batch_size)

            input = test_X[start:end].to(device)

            out_package = model(input)
            output = out_package['S_pp']  # 原始分类得分

            # === ZSL评估技巧实现 ===
            # 通过给未见类加上很大的正偏置，强制模型只能选择未见类
            output_t = output.clone()
            output_t[:, unseen_classes] = output_t[:, unseen_classes] + torch.max(output) + 1
            predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1)

            # === ZSL评估真实实现 ===
            # 只考虑未见类的得分，直接在未见类中选择最大值
            predicted_label_zsl_t[start:end] = torch.argmax(output.data[:, unseen_classes], 1)

            # === GZSL评估 ===
            # 对未见类应用偏置校准，然后在所有类中选择
            output[:, unseen_classes] = output[:, unseen_classes] + bias
            predicted_label_gzsl[start:end] = torch.argmax(output.data, 1)

            start = end

        # 计算三种评估方式的准确率
        acc_gzsl = compute_per_class_acc_gzsl(test_label, predicted_label_gzsl, unseen_classes, in_package)
        acc_zs = compute_per_class_acc_gzsl(test_label, predicted_label_zsl, unseen_classes, in_package)
        acc_zs_t = compute_per_class_acc(map_label(test_label, unseen_classes), predicted_label_zsl_t,
                                         unseen_classes.size(0))

        return acc_gzsl, acc_zs_t


def compute_per_class_acc(test_label, predicted_label, nclass):
    """
        计算每个类别的准确率并返回平均值

        这是标准的分类准确率计算，适用于连续索引的类别标签。

        Args:
            test_label: 真实标签（连续索引0,1,2,...）
            predicted_label: 预测标签
            nclass: 类别总数

        Returns:
            平均每类准确率
    """
    acc_per_class = torch.FloatTensor(nclass).fill_(0)
    for i in range(nclass):
        idx = (test_label == i)  # 找到属于类别i的所有样本
        # 计算类别i的准确率：正确预测数 / 该类别总样本数
        acc_per_class[i] = torch.sum(test_label[idx] == predicted_label[idx]).float() / torch.sum(idx).float()
    return acc_per_class.mean().item()


# 计算每个类别的准确率
def compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package):
    """
        计算广义零样本学习中目标类别的每类准确率

        与标准准确率计算不同，这里的类别标签可能不连续，
        需要根据target_classes来确定要评估的类别。

        Args:
            test_label: 真实标签（原始类别ID）
            predicted_label: 预测标签（原始类别ID）
            target_classes: 目标类别索引（如未见类的索引）
            in_package: 包含设备信息等

        Returns:
            目标类别上的平均每类准确率
    """
    device = in_package['device']
    per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach()

    predicted_label = predicted_label.to(device)

    for i in range(target_classes.size()[0]):
        # 找到属于target_classes[i]类别的所有样本
        is_class = test_label == target_classes[i]
        # 计算准确率
        per_class_accuracies[i] = torch.div((predicted_label[is_class] == test_label[is_class]).sum().float(),
                                            is_class.sum().float())
    # 返回平均准确率
    return per_class_accuracies.mean().item()


# 广义零样本学习(GZSL)的评估函数
def eval_zs_gzsl(dataloader, model, device, bias_seen=0, bias_unseen=0, batch_size=50):
    """
        零样本学习和广义零样本学习的主要评估函数

        这是TransZero模型评估的核心函数，同时计算：
        1. 已见类在其测试集上的准确率
        2. 未见类在其测试集上的准确率（ZSL和GZSL）
        3. 调和平均数H，用于综合评估GZSL性能

        Args:
            dataloader: 数据加载器，包含训练和测试数据
            model: TransZero模型
            device: 计算设备
            bias_seen: 已见类偏置校准参数（通常为0或负值）
            bias_unseen: 未见类偏置校准参数（通常为正值）
            batch_size: 评估时的批次大小

        Returns:
            acc_seen: 已见类准确率
            acc_novel: 未见类准确率（GZSL设置）
            H: 调和平均数 = 2 * acc_seen * acc_novel / (acc_seen + acc_novel)
            acc_zs: 未见类准确率（纯ZSL设置）
    """
    model.eval()

    # 从数据加载器获取测试数据
    test_seen_feature = dataloader.data['test_seen']['resnet_features']  # 已知类的测试特征
    test_seen_label = dataloader.data['test_seen']['labels'].to(device)  # 已知类的测试标签

    test_unseen_feature = dataloader.data['test_unseen']['resnet_features']  # 未知类的测试特征
    test_unseen_label = dataloader.data['test_unseen']['labels'].to(device)  # 未知类的测试标签

    # 获取已见类和未见类的索引
    seenclasses = dataloader.seenclasses
    unseenclasses = dataloader.unseenclasses

    batch_size = batch_size

    # 构建评估所需的信息包
    in_package = {'model': model, 'device': device, 'batch_size': batch_size}

    with torch.no_grad():
        # 评估已见类性能：在已见类测试集上，使用已见类作为候选
        acc_seen = val_gzsl(test_seen_feature, test_seen_label, seenclasses, in_package, bias=bias_seen)
        # 评估未见类性能：在未见类测试集上，分别计算GZSL和ZSL性能
        acc_novel, acc_zs = val_zs_gzsl(test_unseen_feature, test_unseen_label, unseenclasses, in_package,
                                        bias=bias_unseen)

    # 计算调和平均数
    if (acc_seen + acc_novel) > 0:
        H = (2 * acc_seen * acc_novel) / (acc_seen + acc_novel)
    else:
        H = 0

    return acc_seen, acc_novel, H, acc_zs


def val_gzsl_k(k, test_X, test_label, target_classes, in_package, bias=0, is_detect=False):
    """
        Top-K广义零样本学习评估函数

        与Top-1不同，Top-K允许模型预测K个最可能的类别，
        只要真实类别在这K个预测中即算正确。

        Args:
            k: Top-K中的K值
            test_X: 测试特征
            test_label: 真实标签
            target_classes: 目标类别
            in_package: 模型信息包（需包含num_class）
            bias: 偏置校准参数
            is_detect: 是否进行新颖性检测（检测已见vs未见）

        Returns:
            acc: Top-K准确率
    """
    batch_size = in_package['batch_size']
    model = in_package['model']
    device = in_package['device']
    n_classes = in_package["num_class"]

    with torch.no_grad():
        start = 0
        ntest = test_X.size()[0]

        # 将标签转换为one-hot编码，便于Top-K评估
        test_label = F.one_hot(test_label, num_classes=n_classes)
        predicted_label = torch.LongTensor(test_label.size()).fill_(0).to(test_label.device)
        for i in range(0, ntest, batch_size):

            end = min(ntest, start + batch_size)

            input = test_X[start:end].to(device)

            out_package = model(input)

            output = out_package['S_pp']

            # 应用偏置校准
            output[:, target_classes] = output[:, target_classes] + bias
            # 获取Top-K预测
            _, idx_k = torch.topk(output, k, dim=1)
            if is_detect:
                # 新颖性检测模式：判断样本是已见类还是未见类
                assert k == 1  # 新颖性检测只支持Top-1
                detection_mask = in_package["detection_mask"]
                predicted_label[start:end] = detection_mask[torch.argmax(output.data, 1)]
            else:
                # 标准Top-K预测：在Top-K位置设置为1
                predicted_label[start:end] = predicted_label[start:end].scatter_(1, idx_k, 1)
            start = end

        acc = compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package)
        return acc


def val_zs_gzsl_k(k, test_X, test_label, unseen_classes, in_package, bias=0, is_detect=False):
    """
        Top-K零样本学习和广义零样本学习评估函数

        Args:
            k: Top-K中的K值
            其他参数与val_zs_gzsl类似

        Returns:
            acc_gzsl: Top-K GZSL准确率
            -1: ZSL准确率（在Top-K版本中不计算）
    """
    batch_size = in_package['batch_size']
    model = in_package['model']
    device = in_package['device']
    n_classes = in_package["num_class"]
    with torch.no_grad():
        start = 0
        ntest = test_X.size()[0]

        # GZSL评估：在所有类别中进行Top-K预测
        test_label_gzsl = F.one_hot(test_label, num_classes=n_classes)
        predicted_label_gzsl = torch.LongTensor(test_label_gzsl.size()).fill_(0).to(test_label.device)

        # ZSL评估相关（在Top-K版本中暂不实现）
        predicted_label_zsl = torch.LongTensor(test_label.size())
        predicted_label_zsl_t = torch.LongTensor(test_label.size())
        for i in range(0, ntest, batch_size):

            end = min(ntest, start + batch_size)

            input = test_X[start:end].to(device)

            # ZSL预测（技巧实现）
            out_package = model(input)
            output = out_package['S_pp']
            # GZSL Top-K预测
            output_t = output.clone()
            output_t[:, unseen_classes] = output_t[:, unseen_classes] + torch.max(output) + 1
            predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1)
            predicted_label_zsl_t[start:end] = torch.argmax(output.data[:, unseen_classes], 1)

            output[:, unseen_classes] = output[:, unseen_classes] + bias
            _, idx_k = torch.topk(output, k, dim=1)
            if is_detect:
                assert k == 1
                detection_mask = in_package["detection_mask"]
                predicted_label_gzsl[start:end] = detection_mask[torch.argmax(output.data, 1)]
            else:
                predicted_label_gzsl[start:end] = predicted_label_gzsl[start:end].scatter_(1, idx_k, 1)

            start = end

        acc_gzsl = compute_per_class_acc_gzsl_k(test_label_gzsl, predicted_label_gzsl, unseen_classes, in_package)
        #print('acc_zs: {} acc_zs_t: {}'.format(acc_zs,acc_zs_t))
        return acc_gzsl, -1


def compute_per_class_acc_k(test_label, predicted_label, nclass):
    """
        计算Top-K每类准确率（标准版本）

        适用于连续索引的类别标签的Top-K准确率计算。
    """
    acc_per_class = torch.FloatTensor(nclass).fill_(0)
    for i in range(nclass):
        idx = (test_label == i)
        acc_per_class[i] = torch.sum(test_label[idx] == predicted_label[idx]).float() / torch.sum(idx).float()
    return acc_per_class.mean().item()


def compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package):
    """
        计算Top-K广义零样本学习每类准确率

        使用one-hot编码的标签和预测进行Top-K准确率计算。

        Args:
            test_label: one-hot编码的真实标签，形状为(N, num_classes)
            predicted_label: one-hot编码的Top-K预测，形状为(N, num_classes)
            target_classes: 目标类别索引
            in_package: 包含设备信息

        Returns:
            目标类别上的平均Top-K准确率
        """
    device = in_package['device']
    per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach()

    predicted_label = predicted_label.to(device)

    hit = test_label * predicted_label
    for i in range(target_classes.size()[0]):
        target = target_classes[i]
        n_pos = torch.sum(hit[:, target])
        n_gt = torch.sum(test_label[:, target])
        per_class_accuracies[i] = torch.div(n_pos.float(), n_gt.float())
        #pdb.set_trace()
    return per_class_accuracies.mean().item()


def eval_zs_gzsl_k(k, dataloader, model, device, bias_seen, bias_unseen, is_detect=False):
    """
        Top-K零样本学习和广义零样本学习的主要评估函数

        提供Top-K评估和可选的新颖性检测功能。

        Args:
            k: Top-K中的K值
            dataloader: 数据加载器
            model: 模型
            device: 计算设备
            bias_seen: 已见类偏置
            bias_unseen: 未见类偏置
            is_detect: 是否进行新颖性检测

        Returns:
            acc_seen: 已见类Top-K准确率
            acc_novel: 未见类Top-K准确率
            H: 调和平均数
            acc_zs: ZSL准确率（Top-K版本中返回-1）
    """
    model.eval()
    print('bias_seen {} bias_unseen {}'.format(bias_seen, bias_unseen))
    test_seen_feature = dataloader.data['test_seen']['resnet_features']
    test_seen_label = dataloader.data['test_seen']['labels'].to(device)

    test_unseen_feature = dataloader.data['test_unseen']['resnet_features']
    test_unseen_label = dataloader.data['test_unseen']['labels'].to(device)

    seenclasses = dataloader.seenclasses
    unseenclasses = dataloader.unseenclasses

    batch_size = 100
    n_classes = dataloader.ntrain_class + dataloader.ntest_class
    in_package = {'model': model, 'device': device, 'batch_size': batch_size, 'num_class': n_classes}

    if is_detect:
        print("Measure novelty detection k: {}".format(k))

        detection_mask = torch.zeros((n_classes, n_classes)).long().to(dataloader.device)
        detect_label = torch.zeros(n_classes).long().to(dataloader.device)
        detect_label[seenclasses] = 1
        detection_mask[seenclasses, :] = detect_label

        detect_label = torch.zeros(n_classes).long().to(dataloader.device)
        detect_label[unseenclasses] = 1
        detection_mask[unseenclasses, :] = detect_label
        in_package["detection_mask"] = detection_mask

    with torch.no_grad():
        acc_seen = val_gzsl_k(k, test_seen_feature, test_seen_label, seenclasses, in_package, bias=bias_seen,
                              is_detect=is_detect)
        acc_novel, acc_zs = val_zs_gzsl_k(k, test_unseen_feature, test_unseen_label, unseenclasses, in_package,
                                          bias=bias_unseen, is_detect=is_detect)

    if (acc_seen + acc_novel) > 0:
        H = (2 * acc_seen * acc_novel) / (acc_seen + acc_novel)
    else:
        H = 0

    return acc_seen, acc_novel, H, acc_zs
