import torch.nn as nn
from transformers import PreTrainedModel
from transformers import LlamaPreTrainedModel, AutoTokenizer
from transformers.modeling_outputs import TokenClassifierOutput
from utils.utils import mean_pooling
from torch.cuda.amp import autocast
import torch
import torch.nn.functional as F
import transformers
import math
import numpy as np
import random

# class MLP(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
#         super().init()
#         self.lins = nn.ModuleList()
#         self.norms = nn.ModuleList()
        
#     def forward(self, x):
#         pass

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class LlamaClassifier(LlamaPreTrainedModel):
    def __init__(self, model, n_labels, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(model.config)
        self.bert_encoder = model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)
        # init_random_state(seed)

    def forward(self,
                input_ids=None,
                attention_mask=None,
                labels=None,
                return_dict=None,
                root_mask=None):
        with autocast():
            

            outputs = self.bert_encoder(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        return_dict=return_dict,
                                        output_hidden_states=True)
            
            
            emb = self.dropout(outputs['hidden_states'][-1])
            
            emb = mean_pooling(emb, root_mask)
            if self.feat_shrink:
                emb = self.feat_shrink_layer(emb)
                    
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=emb)

class BertClassifier(PreTrainedModel):
    def __init__(self, model, n_labels, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(model.config)
        self.bert_encoder = model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)
        # init_random_state(seed)

    def forward(self,
                input_ids=None,
                attention_mask=None,
                labels=None,
                return_dict=None,
                root_mask=None):
       
        outputs = self.bert_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)
    
        emb = self.dropout(outputs['hidden_states'][-1])
        emb = mean_pooling(emb, root_mask)
        
        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=emb)

class HierBertClassifier(PreTrainedModel):
    def __init__(self, token_model, node_model, n_labels, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.node_encoder = node_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)
        # init_random_state(seed)

        self.fusion_weights  = nn.Parameter(torch.tensor([0.65,  0.35]))
        self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)

    def forward(self,
                input_ids=None,
                attention_mask=None,
                labels=None,
                return_dict=None,
                root=None,
                boundary=None):
       
        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

        token_embeddings = outputs['hidden_states'][-1]

        batch_size = input_ids.size(0) 
        all_node_embeddings = []
        all_node_attention_masks = []

        for b in range(batch_size):
            node_embeddings = []
            node_attention_masks = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if start == -1:
                    node_attention_masks.append(0)
                    node_embeddings.append(torch.zeros(self.token_encoder.config.hidden_size,
                                                        device=input_ids.device))
                    continue
                else:
                    node_attention_masks.append(1)

                node_tokens = token_embeddings[b, start:end+1, :]  # [node_len, dim]
                
                # 计算节点内注意力权重
                attn_weights = torch.softmax( 
                    self.intra_node_attention(node_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
                # 加权聚合得到节点表示
                node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
                node_embeddings.append(node_embed) 

            node_embeddings = torch.stack(node_embeddings).to(input_ids.device)
            node_attention_masks = torch.tensor(node_attention_masks, dtype=torch.long, device=input_ids.device)
            all_node_embeddings.append(node_embeddings)
            all_node_attention_masks.append(node_attention_masks)

        all_node_embeddings = torch.stack(all_node_embeddings)
        all_node_attention_masks = torch.stack(all_node_attention_masks)

        all_node_embeddings = self.node_input_norm(all_node_embeddings)

        node_outputs = self.node_encoder( 
            inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
            attention_mask=all_node_attention_masks, 
            output_hidden_states=True
        )
        node_hidden_states = node_outputs['hidden_states'][-1]

        low_emb = []
        high_emb = []

        for b in range(batch_size):
            low_emb.append(all_node_embeddings[b, root[b].item(), :])
            high_emb.append(node_hidden_states[b, root[b].item(), :])
        low_emb = torch.stack(low_emb)
        high_emb = torch.stack(high_emb)

        weights = torch.softmax(self.fusion_weights,  dim=0)
        emb = weights[0] * high_emb  + weights[1] * low_emb
        #emb = high_emb

        emb = self.dropout(emb)
        
        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)

        distribution_loss = F.mse_loss( 
            F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
            F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        )
        loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=emb)

class CatBertClassifier(PreTrainedModel):
    def __init__(self, token_model, node_model, n_labels, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.node_encoder = node_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim*2, n_labels, bias=cla_bias)
        # init_random_state(seed)

        #self.fusion_weights  = nn.Parameter(torch.tensor([0.64,  0.36]))

    def forward(self,
                input_ids=None,
                attention_mask=None,
                labels=None,
                return_dict=None,
                root=None,
                boundary=None):
       
        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

        token_embeddings = outputs['hidden_states'][-1]

        batch_size = input_ids.size(0) 
        all_node_embeddings = []
        all_node_attention_masks = []

        for b in range(batch_size):
            node_embeddings = []
            node_attention_masks = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if start == -1:
                    node_attention_masks.append(0)
                    node_embeddings.append(torch.zeros(self.token_encoder.config.hidden_size,
                                                        device=input_ids.device))
                    continue
                else:
                    node_attention_masks.append(1)

                node_tokens = token_embeddings[b, start:end+1, :]  # [node_len, dim]
                
                # 计算节点内注意力权重
                attn_weights = torch.softmax( 
                    self.intra_node_attention(node_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
                # 加权聚合得到节点表示
                node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
                node_embeddings.append(node_embed) 

            node_embeddings = torch.stack(node_embeddings).to(input_ids.device)
            node_attention_masks = torch.tensor(node_attention_masks, dtype=torch.long, device=input_ids.device)
            all_node_embeddings.append(node_embeddings)
            all_node_attention_masks.append(node_attention_masks)

        all_node_embeddings = torch.stack(all_node_embeddings)
        all_node_attention_masks = torch.stack(all_node_attention_masks)


        node_outputs = self.node_encoder( 
            inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
            attention_mask=all_node_attention_masks, 
            output_hidden_states=True
        )
        node_hidden_states = node_outputs['hidden_states'][-1]

        low_emb = []
        high_emb = []

        for b in range(batch_size):
            low_emb.append(all_node_embeddings[b, root[b].item(), :])
            high_emb.append(node_hidden_states[b, root[b].item(), :])
        low_emb = torch.stack(low_emb)
        high_emb = torch.stack(high_emb)

        # weights = torch.softmax(self.fusion_weights,  dim=0)
        # emb = weights[0] * high_emb  + weights[1] * low_emb
        emb = torch.cat([low_emb, high_emb], dim=-1)

        emb = self.dropout(emb)
        
        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=emb)


class AttBertClassifier(PreTrainedModel):
    def __init__(self, token_model, n_labels, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')
        self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)

        # self.attention_co = nn.Parameter(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        #                                                [1, 0.9, 0.85, 0.8, 0.9, 0.85, 0.8, 0.75, 0.85],
        #                                                [1, 0.9, 0.85, 0.8, 0.9, 0.85, 0.8, 0.75, 0.85]]), requires_grad=True)

        #self.attention_co = nn.Parameter(torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32), requires_grad=True)
        self.attention_co = [1, 1, 1, 1, 0, 0, 0]

    def compute_attention_mask(self, seq_length, block_rules, boundary):
        batch_size = block_rules.size(0) 
        device = block_rules.device 
        
        attention_mask = torch.zeros(batch_size,  seq_length, seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = -1
            nodes = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                nodes.append((start,  end))
                true_seq_length = end + 1
            
            for i, (src_start, src_end) in enumerate(nodes):
                for j, (tgt_start, tgt_end) in enumerate(nodes):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]   # 可训练参数
                    attention_mask[b, src_start:src_end+1, tgt_start:tgt_end+1] = co_value

            #处理SEP符号
            for i, (src_start, src_end) in enumerate(nodes):
                attention_mask[b, src_end+1, 0:true_seq_length+1] = 1
                attention_mask[b, 0:true_seq_length+1, src_end+1] = 1
 
        # torch.set_printoptions(precision=10)
        # print(boundary[0])
        # print(attention_mask[0, 0])
        return attention_mask

    
    def forward(self,
                input_ids=None,
                block_rules=None,
                labels=None,
                return_dict=None,
                root=None,
                boundary=None):


        batch_size = input_ids.size(0) 
        seq_length = input_ids.size(1)

        attention_mask = self.compute_attention_mask(seq_length, block_rules, boundary)
       
        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

        token_embeddings = outputs['hidden_states'][-1]

        all_node_embeddings = []
        all_node_attention_masks = []

        for b in range(batch_size):
            node_embeddings = []
            node_attention_masks = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if start == -1:
                    node_attention_masks.append(0)
                    node_embeddings.append(torch.zeros(self.token_encoder.config.hidden_size,
                                                        device=input_ids.device))
                    continue
                else:
                    node_attention_masks.append(1)

                node_tokens = token_embeddings[b, start:end+1, :]  # [node_len, dim]
                
                # 计算节点内注意力权重
                attn_weights = torch.softmax( 
                    self.intra_node_attention(node_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
                # 加权聚合得到节点表示
                node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
                node_embeddings.append(node_embed) 

            node_embeddings = torch.stack(node_embeddings).to(input_ids.device)
            node_attention_masks = torch.tensor(node_attention_masks, dtype=torch.long, device=input_ids.device)
            all_node_embeddings.append(node_embeddings)
            all_node_attention_masks.append(node_attention_masks)

        all_node_embeddings = torch.stack(all_node_embeddings)
        all_node_attention_masks = torch.stack(all_node_attention_masks)

        all_node_embeddings = self.node_input_norm(all_node_embeddings)

        # node_outputs = self.node_encoder( 
        #     inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
        #     attention_mask=all_node_attention_masks, 
        #     output_hidden_states=True
        # )
        # node_hidden_states = node_outputs['hidden_states'][-1]

        low_emb = []
        #high_emb = []

        for b in range(batch_size):
            low_emb.append(all_node_embeddings[b, root[b].item(), :])
            #high_emb.append(node_hidden_states[b, root[b].item(), :])
        low_emb = torch.stack(low_emb)
        #high_emb = torch.stack(high_emb)

        # weights = torch.softmax(self.fusion_weights,  dim=0)
        # emb = weights[0] * high_emb  + weights[1] * low_emb

        # emb = self.dropout(emb)
        
        emb = low_emb

        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)

        distribution_loss = F.mse_loss( 
            F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
            F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        )
        loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=emb)

class TransAttBertClassifier(PreTrainedModel):
    def __init__(self, token_model, n_labels, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)

        self._set_seed(seed)

        self.token_encoder = token_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')
        # self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)

        # 节点内注意力机制 
        # self.intra_node_attention  = nn.Sequential(
        #     nn.Linear(token_model.config.hidden_size,  128),
        #     nn.Tanh(),
        #     nn.Linear(128, 1, bias=False)
        # )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)

        self.attention_co = [1, 1, 1, 1, 0, 0, 0]

        self.Q = nn.Linear(hidden_dim, hidden_dim)
        self.K = nn.Linear(hidden_dim, hidden_dim)
        self.hidden_dim = hidden_dim

    def _set_seed(self, seed):
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)

    def compute_attention_mask(self, seq_length, block_rules, boundary):
        batch_size = block_rules.size(0) 
        device = block_rules.device 
        
        #attention_mask = torch.full((batch_size, seq_length, seq_length), 0.9999999, dtype=torch.float32, device=device)
        attention_mask = torch.zeros(batch_size,  seq_length, seq_length, device=device)
        # attention_mask = torch.zeros(batch_size, seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = -1
            nodes = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                nodes.append((start,  end))
                true_seq_length = end + 1

            #attention_mask[b, :, 0:true_seq_length+1] = 1
            # attention_mask[b, 0:true_seq_length+1] = 1
            attention_mask[b, true_seq_length+1:, 0:true_seq_length+1] = 1
            
            for i, (src_start, src_end) in enumerate(nodes):
                for j, (tgt_start, tgt_end) in enumerate(nodes):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    # attention_mask[b, src_start:src_end+1, tgt_start:tgt_end+1] = co_value
                    attention_mask[b, src_start:src_end+1, tgt_start:tgt_end+1] = 1

            #处理SEP符号
            for i, (src_start, src_end) in enumerate(nodes):
                attention_mask[b, src_end+1, 0:true_seq_length+1] = 1
                attention_mask[b, 0:true_seq_length+1, src_end+1] = 1
 
        #torch.set_printoptions(precision=10)
        #print(boundary[0])
        #print(attention_mask[0, 0])
        return attention_mask

    def get_node_embedding(self, token_embeddings, block_rules, boundary, root):
        batch_size = block_rules.size(0)
        device = block_rules.device

        node_embeddings = []
        for b in range(batch_size):
            root_start = 0
            root_end = 0
            value = []
            for idx, item in enumerate(boundary[b]):
                start, end = item[0].item(), item[1].item()
                if idx == root[b].item():
                    root_start = start
                    root_end = end
                    continue
                if block_rules[b][root[b].item()][idx] > 1:
                    continue
                for index in range(start, end+1):
                    value.append(token_embeddings[b][index])

            if not value:
                # 如果没有邻域 token，直接使用 root token
                for index in range(root_start, root_end+1):
                    value.append(token_embeddings[b][index])

            value = torch.stack(value).to(device)
            query = self.Q(value)
            key = self.K(value)

            root_value = token_embeddings[b, root_start:root_end+1]
            root_query = self.Q(root_value)
            root_key = self.K(root_value)

            attn_from_neighbors = torch.matmul(query, root_key.transpose(0, 1))  
            attn_from_neighbors = attn_from_neighbors / (self.hidden_dim ** 0.5)
            neighbor_attn_weights = torch.softmax(attn_from_neighbors, dim=-1)
            neighbor_contribution = neighbor_attn_weights.sum(dim=0)
            outra_attns = torch.softmax(neighbor_contribution.unsqueeze(0), dim=-1).squeeze(0)

            intra_attns = torch.matmul(root_query, root_key.transpose(0, 1)) / (self.hidden_dim ** 0.5)
            intra_attns = torch.softmax(intra_attns, dim=-1)
            intra_attns = intra_attns.sum(dim=0)
            intra_attns = torch.softmax(intra_attns.unsqueeze(0), dim=-1).squeeze(0)
            
            final_weights = 0.9 * intra_attns + 0.1 * outra_attns 
            # final_weights = intra_attns
            # final_weights = outra_attns

            weighted_value = root_value * final_weights.unsqueeze(-1)  # [N, d]
            node_embedding = weighted_value.sum(dim=0, keepdim=True)  # [1, d]

            node_embeddings.append(node_embedding)

        node_embeddings = torch.cat(node_embeddings, dim=0)  # [B, d]
        return node_embeddings

    def forward(self,
                input_ids=None,
                block_rules=None,
                labels=None,
                return_dict=None,
                root=None,
                boundary=None,
                root_mask=None,
                sorted_sequence=None):


        batch_size = input_ids.size(0) 
        seq_length = input_ids.size(1)

        attention_mask = self.compute_attention_mask(seq_length, block_rules, boundary)
       
        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True,)

        token_embeddings = outputs['hidden_states'][-1]

        low_emb = self.get_node_embedding(token_embeddings, block_rules, boundary, root)
        # low_emb = mean_pooling(token_embeddings, root_mask)

        # low_emb = self.node_input_norm(low_emb)
        
        emb = low_emb

        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)

        # distribution_loss = F.mse_loss( 
        #     F.normalize(low_emb,  p=2, dim=-1),
        #     F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        # )
        # loss = loss + 0.1 * distribution_loss  # 可调系数 

        hidden_states = (token_embeddings, emb, boundary, block_rules, root, sorted_sequence)
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=hidden_states)

class Node_AH_BertClassifier(PreTrainedModel):
    def __init__(self, node_model, n_labels, low_embeddings, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(node_model.config)


        self._set_seed(seed)

        self.low_embeddings = low_embeddings
        self.node_encoder = node_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = node_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(node_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)
        # init_random_state(seed)

        # self.fusion_weights  = nn.Parameter(torch.tensor([0.7,  0.3]))
        self.attention_co = [1, 1, 1, 1, 0, 0, 0]
        self.hidden_dim = hidden_dim

    def _set_seed(self, seed):
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)

    def compute_node_attention_mask(self, boundary, block_rules, seq_length=40):
        batch_size = block_rules.size(0) 
        device = block_rules.device 

        node_attention_mask = torch.zeros(batch_size, seq_length, seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = 0
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                true_seq_length += 1

            node_attention_mask[b, true_seq_length:, 0:true_seq_length] = 1
            
            for i in range(true_seq_length):
                for j in range(true_seq_length):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    # node_attention_mask[b, i, j] = co_value
                    node_attention_mask[b, i, j] = 1
 
        return node_attention_mask

    def get_node_embedding(self, frozen_node_embeddings, token_embeddings, block_rules, boundary, root, sorted_sequence):
        batch_size = block_rules.size(0)
        device = block_rules.device

        all_node_embeddings = []
        for b in range(batch_size):
            node_embeddings = []
            for idx, item in enumerate(boundary[b]):
                start, end = item[0].item(), item[1].item()

                if idx == root[b].item():
                    node_embeddings.append(frozen_node_embeddings[b])
                    # node_embeddings.append(self.low_embeddings[sorted_sequence[b][idx].item()].to(device))

                elif start != -1:
                    # node_embeddings.append(self.low_embeddings[sorted_sequence[b][idx].item()].to(device))
                    node_tokens = token_embeddings[b, start:end+1, :]  # [node_len, dim]
                
                    # 计算节点内注意力权重
                    attn_weights = torch.softmax( 
                        self.intra_node_attention(node_tokens).squeeze(-1), 
                        dim=0 
                    )  # [node_len]
                    
                    # 加权聚合得到节点表示
                    node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)
                    node_embeddings.append(node_embed)

                else:
                    node_embeddings.append(torch.zeros(self.hidden_dim,
                                                        device=device))

            node_embeddings = torch.stack(node_embeddings).to(device)
            all_node_embeddings.append(node_embeddings)

        all_node_embeddings = torch.stack(all_node_embeddings).to(device)
        
        return all_node_embeddings

    def forward(self,
                frozen_node_embeddings=None,
                token_embeddings=None,
                block_rules=None,
                labels=None,
                return_dict=None,
                root=None,
                boundary=None,
                sorted_sequence=None):

        batch_size = block_rules.size(0) 


        all_node_embeddings = self.get_node_embedding(frozen_node_embeddings, token_embeddings, block_rules, boundary, root, sorted_sequence)
        all_node_attention_masks = self.compute_node_attention_mask(boundary, block_rules)
        
        # all_node_embeddings = self.node_input_norm(all_node_embeddings)

        node_outputs = self.node_encoder( 
            inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
            attention_mask=all_node_attention_masks, 
            output_hidden_states=True
        )
        node_hidden_states = node_outputs['hidden_states'][-1]

        low_emb = frozen_node_embeddings
        high_emb = []

        for b in range(batch_size):
            high_emb.append(node_hidden_states[b, root[b].item(), :])
        high_emb = torch.stack(high_emb)

        # weights = torch.softmax(self.fusion_weights,  dim=0)
        emb = 0.7 * high_emb  + 0.3 * low_emb
        # emb = high_emb
        # emb = weights[0] * high_emb  + weights[1] * low_emb

        emb = self.dropout(emb)
        
        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)

        # distribution_loss = F.mse_loss( 
        #     F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
        #     F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        # )
        # loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=emb)

class TAH_BertClassifier(PreTrainedModel):
    def __init__(self, token_model, node_model, n_labels, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.node_encoder = node_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)
        # init_random_state(seed)

        self.fusion_weights  = nn.Parameter(torch.tensor([0.65,  0.35]))
        self.attention_weights  = nn.Parameter(torch.tensor([0.9,  0.1]))
        self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)
        self.attention_co = [1, 1, 1, 1, 0, 0, 0]
        self.Q = nn.Linear(hidden_dim, hidden_dim)
        self.K = nn.Linear(hidden_dim, hidden_dim)
        self.hidden_dim = hidden_dim

    def compute_attention_mask(self, seq_length, block_rules, boundary):
        batch_size = block_rules.size(0) 
        device = block_rules.device 
        
        #attention_mask = torch.full((batch_size, seq_length, seq_length), 0.9999999, dtype=torch.float32, device=device)
        attention_mask = torch.zeros(batch_size,  seq_length, seq_length, device=device)
        #attention_mask = torch.zeros(batch_size,  seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = -1
            nodes = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                nodes.append((start,  end))
                true_seq_length = end + 1

            #attention_mask[b, :, 0:true_seq_length+1] = 1
            #attention_mask[b, 0:true_seq_length+1] = 1
            attention_mask[b, true_seq_length+1:, 0:true_seq_length+1] = 1
            
            for i, (src_start, src_end) in enumerate(nodes):
                for j, (tgt_start, tgt_end) in enumerate(nodes):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    attention_mask[b, src_start:src_end+1, tgt_start:tgt_end+1] = co_value

            #处理SEP符号
            for i, (src_start, src_end) in enumerate(nodes):
                attention_mask[b, src_end+1, 0:true_seq_length+1] = 1
                attention_mask[b, 0:true_seq_length+1, src_end+1] = 1
 
        #torch.set_printoptions(precision=10)
        #print(boundary[0])
        #print(attention_mask[0, 0])
        return attention_mask

    def compute_node_attention_mask(self, boundary, block_rules, seq_length=40):
        batch_size = block_rules.size(0) 
        device = block_rules.device 

        node_attention_mask = torch.zeros(batch_size, seq_length, seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = 0
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                true_seq_length += 1

            node_attention_mask[b, true_seq_length:, 0:true_seq_length] = 1
            
            for i in range(true_seq_length):
                for j in range(true_seq_length):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    node_attention_mask[b, i, j] = co_value
 
        #torch.set_printoptions(precision=10)
        #print(boundary[0])
        #print(attention_mask[0, 0])
        return node_attention_mask

    def get_node_embedding(self, token_embeddings, block_rules, boundary, root):
        batch_size = block_rules.size(0)
        device = block_rules.device

        all_node_embeddings = []
        for b in range(batch_size):
            node_embeddings = []
            for total_idx, total_item in enumerate(boundary[b]):
                total_start, total_end = total_item[0].item(), total_item[1].item()

                if total_idx == root[b].item():
                    root_start = 0
                    root_end = 0
                    value = []

                    for idx, item in enumerate(boundary[b]):
                        start, end = item[0].item(), item[1].item()
                        if start == -1:
                            continue
                        if idx == root[b].item():
                            root_start = start
                            root_end = end
                            continue
                        if block_rules[b][root[b].item()][idx] > 1:
                            continue
                        for index in range(start, end+1):
                            value.append(token_embeddings[b][index])

                    if not value:
                        # 如果没有邻域 token，直接使用 root token
                        for index in range(root_start, root_end+1):
                            value.append(token_embeddings[b][index])

                    value = torch.stack(value).to(device)
                    query = self.Q(value)
                    key = self.K(value)

                    root_value = token_embeddings[b, root_start:root_end+1]
                    root_query = self.Q(root_value)
                    root_key = self.K(root_value)

                    attn_from_neighbors = torch.matmul(query, root_key.transpose(0, 1))  
                    attn_from_neighbors = attn_from_neighbors / (self.hidden_dim ** 0.5)
                    neighbor_attn_weights = torch.softmax(attn_from_neighbors, dim=-1)
                    neighbor_contribution = neighbor_attn_weights.sum(dim=0)
                    outra_attns = torch.softmax(neighbor_contribution.unsqueeze(0), dim=-1).squeeze(0)

                    intra_attns = torch.matmul(root_query, root_key.transpose(0, 1)) / (self.hidden_dim ** 0.5)
                    intra_attns = torch.softmax(intra_attns, dim=-1)
                    intra_attns = intra_attns.sum(dim=0)
                    intra_attns = torch.softmax(intra_attns.unsqueeze(0), dim=-1).squeeze(0)
                    
                    final_weights = self.attention_weights[0] * intra_attns + self.attention_weights[1] * outra_attns 

                    weighted_value = root_value * final_weights.unsqueeze(-1)  # [N, d]
                    node_embedding = weighted_value.sum(dim=0)

                    node_embeddings.append(node_embedding)

                elif total_start != -1:
                    value = []
                    for index in range(total_start, total_end+1):
                        value.append(token_embeddings[b][index])
                    value = torch.stack(value).to(device)
                    query = self.Q(value)
                    key = self.K(value)
                    
                    intra_attns = torch.matmul(query, key.transpose(0, 1)) / (self.hidden_dim ** 0.5)
                    intra_attns = torch.softmax(intra_attns, dim=-1)
                    intra_attns = intra_attns.sum(dim=0)
                    intra_attns = torch.softmax(intra_attns.unsqueeze(0), dim=-1).squeeze(0)

                    final_weights = intra_attns

                    weighted_value = value * final_weights.unsqueeze(-1)  # [N, d]
                    node_embedding = weighted_value.sum(dim=0)

                    node_embeddings.append(node_embedding)

                else:
                    node_embeddings.append(torch.zeros(self.token_encoder.config.hidden_size,
                                                        device=device))

            node_embeddings = torch.stack(node_embeddings).to(device)
            all_node_embeddings.append(node_embeddings)

        all_node_embeddings = torch.stack(all_node_embeddings).to(device)
        
        return all_node_embeddings

    def forward(self,
                input_ids=None,
                block_rules=None,
                labels=None,
                return_dict=None,
                root=None,
                #attention_mask=None,
                boundary=None):

        batch_size = input_ids.size(0) 
        seq_length = input_ids.size(1)

        attention_mask = self.compute_attention_mask(seq_length, block_rules, boundary)        

        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    #return_dict=True,
                                    output_hidden_states=True)


        token_embeddings = outputs['hidden_states'][-1]

        all_node_embeddings = self.get_node_embedding(token_embeddings, block_rules, boundary, root)
        all_node_attention_masks = self.compute_node_attention_mask(boundary, block_rules)
        
        all_node_embeddings = self.node_input_norm(all_node_embeddings)

        node_outputs = self.node_encoder( 
            inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
            attention_mask=all_node_attention_masks, 
            output_hidden_states=True
        )
        node_hidden_states = node_outputs['hidden_states'][-1]

        low_emb = []
        high_emb = []

        for b in range(batch_size):
            low_emb.append(all_node_embeddings[b, root[b].item(), :])
            high_emb.append(node_hidden_states[b, root[b].item(), :])
        low_emb = torch.stack(low_emb)
        high_emb = torch.stack(high_emb)

        weights = torch.softmax(self.fusion_weights,  dim=0)
        emb = weights[0] * high_emb  + weights[1] * low_emb
        # fusion_input = torch.cat([high_emb,  low_emb], dim=-1)
        # weights = torch.softmax(self.fusion(fusion_input),  dim=-1)
        # emb = weights[:, 0:1] * high_emb + weights[:, 1:2] * low_emb

        emb = self.dropout(emb)
        
        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)

        distribution_loss = F.mse_loss( 
            F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
            F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        )
        loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=emb)


class AH_BertClassifier(PreTrainedModel):
    def __init__(self, token_model, node_model, n_labels, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.node_encoder = node_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)
        # init_random_state(seed)

        self.fusion_weights  = nn.Parameter(torch.tensor([0.65,  0.35]))
        self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)
        #self.attention_co = [1, 1, 1, 1, 0, 0, 0]
        self.attention_co = [1, 1, 1, 1, 0, 0, 0]
        #self.fusion = nn.Linear(2*hidden_dim, 2)

    def compute_attention_mask(self, seq_length, block_rules, boundary):
        batch_size = block_rules.size(0) 
        device = block_rules.device 
        
        #attention_mask = torch.full((batch_size, seq_length, seq_length), 0.9999999, dtype=torch.float32, device=device)
        attention_mask = torch.zeros(batch_size,  seq_length, seq_length, device=device)
        #attention_mask = torch.zeros(batch_size,  seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = -1
            nodes = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                nodes.append((start,  end))
                true_seq_length = end + 1

            #attention_mask[b, :, 0:true_seq_length+1] = 1
            #attention_mask[b, 0:true_seq_length+1] = 1
            attention_mask[b, true_seq_length+1:, 0:true_seq_length+1] = 1
            
            for i, (src_start, src_end) in enumerate(nodes):
                for j, (tgt_start, tgt_end) in enumerate(nodes):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    attention_mask[b, src_start:src_end+1, tgt_start:tgt_end+1] = co_value

            #处理SEP符号
            for i, (src_start, src_end) in enumerate(nodes):
                attention_mask[b, src_end+1, 0:true_seq_length+1] = 1
                attention_mask[b, 0:true_seq_length+1, src_end+1] = 1
 
        #torch.set_printoptions(precision=10)
        #print(boundary[0])
        #print(attention_mask[0, 0])
        return attention_mask

    def compute_node_attention_mask(self, boundary, block_rules, seq_length=40):
        batch_size = block_rules.size(0) 
        device = block_rules.device 

        node_attention_mask = torch.zeros(batch_size, seq_length, seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = 0
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                true_seq_length += 1

            node_attention_mask[b, true_seq_length:, 0:true_seq_length] = 1
            
            for i in range(true_seq_length):
                for j in range(true_seq_length):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    node_attention_mask[b, i, j] = co_value
 
        #torch.set_printoptions(precision=10)
        #print(boundary[0])
        #print(attention_mask[0, 0])
        return node_attention_mask

    def forward(self,
                input_ids=None,
                block_rules=None,
                labels=None,
                return_dict=None,
                root=None,
                #attention_mask=None,
                boundary=None):

        batch_size = input_ids.size(0) 
        seq_length = input_ids.size(1)

        attention_mask = self.compute_attention_mask(seq_length, block_rules, boundary)        

        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    #return_dict=True,
                                    output_hidden_states=True)


        token_embeddings = outputs['hidden_states'][-1]
        # attentions = outputs.attentions[0]
        # print(attentions[0,0])
        # print(attentions[0,1])
        # print(attentions[0,-1])
        # print(attentions.shape)
        #return

        batch_size = input_ids.size(0) 
        all_node_embeddings = []
        all_node_attention_masks = self.compute_node_attention_mask(boundary, block_rules)

        for b in range(batch_size):
            node_embeddings = []
            #node_attention_masks = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if start == -1:
                    #node_attention_masks.append(0)
                    node_embeddings.append(torch.zeros(self.token_encoder.config.hidden_size,
                                                        device=input_ids.device))
                    continue

                node_tokens = token_embeddings[b, start:end+1, :]  # [node_len, dim]
                
                # 计算节点内注意力权重
                attn_weights = torch.softmax( 
                    self.intra_node_attention(node_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
                # 加权聚合得到节点表示
                node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
                node_embeddings.append(node_embed) 

            node_embeddings = torch.stack(node_embeddings).to(input_ids.device)
            #node_attention_masks = torch.tensor(node_attention_masks, dtype=torch.long, device=input_ids.device)
            all_node_embeddings.append(node_embeddings)
            #all_node_attention_masks.append(node_attention_masks)

        all_node_embeddings = torch.stack(all_node_embeddings)
        #all_node_attention_masks = torch.stack(all_node_attention_masks)

        all_node_embeddings = self.node_input_norm(all_node_embeddings)

        node_outputs = self.node_encoder( 
            inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
            attention_mask=all_node_attention_masks, 
            output_hidden_states=True
        )
        node_hidden_states = node_outputs['hidden_states'][-1]

        low_emb = []
        high_emb = []

        for b in range(batch_size):
            low_emb.append(all_node_embeddings[b, root[b].item(), :])
            high_emb.append(node_hidden_states[b, root[b].item(), :])
        low_emb = torch.stack(low_emb)
        high_emb = torch.stack(high_emb)

        weights = torch.softmax(self.fusion_weights,  dim=0)
        emb = weights[0] * high_emb  + weights[1] * low_emb
        # fusion_input = torch.cat([high_emb,  low_emb], dim=-1)
        # weights = torch.softmax(self.fusion(fusion_input),  dim=-1)
        # emb = weights[:, 0:1] * high_emb + weights[:, 1:2] * low_emb

        emb = self.dropout(emb)
        
        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)

        distribution_loss = F.mse_loss( 
            F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
            F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        )
        loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=emb)

class LabelBertClassifier(PreTrainedModel):
    def __init__(self, token_model, n_labels, label_text, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls

        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')
        #self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)

        #self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained(model_id[self.config.reduction_lm_type])
        model_path = f""
 
        self.tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            add_special_tokens=False,
            local_files_only=True  # 强制只使用本地文件
        )

        self.label_text = label_text
        self.label_tokens = []
        self.label_masks = []
        for label in label_text:
            label_encoding = self.tokenizer.encode(label, add_special_tokens=False)
            self.label_tokens.append(torch.tensor(label_encoding))
            label_mask = [1 for _ in range(len(label_encoding))]
            self.label_masks.append(torch.tensor(label_mask))

        self.query  = nn.Linear(hidden_dim, hidden_dim)  # 标签特征变换 
        self.key  = nn.Linear(hidden_dim, hidden_dim)   # 节点特征变换
        # self.temperature  = nn.Parameter(torch.tensor([0.1]))

    def compute_attention_mask(self, seq_length, length):
        batch_size = length.size(0) 
        device = length.device 
        
        attention_mask = torch.zeros(batch_size, seq_length, device=device)
        for b in range(batch_size):
            attention_mask[b, 0:length[b]] = 1
        return attention_mask
  
    def forward(self,
                input_ids=None,
                labels=None,
                return_dict=None,
                length=None):
        batch_size = input_ids.size(0) 
        seq_length = input_ids.size(1)
        device = input_ids.device

        attention_mask = self.compute_attention_mask(seq_length, length)
       
        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

        token_embeddings = outputs['hidden_states'][-1]

        node_embeddings = []
        for b in range(batch_size):
            node_tokens = token_embeddings[b, 0:length[b], :]  # [node_len, dim]
                
            # 计算节点内注意力权重
            attn_weights = torch.softmax( 
                self.intra_node_attention(node_tokens).squeeze(-1), 
                dim=0 
            )  # [node_len]
                
            # 加权聚合得到节点表示
            node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            node_embeddings.append(node_embed) 

        node_embeddings = torch.stack(node_embeddings).to(device)


        label_embeddings = []
        for index in range(len(self.label_tokens)):
            label_token = self.label_tokens[index].unsqueeze(0)
            label_mask = self.label_masks[index].unsqueeze(0)
            label_token = label_token.to(device)
            label_mask = label_mask.to(device)
            outputs = self.token_encoder(input_ids=label_token,
                                    attention_mask=label_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

            label_embedding = outputs['hidden_states'][-1]

            label_tokens = label_embedding[0]

            # 计算节点内注意力权重
            attn_weights = torch.softmax( 
                self.intra_node_attention(label_tokens).squeeze(-1), 
                dim=0 
            )
                
            # 加权聚合得到节点表示
            label_embed = (label_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            label_embeddings.append(label_embed) 

        label_embeddings = torch.stack(label_embeddings).to(device)

        # if self.feat_shrink:
        #     emb = self.feat_shrink_layer(emb)
        
        attn_weights = torch.softmax( 
            torch.matmul( 
                self.query(node_embeddings),   # [batch, dim]
                self.key(label_embeddings).T   # [dim, n_labels]
            ) / math.sqrt(self.token_encoder.config.hidden_size),  
            dim=-1 
        )  # [batch, n_labels]
            
        # Q = self.node_project(node_embeddings)   # [batch, dim]
        # K = self.label_project(label_embeddings)  # [n_labels, dim]
        
        # # 缩放点积注意力
        # logits = (Q @ K.T) / self.temperature.clamp(min=1e-6)

        combined_features = torch.cat([ 
            node_embeddings,
            attn_weights  # 直接拼接注意力权重 
        ], dim=-1)
        logits = self.classifier(combined_features)

        loss = self.loss_func(logits, labels)

        # distribution_loss = F.mse_loss( 
        #     F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
        #     F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        # )
        # loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=combined_features)

class LAH_BertClassifier(PreTrainedModel):
    def __init__(self, token_model, node_model, n_labels, label_text, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.node_encoder = node_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim + n_labels, n_labels, bias=cla_bias)
        # init_random_state(seed)

        self.fusion_weights  = nn.Parameter(torch.tensor([0.65,  0.35]))
        self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)
        #self.attention_co = [1, 1, 1, 1, 0, 0, 0]
        self.attention_co = [1, 1, 1, 1, 0, 0, 0]
        #self.fusion = nn.Linear(2*hidden_dim, 2)

        #self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained(model_id[self.config.reduction_lm_type])
        model_path = f""
 
        self.tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            add_special_tokens=False,
            local_files_only=True  # 强制只使用本地文件
        )

        self.label_text = label_text
        self.label_tokens = []
        self.label_masks = []
        for label in label_text:
            label_encoding = self.tokenizer.encode(label, add_special_tokens=False)
            self.label_tokens.append(torch.tensor(label_encoding))
            label_mask = [1 for _ in range(len(label_encoding))]
            self.label_masks.append(torch.tensor(label_mask))

        self.query  = nn.Linear(hidden_dim, hidden_dim)  # 标签特征变换 
        self.key  = nn.Linear(hidden_dim, hidden_dim)   # 节点特征变换

    def compute_attention_mask(self, seq_length, block_rules, boundary):
        batch_size = block_rules.size(0) 
        device = block_rules.device 
        
        #attention_mask = torch.full((batch_size, seq_length, seq_length), 0.9999999, dtype=torch.float32, device=device)
        attention_mask = torch.zeros(batch_size,  seq_length, seq_length, device=device)
        #attention_mask = torch.zeros(batch_size,  seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = -1
            nodes = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                nodes.append((start,  end))
                true_seq_length = end + 1

            #attention_mask[b, :, 0:true_seq_length+1] = 1
            #attention_mask[b, 0:true_seq_length+1] = 1
            attention_mask[b, true_seq_length+1:, 0:true_seq_length+1] = 1
            
            for i, (src_start, src_end) in enumerate(nodes):
                for j, (tgt_start, tgt_end) in enumerate(nodes):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    attention_mask[b, src_start:src_end+1, tgt_start:tgt_end+1] = co_value

            #处理SEP符号
            for i, (src_start, src_end) in enumerate(nodes):
                attention_mask[b, src_end+1, 0:true_seq_length+1] = 1
                attention_mask[b, 0:true_seq_length+1, src_end+1] = 1
 
        #torch.set_printoptions(precision=10)
        #print(boundary[0])
        #print(attention_mask[0, 0])
        return attention_mask

    def compute_node_attention_mask(self, boundary, block_rules, seq_length=40):
        batch_size = block_rules.size(0) 
        device = block_rules.device 

        node_attention_mask = torch.zeros(batch_size, seq_length, seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = 0
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                true_seq_length += 1

            node_attention_mask[b, true_seq_length:, 0:true_seq_length] = 1
            
            for i in range(true_seq_length):
                for j in range(true_seq_length):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    node_attention_mask[b, i, j] = co_value
 
        #torch.set_printoptions(precision=10)
        #print(boundary[0])
        #print(attention_mask[0, 0])
        return node_attention_mask

    def compute_root_attention_mask(self, seq_length, length):
        batch_size = length.size(0) 
        device = length.device 
        
        attention_mask = torch.zeros(batch_size, seq_length, device=device)
        for b in range(batch_size):
            attention_mask[b, 0:length[b]] = 1
        return attention_mask

    def forward(self,
                input_ids=None,
                block_rules=None,
                labels=None,
                return_dict=None,
                root=None,
                #attention_mask=None,
                boundary=None,
                input_root_ids=None,
                length=None):

        batch_size = input_ids.size(0) 
        seq_length = input_ids.size(1)
        device = input_ids.device

        attention_mask = self.compute_attention_mask(seq_length, block_rules, boundary)  

        root_seq_length = input_root_ids.size(1)
        root_attention_mask = self.compute_root_attention_mask(root_seq_length, length)      

        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    #return_dict=True,
                                    output_hidden_states=True)

        root_outputs = self.token_encoder(input_ids=input_root_ids,
                                    attention_mask=root_attention_mask,
                                    return_dict=return_dict,
                                    #return_dict=True,
                                    output_hidden_states=True)

        token_embeddings = outputs['hidden_states'][-1]
        root_token_embeddings = root_outputs['hidden_states'][-1]

        batch_size = input_ids.size(0) 
        all_node_embeddings = []
        all_node_attention_masks = self.compute_node_attention_mask(boundary, block_rules)
        root_embeddings = []
        for b in range(batch_size):
            node_embeddings = []
            
            #node_attention_masks = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if start == -1:
                    #node_attention_masks.append(0)
                    node_embeddings.append(torch.zeros(self.token_encoder.config.hidden_size,
                                                        device=input_ids.device))
                    continue

                node_tokens = token_embeddings[b, start:end+1, :]  # [node_len, dim]
                
                # 计算节点内注意力权重
                attn_weights = torch.softmax( 
                    self.intra_node_attention(node_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
                # 加权聚合得到节点表示
                node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
                node_embeddings.append(node_embed) 

            node_embeddings = torch.stack(node_embeddings).to(input_ids.device)
            #node_attention_masks = torch.tensor(node_attention_masks, dtype=torch.long, device=input_ids.device)
            all_node_embeddings.append(node_embeddings)
            #all_node_attention_masks.append(node_attention_masks)

            root_tokens = root_token_embeddings[b, 0:length[b].item(), :]

            root_attn_weights = torch.softmax( 
                    self.intra_node_attention(root_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
            root_embed = (root_tokens * root_attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            root_embeddings.append(root_embed) 
        
        root_embeddings = torch.stack(root_embeddings).to(device)

        all_node_embeddings = torch.stack(all_node_embeddings)
        #all_node_attention_masks = torch.stack(all_node_attention_masks)

        all_node_embeddings = self.node_input_norm(all_node_embeddings)



        node_outputs = self.node_encoder( 
            inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
            attention_mask=all_node_attention_masks, 
            output_hidden_states=True
        )
        node_hidden_states = node_outputs['hidden_states'][-1]

        label_embeddings = []
        for index in range(len(self.label_tokens)):
            label_token = self.label_tokens[index].unsqueeze(0)
            label_mask = self.label_masks[index].unsqueeze(0)
            label_token = label_token.to(device)
            label_mask = label_mask.to(device)
            outputs = self.token_encoder(input_ids=label_token,
                                    attention_mask=label_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

            label_embedding = outputs['hidden_states'][-1]

            label_tokens = label_embedding[0]

            # 计算节点内注意力权重
            attn_weights = torch.softmax( 
                self.intra_node_attention(label_tokens).squeeze(-1), 
                dim=0 
            )
                
            # 加权聚合得到节点表示
            label_embed = (label_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            label_embeddings.append(label_embed) 

        label_embeddings = torch.stack(label_embeddings).to(device)

        # if self.feat_shrink:
        #     emb = self.feat_shrink_layer(emb)

        low_emb = []
        high_emb = []

        for b in range(batch_size):
            low_emb.append(all_node_embeddings[b, root[b].item(), :])
            high_emb.append(node_hidden_states[b, root[b].item(), :])
        low_emb = torch.stack(low_emb)
        high_emb = torch.stack(high_emb)

        weights = torch.softmax(self.fusion_weights,  dim=0)
        emb = weights[0] * high_emb  + weights[1] * low_emb
        # fusion_input = torch.cat([high_emb,  low_emb], dim=-1)
        # weights = torch.softmax(self.fusion(fusion_input),  dim=-1)
        # emb = weights[:, 0:1] * high_emb + weights[:, 1:2] * low_emb

        emb = self.dropout(emb)

        # print(f"root_embeddings {root_embeddings.shape}")
        # print(f"label_embeddings {label_embeddings.shape}")
        # print(f"emb {emb.shape}")


        attn_weights = torch.softmax( 
            torch.matmul( 
                self.query(root_embeddings),   # [batch, dim]
                self.key(label_embeddings).T   # [dim, n_labels]
            ) / math.sqrt(self.token_encoder.config.hidden_size),  
            dim=-1 
        )  # [batch, n_labels]

        print(attn_weights)

        # print(f"attn_weights {attn_weights}")

        combined_features = torch.cat([ 
            emb,
            attn_weights  # 直接拼接注意力权重 
        ], dim=-1)
        logits = self.classifier(combined_features)
        
        # if self.feat_shrink:
        #     emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(combined_features)

        loss = self.loss_func(logits, labels)

        distribution_loss = F.mse_loss( 
            F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
            F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        )
        loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=combined_features)

class LAttBertClassifier(PreTrainedModel):
    def __init__(self, token_model, n_labels, label_text, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')
        self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)

        model_path = f""
        self.tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            add_special_tokens=False,
            local_files_only=True  # 强制只使用本地文件
        )

        self.label_text = label_text
        self.label_tokens = []
        self.label_masks = []
        for label in label_text:
            label_encoding = self.tokenizer.encode(label, add_special_tokens=False)
            self.label_tokens.append(torch.tensor(label_encoding))
            label_mask = [1 for _ in range(len(label_encoding))]
            self.label_masks.append(torch.tensor(label_mask))

        # self.attention_co = nn.Parameter(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        #                                                [1, 0.9, 0.85, 0.8, 0.9, 0.85, 0.8, 0.75, 0.85],
        #                                                [1, 0.9, 0.85, 0.8, 0.9, 0.85, 0.8, 0.75, 0.85]]), requires_grad=True)

        #self.attention_co = nn.Parameter(torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32), requires_grad=True)
        self.attention_co = [1, 1, 1, 1, 0, 0, 0]

    def compute_attention_mask(self, seq_length, block_rules, boundary):
        batch_size = block_rules.size(0) 
        device = block_rules.device 
        
        #attention_mask = torch.full((batch_size, seq_length, seq_length), 0.9999999, dtype=torch.float32, device=device)
        attention_mask = torch.zeros(batch_size,  seq_length, seq_length, device=device)
        #attention_mask = torch.zeros(batch_size,  seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = -1
            nodes = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                nodes.append((start,  end))
                true_seq_length = end + 1

            #attention_mask[b, :, 0:true_seq_length+1] = 1
            #attention_mask[b, 0:true_seq_length+1] = 1
            attention_mask[b, true_seq_length+1:, 0:true_seq_length+1] = 1
            
            for i, (src_start, src_end) in enumerate(nodes):
                for j, (tgt_start, tgt_end) in enumerate(nodes):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    attention_mask[b, src_start:src_end+1, tgt_start:tgt_end+1] = co_value

            #处理SEP符号
            for i, (src_start, src_end) in enumerate(nodes):
                attention_mask[b, src_end+1, 0:true_seq_length+1] = 1
                attention_mask[b, 0:true_seq_length+1, src_end+1] = 1
 
        #torch.set_printoptions(precision=10)
        #print(boundary[0])
        #print(attention_mask[0, 0])
        return attention_mask

    
    def forward(self,
                input_ids=None,
                block_rules=None,
                labels=None,
                return_dict=None,
                root=None,
                boundary=None):


        batch_size = input_ids.size(0) 
        seq_length = input_ids.size(1)
        device = input_ids.device

        attention_mask = self.compute_attention_mask(seq_length, block_rules, boundary)
       
        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

        token_embeddings = outputs['hidden_states'][-1]

        all_node_embeddings = []
        all_node_attention_masks = []

        for b in range(batch_size):
            node_embeddings = []
            node_attention_masks = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if start == -1:
                    node_attention_masks.append(0)
                    node_embeddings.append(torch.zeros(self.token_encoder.config.hidden_size,
                                                        device=input_ids.device))
                    continue
                else:
                    node_attention_masks.append(1)

                node_tokens = token_embeddings[b, start:end+1, :]  # [node_len, dim]
                
                # 计算节点内注意力权重
                attn_weights = torch.softmax( 
                    self.intra_node_attention(node_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
                # 加权聚合得到节点表示
                node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
                node_embeddings.append(node_embed) 

            node_embeddings = torch.stack(node_embeddings).to(input_ids.device)
            node_attention_masks = torch.tensor(node_attention_masks, dtype=torch.long, device=input_ids.device)
            all_node_embeddings.append(node_embeddings)
            all_node_attention_masks.append(node_attention_masks)

        all_node_embeddings = torch.stack(all_node_embeddings)
        all_node_attention_masks = torch.stack(all_node_attention_masks)

        all_node_embeddings = self.node_input_norm(all_node_embeddings)


        # node_outputs = self.node_encoder( 
        #     inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
        #     attention_mask=all_node_attention_masks, 
        #     output_hidden_states=True
        # )
        # node_hidden_states = node_outputs['hidden_states'][-1]

        label_embeddings = []
        for index in range(len(self.label_tokens)):
            label_token = self.label_tokens[index].unsqueeze(0)
            label_mask = self.label_masks[index].unsqueeze(0)
            label_token = label_token.to(device)
            label_mask = label_mask.to(device)

            # with torch.no_grad():
            outputs = self.token_encoder(input_ids=label_token,
                                    attention_mask=label_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

            label_embedding = outputs['hidden_states'][-1]

            label_tokens = label_embedding[0]

            # 计算节点内注意力权重
            attn_weights = torch.softmax( 
                self.intra_node_attention(label_tokens).squeeze(-1), 
                dim=0 
            )
                
            # 加权聚合得到节点表示
            label_embed = (label_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            label_embeddings.append(label_embed) 

        label_embeddings = torch.stack(label_embeddings).to(device)

        low_emb = []
        #high_emb = []

        for b in range(batch_size):
            low_emb.append(all_node_embeddings[b, root[b].item(), :])
            #high_emb.append(node_hidden_states[b, root[b].item(), :])
        low_emb = torch.stack(low_emb)
        #high_emb = torch.stack(high_emb)

        # weights = torch.softmax(self.fusion_weights,  dim=0)
        # emb = weights[0] * high_emb  + weights[1] * low_emb

        # emb = self.dropout(emb)
        
        emb = low_emb

        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        emb_logits = self.classifier(emb)

        emb_loss = self.loss_func(emb_logits, labels)

        label_logits = emb @ label_embeddings.T

        label_loss = self.loss_func(label_logits, labels)

        loss = emb_loss + 0.1 * label_loss

        distribution_loss = F.mse_loss( 
            F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
            F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        )
        loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=emb_logits, hidden_states=emb)

class OLAttBertClassifier(PreTrainedModel):
    def __init__(self, token_model, n_labels, label_text, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')
        self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)

        model_path = f""
        self.tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            add_special_tokens=False,
            local_files_only=True  # 强制只使用本地文件
        )

        self.label_text = label_text
        self.label_tokens = []
        self.label_masks = []
        for label in label_text:
            label_encoding = self.tokenizer.encode(label, add_special_tokens=False)
            self.label_tokens.append(torch.tensor(label_encoding))
            label_mask = [1 for _ in range(len(label_encoding))]
            self.label_masks.append(torch.tensor(label_mask))

        # self.attention_co = nn.Parameter(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        #                                                [1, 0.9, 0.85, 0.8, 0.9, 0.85, 0.8, 0.75, 0.85],
        #                                                [1, 0.9, 0.85, 0.8, 0.9, 0.85, 0.8, 0.75, 0.85]]), requires_grad=True)

        #self.attention_co = nn.Parameter(torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32), requires_grad=True)
        self.attention_co = [1, 1, 1, 1, 0, 0, 0]

    def compute_attention_mask(self, seq_length, block_rules, boundary):
        batch_size = block_rules.size(0) 
        device = block_rules.device 
        
        #attention_mask = torch.full((batch_size, seq_length, seq_length), 0.9999999, dtype=torch.float32, device=device)
        attention_mask = torch.zeros(batch_size,  seq_length, seq_length, device=device)
        #attention_mask = torch.zeros(batch_size,  seq_length, device=device)
        for b in range(batch_size):
            true_seq_length = -1
            nodes = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if end == -1:
                    break
                nodes.append((start,  end))
                true_seq_length = end + 1

            #attention_mask[b, :, 0:true_seq_length+1] = 1
            #attention_mask[b, 0:true_seq_length+1] = 1
            attention_mask[b, true_seq_length+1:, 0:true_seq_length+1] = 1
            
            for i, (src_start, src_end) in enumerate(nodes):
                for j, (tgt_start, tgt_end) in enumerate(nodes):
                    rule_idx = block_rules[b, i, j]
                    co_value = self.attention_co[rule_idx]
                    attention_mask[b, src_start:src_end+1, tgt_start:tgt_end+1] = co_value

            #处理SEP符号
            for i, (src_start, src_end) in enumerate(nodes):
                attention_mask[b, src_end+1, 0:true_seq_length+1] = 1
                attention_mask[b, 0:true_seq_length+1, src_end+1] = 1
 
        #torch.set_printoptions(precision=10)
        #print(boundary[0])
        #print(attention_mask[0, 0])
        return attention_mask

    def compute_root_attention_mask(self, seq_length, length):
        batch_size = length.size(0) 
        device = length.device 
        
        attention_mask = torch.zeros(batch_size, seq_length, device=device)
        for b in range(batch_size):
            attention_mask[b, 0:length[b]] = 1
        return attention_mask

    def forward(self,
                input_ids=None,
                block_rules=None,
                labels=None,
                return_dict=None,
                root=None,
                boundary=None,
                input_root_ids=None,
                length=None):


        batch_size = input_ids.size(0) 
        seq_length = input_ids.size(1)
        device = input_ids.device

        root_seq_length = input_root_ids.size(1)

        root_attention_mask = self.compute_root_attention_mask(root_seq_length, length)      

        # with torch.no_grad():
        root_outputs = self.token_encoder(input_ids=input_root_ids,
                                    attention_mask=root_attention_mask,
                                    return_dict=return_dict,
                                    #return_dict=True,
                                    output_hidden_states=True)

        root_token_embeddings = root_outputs['hidden_states'][-1]

        root_embeddings = []
        for b in range(batch_size):
            root_tokens = root_token_embeddings[b, 0:length[b].item(), :]

            root_attn_weights = torch.softmax( 
                    self.intra_node_attention(root_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
            root_embed = (root_tokens * root_attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            root_embeddings.append(root_embed) 
        
        root_embeddings = torch.stack(root_embeddings).to(device)

        attention_mask = self.compute_attention_mask(seq_length, block_rules, boundary)
       
        outputs = self.token_encoder(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

        token_embeddings = outputs['hidden_states'][-1]

        all_node_embeddings = []
        all_node_attention_masks = []

        for b in range(batch_size):
            node_embeddings = []
            node_attention_masks = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if start == -1:
                    node_attention_masks.append(0)
                    node_embeddings.append(torch.zeros(self.token_encoder.config.hidden_size,
                                                        device=input_ids.device))
                    continue
                else:
                    node_attention_masks.append(1)

                node_tokens = token_embeddings[b, start:end+1, :]  # [node_len, dim]
                
                # 计算节点内注意力权重
                attn_weights = torch.softmax( 
                    self.intra_node_attention(node_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
                # 加权聚合得到节点表示
                node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
                node_embeddings.append(node_embed) 

            node_embeddings = torch.stack(node_embeddings).to(input_ids.device)
            node_attention_masks = torch.tensor(node_attention_masks, dtype=torch.long, device=input_ids.device)
            all_node_embeddings.append(node_embeddings)
            all_node_attention_masks.append(node_attention_masks)

        all_node_embeddings = torch.stack(all_node_embeddings)
        all_node_attention_masks = torch.stack(all_node_attention_masks)

        all_node_embeddings = self.node_input_norm(all_node_embeddings)


        # node_outputs = self.node_encoder( 
        #     inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
        #     attention_mask=all_node_attention_masks, 
        #     output_hidden_states=True
        # )
        # node_hidden_states = node_outputs['hidden_states'][-1]

        label_embeddings = []
        for index in range(len(self.label_tokens)):
            label_token = self.label_tokens[index].unsqueeze(0)
            label_mask = self.label_masks[index].unsqueeze(0)
            label_token = label_token.to(device)
            label_mask = label_mask.to(device)

            # with torch.no_grad():
            outputs = self.token_encoder(input_ids=label_token,
                                    attention_mask=label_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

            label_embedding = outputs['hidden_states'][-1]

            label_tokens = label_embedding[0]

            # 计算节点内注意力权重
            attn_weights = torch.softmax( 
                self.intra_node_attention(label_tokens).squeeze(-1), 
                dim=0 
            )
                
            # 加权聚合得到节点表示
            label_embed = (label_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            label_embeddings.append(label_embed) 

        label_embeddings = torch.stack(label_embeddings).to(device)

        low_emb = []
        #high_emb = []

        for b in range(batch_size):
            low_emb.append(all_node_embeddings[b, root[b].item(), :])
            #high_emb.append(node_hidden_states[b, root[b].item(), :])
        low_emb = torch.stack(low_emb)
        #high_emb = torch.stack(high_emb)

        # weights = torch.softmax(self.fusion_weights,  dim=0)
        # emb = weights[0] * high_emb  + weights[1] * low_emb

        # emb = self.dropout(emb)
        
        emb = low_emb

        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        emb_logits = self.classifier(emb)

        emb_loss = self.loss_func(emb_logits, labels)

        label_logits = root_embeddings @ label_embeddings.T

        label_loss = self.loss_func(label_logits, labels)

        loss = emb_loss + 0.1 * label_loss

        distribution_loss = F.mse_loss( 
            F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
            F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        )
        loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=emb_logits, hidden_states=emb)

class TestLabelBertClassifier(PreTrainedModel):
    def __init__(self, token_model, n_labels, label_text, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.n_labels = n_labels
        self.token_encoder = token_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)

        self.temperature = nn.Parameter(torch.tensor(1.0))
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)
        # self.comparer = nn.Linear(hidden_dim, 1, bias=cla_bias)

        # for param in self.token_encoder.parameters(): 
        #     param.requires_grad_(False)   # 冻结所有LM参数
        # self.token_encoder.eval()

        model_path = f""
        self.tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            add_special_tokens=False,
            local_files_only=True  # 强制只使用本地文件
        )

        # self.transform = nn.Linear(hidden_dim, hidden_dim, bias=cla_bias)
        # self.encoder = MLP(hidden_dim, hidden_dim * 2, hidden_dim)
        #self.temperature = nn.Parameter(torch.tensor(1.0))

        self.label_text = label_text
        self.label_tokens = []
        self.label_masks = []
        for label in label_text:
            label_encoding = self.tokenizer.encode(label, add_special_tokens=False)
            self.label_tokens.append(torch.tensor(label_encoding))
            label_mask = [1 for _ in range(len(label_encoding))]
            self.label_masks.append(torch.tensor(label_mask))

    def compute_attention_mask(self, boundary, seq_length):
        batch_size = boundary.size(0) 
        device = boundary.device
        
        attention_mask = torch.zeros(batch_size, seq_length, seq_length, device=device)
        true_length = torch.zeros(batch_size, device=device)
        for b in range(batch_size):
            true_start = boundary[b][-1][0].item()
            true_end = boundary[b][-1][1].item()
            true_length[b] = boundary[b][-1][1].item() + 1
            attention_mask[b, 0:true_end+1, 0:true_end+1] = 1
            attention_mask[b, 0:true_start, 0:true_start] = 0
            for idx in range(len(boundary[b])-1):
                start, end = boundary[b][idx][0].item(), boundary[b][idx][1].item()
                attention_mask[b, start:end+1, start:end+1] = 1
            attention_mask[b, true_end+1:, 0:true_end+1] = 1
            
        # torch.set_printoptions(precision=10)
        # print(boundary[0])
        # print(attention_mask[0, 0])
        return attention_mask, true_length

    def compute_root_attention_mask(self, seq_length, length):
        batch_size = length.size(0) 
        device = length.device 
        
        attention_mask = torch.zeros(batch_size, seq_length, device=device)
        for b in range(batch_size):
            attention_mask[b, 0:length[b]] = 1
        return attention_mask

    def forward(self,
                #input_ids=None,
                #block_rules=None,
                labels=None,
                return_dict=None,
                #root=None,
                #attention_mask=None,
                #boundary=None,
                input_root_ids=None,
                length=None):

        batch_size = input_root_ids.size(0) 
        device = input_root_ids.device
        root_seq_length = input_root_ids.size(1)

        root_attention_mask = self.compute_root_attention_mask(root_seq_length, length)      

        # with torch.no_grad():
        root_outputs = self.token_encoder(input_ids=input_root_ids,
                                    attention_mask=root_attention_mask,
                                    return_dict=return_dict,
                                    #return_dict=True,
                                    output_hidden_states=True)

        root_token_embeddings = root_outputs['hidden_states'][-1]

        root_embeddings = []
        for b in range(batch_size):
            root_tokens = root_token_embeddings[b, 0:length[b].item(), :]

            root_attn_weights = torch.softmax( 
                    self.intra_node_attention(root_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
            root_embed = (root_tokens * root_attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            root_embeddings.append(root_embed) 
        
        root_embeddings = torch.stack(root_embeddings).to(device)


        label_embeddings = []
        for index in range(len(self.label_tokens)):
            label_token = self.label_tokens[index].unsqueeze(0)
            label_mask = self.label_masks[index].unsqueeze(0)
            label_token = label_token.to(device)
            label_mask = label_mask.to(device)

            # with torch.no_grad():
            outputs = self.token_encoder(input_ids=label_token,
                                    attention_mask=label_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

            label_embedding = outputs['hidden_states'][-1]

            label_tokens = label_embedding[0]

            # 计算节点内注意力权重
            attn_weights = torch.softmax( 
                self.intra_node_attention(label_tokens).squeeze(-1), 
                dim=0 
            )
                
            # 加权聚合得到节点表示
            label_embed = (label_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            label_embeddings.append(label_embed) 

        label_embeddings = torch.stack(label_embeddings).to(device)



        # logits = []
        # for label_emb in label_embeddings:  # [hidden_dim]
        #     # 扩展标签维度以匹配批次 
        #     expanded_label = label_emb.unsqueeze(0).expand(batch_size,  -1)  # [batch, hidden_dim]
            
        #     # # 拼接特征
        #     # combined = torch.cat([ 
        #     #     root_embeddings,  # [batch, hidden_dim]
        #     #     expanded_label    # [batch, hidden_dim]
        #     # ], dim=-1)  # [batch, hidden_dim*2]
        #     combined = root_embeddings - expanded_label
            
        #     #计算单标签得分
        #     score = self.comparer(combined).squeeze(-1)   # [batch]
        #     #score = self.comparer(root_embeddings).squeeze(-1)   # [batch]
        #     logits.append(score) 
    
        # logits = torch.stack(logits,  dim=-1)  # [batch, n_labels]

        #logits = self.comparer(root_embeddings)

        # print(f"root {root_embeddings.shape}")
        # print(f"label {label_embedding}")

        # root_embeddings_normalized = F.normalize(root_embeddings, p=2, dim=1)
        # label_embeddings_normalized = F.normalize(label_embeddings, p=2, dim=1)

        # root_embeddings = self.encoder(root_embeddings)
        # label_embeddings = self.encoder(label_embeddings)
            
        logits = root_embeddings @ label_embeddings.T

        loss = self.loss_func(logits, labels)
        

        # distribution_loss = F.mse_loss( 
        #     F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
        #     F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        # )
        # loss = loss + 0.1 * distribution_loss  # 可调系数 
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=logits)

class SelfLabelBertClassifier(PreTrainedModel):
    def __init__(self, token_model, n_labels, label_initial, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)

        if seed is not None:
            torch.manual_seed(seed)

        hidden_dim = token_model.config.hidden_size
        self.token_encoder = token_model
        self.n_labels = n_labels
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')

        self.intra_node_attention  = nn.Sequential(
            nn.Linear(hidden_dim,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        self.encoder = MLP(hidden_dim, hidden_dim*2, hidden_dim)

        transformer_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*2, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=1)

        self.temperature = nn.Parameter(torch.tensor(1.0))

        self.label_encodings = nn.Parameter(label_initial.clone().detach().requires_grad_(True))

    def forward(self,
                input_encodings=None,
                length=None,
                labels=None,
                return_dict=None,
                ):

        batch_size = input_encodings.size(0) 
        device = input_encodings.device

        node_embeddings = []
        for b in range(batch_size):
            node_tokens = input_encodings[b, 0:length[b].item(), :]

            node_tokens_encoded = self.transformer_encoder(node_tokens.unsqueeze(1).transpose(0, 1)).transpose(0, 1).squeeze(1)

            node_attn_weights = torch.softmax( 
                    self.intra_node_attention(node_tokens_encoded).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
            node_embed = (node_tokens_encoded * node_attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
            node_embeddings.append(node_embed) 
        
        node_embeddings = torch.stack(node_embeddings).to(device)

        # node_embeddings_encoded = self.encoder(node_embeddings)
        # label_encodings_encoded = self.encoder(self.label_encodings)
            
        logits = node_embeddings @ self.label_encodings.T / self.temperature

        loss = self.loss_func(logits, labels)
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=logits)


class Train_AttBertClassifier(PreTrainedModel):
    def __init__(self, token_model, n_labels, dropout=0.0, seed=0, cla_bias=True, feat_shrink='', use_cls=True):
        super().__init__(token_model.config)
        self.token_encoder = token_model
        self.dropout = nn.Dropout(dropout)
        self.feat_shrink = feat_shrink
        self.use_cls = use_cls
        hidden_dim = token_model.config.hidden_size
        self.loss_func = nn.CrossEntropyLoss(
            label_smoothing=0.3, reduction='mean')
        self.node_input_norm  = nn.LayerNorm(token_model.config.hidden_size)

        # 节点内注意力机制 
        self.intra_node_attention  = nn.Sequential(
            nn.Linear(token_model.config.hidden_size,  128),
            nn.Tanh(),
            nn.Linear(128, 1, bias=False)
        )

        if feat_shrink:
            self.feat_shrink_layer = nn.Linear(
                token_model.config.hidden_size, int(feat_shrink), bias=cla_bias)
            hidden_dim = int(feat_shrink)
        self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias)

        # self.attention_co = nn.Parameter(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        #                                                [1, 0.9, 0.85, 0.8, 0.9, 0.85, 0.8, 0.75, 0.85],
        #                                                [1, 0.9, 0.85, 0.8, 0.9, 0.85, 0.8, 0.75, 0.85]]), requires_grad=True)

        self.attention_co = nn.Parameter(torch.tensor([1, 1, 1, 1, 1, 1, 1], dtype=torch.float32), requires_grad=True)
        
        # >>> Add this hook <<<
        def print_attention_co_grad_hook(grad):
            if grad is not None:
                print(f"\n--- Hook: attention_co 梯度: {grad}")
                print(f"--- Hook: attention_co 梯度的范数: {grad.norm()}")
            else:
                print(f"\n--- Hook: attention_co 梯度为 None (可能未计算或已清零)")
        self.attention_co.register_hook(print_attention_co_grad_hook)
        #self.attention_co = [1, 1, 1, 1, 1, 1, 1, 1, 1]

    def compute_attention_mask(self, seq_length, block_rules, boundary):
        batch_size = block_rules.size(0)
        device = block_rules.device

        temp_rule_indices = [[[0 for _ in range(seq_length)] for _ in range(seq_length)] for _ in range(batch_size)]

        for b in range(batch_size):
            # 确保这里获取的节点是纯 Python 整数元组
            nodes = [(idx[0].item(), idx[1].item()) for idx in boundary[b] if idx[1] != -1]

            for i, (src_s, src_e) in enumerate(nodes):
                for j, (tgt_s, tgt_e) in enumerate(nodes):
                    # 确保 rule_idx 是纯 Python 整数
                    rule_idx = block_rules[b, i, j].item() 
                    
                    # 填充对应的区域，这里完全是 Python 列表的赋值操作
                    for row in range(src_s, src_e + 1):
                        for col in range(tgt_s, tgt_e + 1):
                            temp_rule_indices[b][row][col] = rule_idx

        rule_indices_tensor = torch.tensor(temp_rule_indices, dtype=torch.long, device=device)

        rule_matrix = F.one_hot(rule_indices_tensor, num_classes=7).float() # [B, L, L, 7]

        # 矩阵化计算 (完全可微分)
        # 这一步是关键，self.attention_co 的梯度会通过这个 matmul 操作回传
        attention_values = torch.matmul(
            rule_matrix.flatten(0,2),  # [B*L*L, 7]
            self.attention_co.unsqueeze(-1)  # [7,1]
        ).view(batch_size, seq_length, seq_length)

        # sep_mask = torch.zeros_like(attention_values) 
        # for b in range(batch_size):
        #     nodes = [(idx[0].item(), idx[1].item()) for idx in boundary[b] if idx[1] != -1]
        #     for _, end in nodes:
        #         pos = end + 1
        #         sep_mask[b, pos, :] = 1 
        #         sep_mask[b, :, pos] = 1 
        # attention_values = attention_values * (1 - sep_mask) + sep_mask  # 可微
        
        return attention_values 
    
    def forward(self,
                input_ids=None,
                block_rules=None,
                labels=None,
                return_dict=None,
                root=None,
                boundary=None):

        assert self.attention_co.requires_grad

        torch.set_printoptions(precision=10)
        print(self.attention_co)

        batch_size = input_ids.size(0) 
        seq_length = input_ids.size(1)

        with torch.autograd.detect_anomaly(): 
            attention_mask = self.compute_attention_mask(seq_length, block_rules, boundary)
            
            outputs = self.token_encoder(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        return_dict=return_dict,
                                        output_hidden_states=True)

        print(f"参数梯度: {self.attention_co.grad}")

        token_embeddings = outputs['hidden_states'][-1]

        all_node_embeddings = []
        all_node_attention_masks = []

        for b in range(batch_size):
            node_embeddings = []
            node_attention_masks = []
            for index in boundary[b]:
                start, end = index[0].item(), index[1].item()
                if start == -1:
                    node_attention_masks.append(0)
                    node_embeddings.append(torch.zeros(self.token_encoder.config.hidden_size,
                                                        device=input_ids.device))
                    continue
                else:
                    node_attention_masks.append(1)

                node_tokens = token_embeddings[b, start:end+1, :]  # [node_len, dim]
                
                # 计算节点内注意力权重
                attn_weights = torch.softmax( 
                    self.intra_node_attention(node_tokens).squeeze(-1), 
                    dim=0 
                )  # [node_len]
                
                # 加权聚合得到节点表示
                node_embed = (node_tokens * attn_weights.unsqueeze(-1)).sum(dim=0)   # [dim]
                node_embeddings.append(node_embed) 

            node_embeddings = torch.stack(node_embeddings).to(input_ids.device)
            node_attention_masks = torch.tensor(node_attention_masks, dtype=torch.long, device=input_ids.device)
            all_node_embeddings.append(node_embeddings)
            all_node_attention_masks.append(node_attention_masks)

        all_node_embeddings = torch.stack(all_node_embeddings)
        all_node_attention_masks = torch.stack(all_node_attention_masks)

        all_node_embeddings = self.node_input_norm(all_node_embeddings)

        # node_outputs = self.node_encoder( 
        #     inputs_embeds=all_node_embeddings,   # [batch_size, num_nodes, dim]
        #     attention_mask=all_node_attention_masks, 
        #     output_hidden_states=True
        # )
        # node_hidden_states = node_outputs['hidden_states'][-1]

        low_emb = []
        #high_emb = []

        for b in range(batch_size):
            low_emb.append(all_node_embeddings[b, root[b].item(), :])
            #high_emb.append(node_hidden_states[b, root[b].item(), :])
        low_emb = torch.stack(low_emb)
        #high_emb = torch.stack(high_emb)

        # weights = torch.softmax(self.fusion_weights,  dim=0)
        # emb = weights[0] * high_emb  + weights[1] * low_emb

        # emb = self.dropout(emb)
        
        emb = low_emb

        if self.feat_shrink:
            emb = self.feat_shrink_layer(emb)
            
        logits = self.classifier(emb)

        loss = self.loss_func(logits, labels)

        distribution_loss = F.mse_loss( 
            F.normalize(all_node_embeddings.mean(dim=1),  p=2, dim=-1),
            F.normalize(token_embeddings.mean(dim=1),  p=2, dim=-1)
        )
        #co_reg_loss = 0.01 * torch.norm(self.attention_co  - 1.0)  # 拉向初始值
        loss = loss + 0.1 * distribution_loss  # 可调系数 
        #loss = loss + 0.1 * distribution_loss + co_reg_loss
        
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=emb)