# inside forward(), after hidden_states = outputs.last_hidden_state
if pooling_mode == "evidence_only_pooling":
 # mask out claim + label (+ optionally everything after [SEP])
 masked_hs = hidden_states.clone()
 for i in range(B):
 # sep_pos[i] is index of [SEP]
 # zero out positions >= sep_pos[i] (claim, [LABELSEP], label, EOS, padding)
 start = sep_pos[i]
 masked_hs[i, start:, :] = 0.0
 pooled = self._pool(masked_hs, pooling_mode, sep_pos, labelsep_pos, attention_mask)
else:
 pooled = self._pool(hidden_states, pooling_mode, sep_pos, labelsep_pos, attention_mask)
cons_logits = self.consistency_head(pooled)
