import torch

# 初始化参数
batch = 32
seq_len = 256  # 避免使用len作为变量名，它是Python内置函数

# 转移索引矩阵 (7,7)
trans_index = torch.tensor(
    [[1, 0, 1, 2, 1, 2, 1], 
     [1, 0, 1, 2, 1, 2, 1], 
     [1, 2, 1, 0, 1, 2, 1], 
     [1, 2, 1, 0, 1, 2, 1], 
     [1, 2, 1, 2, 1, 0, 1], 
     [1, 0, 1, 0, 1, 0, 1], 
     [1, 2, 1, 2, 1, 2, 2]], dtype=torch.long)

# 转移分数 (batch, seq_len-2, 3)，并将第2类分数设为0
transition_score = torch.rand(seq_len-2, batch, 3)
transition_score[:, :, 2] = 0

# 随机生成标签 (seq_len, batch)，标签范围0-6
tags = torch.randint(0, 7, (seq_len, batch))

# 计算真实的转移类型索引
# 注意：tags的索引需要调整，tags[i] 是第i个样本的所有标签
# 正确的做法是取第0个时间步和第1个时间步的标签
true_attn_trans_index = trans_index[tags[0], tags[1]]  #形状: torch.Size([32])
transition_slice = transition_score[1] 
#attn_score = transition_slice.gather(dim=1, index=true_attn_trans_index.unsqueeze(1)).squeeze(1) 

next_score =  transition_slice.gather(dim=1, index=trans_index.unsqueeze(1)).squeeze(1)
#next_score = transition_score[1,:, trans_index]
print(f"next_score 形状: {next_score.shape}")

# 获取注意力分数
#attn_score = transition_score[0, true_attn_trans_index]  # 第0个实体对，全批次，对应类型

# 验证维度
#print(f"true_attn_trans_index 形状: {true_attn_trans_index.shape}")
#print(f"attn_score 形状: {attn_score.shape}")

# 检查是否符合预期
#assert attn_score.shape == (batch,), f"维度错误，预期 ({batch},)，实际 {attn_score.shape}"
#print("维度验证通过!")
