import argparse
import torch


def get_args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument("--dataset_name", default="sun397")
    parser.add_argument("--dataset_dir", default="/home/datasets/vtab-1k/vtab-1k/")

    args = parser.parse_args()
    return args

def default_flist_reader(flist):
    """
    flist format: impath label\nimpath label\n ...(same to caffe's filelist)
    """
    imlist = []
    with open(flist, 'r') as rf:
        for line in rf.readlines():
            impath, imlabel = line.strip().split()
            imlist.append((impath, int(imlabel)))

    return imlist


from collections import Counter


def count_value_occurrences(pairs):
    """
    统计二元组列表中每个 value 出现的次数

    参数:
        pairs: list of tuples, 每个元组形式为 (name, value)

    返回:
        dict: {value: 出现次数}
    """
    values = [value for _, value in pairs]  # 提取所有value
    return Counter(values)  # 返回计数结果

def get_class(path, dataset):
    trainval_flist = path + dataset + "/train800val200.txt"
    # trainval_flist = path + dataset + "/test.txt"

    train_list = default_flist_reader(trainval_flist)

    print(count_value_occurrences(train_list))
    unique_values = set()  # 用集合存储以自动去重
    for name, value in train_list:
        unique_values.add(value)
    print(len(unique_values))


def expand_with_mask(tensor, mask):
    """
    Args:
        tensor: 二维张量 [num_rows, num_selected_cols], 仅包含 mask=1 对应的数据
        mask: 一维张量 [num_cols], 值为 0 或 1, mask.sum() = num_selected_cols
    Returns:
        output: 二维张量 [num_rows, num_cols], mask=1 的位置为 tensor 值, mask=0 的位置为 0
    """
    # 检查输入合法性
    assert tensor.dim() == 3, "tensor 必须是二维张量"
    assert mask.dim() == 1, "mask 必须是一维张量"
    assert mask.sum() == tensor.size(-1), "mask 中 1 的数量必须等于 tensor 的列数"

    # num_rows, num_selected_cols = tensor.shape
    # num_cols = mask.size(0)

    # 初始化全零张量（支持梯度）
    output = torch.zeros_like(tensor, dtype=tensor.dtype, device=tensor.device)

    # 获取 mask=1 的列索引（布尔掩码）
    mask_bool = mask.bool()  # [num_cols]
    print(mask_bool)
    # 将 tensor 的值填充到 output 的 mask=1 位置
    output[:, mask_bool] = tensor

    return output


print(torch.zeros(2,2,3))
# 示例使用
mask = torch.tensor([1, 0, 1, 0, 1])  # shape: [5], 3 个 1
tensor = torch.tensor([[
    [10, 20, 30],  # 对应 mask=1 的位置（第0、2、4列）
    [40, 50, 60]
],[ [10, 20, 30],  # 对应 mask=1 的位置（第0、2、4列）
    [40, 50, 60]]], dtype=torch.float32, requires_grad=True)  # shape: [2, 3]

output = expand_with_mask(tensor, mask)
print("Mask:\n", mask)
print("Original tensor:\n", tensor)
print("Expanded output:\n", output)

# 验证梯度传播
loss = output.sum()  # 假设损失函数是求和
loss.backward()
print("Gradient of tensor:\n", tensor.grad)  # 应为全1（因为梯度是1的广播）

# if __name__ == '__main__':
#     # args = get_args_parser()
    # get_class(args.dataset_dir, args.dataset_name)

