# Recovered visible snippet from FEVER tightened-run thread.
# This is not the full original fever_pretrained_gpt2_experiment.py file.

# inside forward(), after hidden_states = outputs.last_hidden_state
if pooling_mode == "evidence_only_pooling":
    # mask out claim + label (+ optionally everything after [SEP])
    masked_hs = hidden_states.clone()
    for i in range(B):
        # sep_pos[i] is index of [SEP]
        # zero out positions >= sep_pos[i] (claim, [LABELSEP], label, EOS, padding)
        start = sep_pos[i]
        masked_hs[i, start:, :] = 0.0
    pooled = self._pool(masked_hs, pooling_mode, sep_pos, labelsep_pos, attention_mask)
else:
    pooled = self._pool(hidden_states, pooling_mode, sep_pos, labelsep_pos, attention_mask)
cons_logits = self.consistency_head(pooled)
