class BlockAwareModel(torch.nn.Module): 
    def __init__(self, base_model):
        super().__init__()
        self.model  = base_model
        self.block_rules  = {}  # 示例: {0:[], 1:[0], 2:[0,1]}
 
    def generate_block_mask(self, input_ids, block_rules, block_boundaries):
        seq_len = input_ids.shape[1] 
        mask = torch.full((seq_len,  seq_len), float('-inf'), device=input_ids.device) 
        
        for block_idx, visible_blocks in block_rules.items(): 
            start, end = block_boundaries[block_idx], block_boundaries[block_idx+1]
            
            # 块内全连接
            mask[start:end, start:end] = 0
            
            # 跨块可见性
            for v_block in visible_blocks:
                v_start = block_boundaries[v_block]
                v_end = block_boundaries[v_block+1]
                mask[start:end, v_start:v_end] = 0 
                
        return mask
 
    def forward(self, input_ids, block_rules=None, **kwargs):
        if block_rules:
            attention_mask = self.generate_block_mask(input_ids,  block_rules)
            kwargs['attention_mask'] = attention_mask 
        return self.model(input_ids,  **kwargs)