import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from transformers import Qwen2ForCausalLM, Qwen2Tokenizer,Qwen2Model, AutoModel
from transformers import Qwen3Model

class Relation_Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Relation_Classifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        # self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        # out = self.softmax(out)
        return out
    
    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)


class Relevance_model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Relevance_model, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)



class Qwen_relevence_atten_model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size = 24, llm = None,
                forward_method = None,
                apply_mask = False):
        super(Qwen_relevence_atten_model, self).__init__()
        if not llm:
            self.embedding = Qwen3Model.from_pretrained(
                "Qwen/Qwen3-0.6B",
                # "Qwen/Qwen3-1.7B",
                attn_implementation="eager"
            ).to(torch.bfloat16)
        else:
            self.embedding = llm

        print("using point line surface ")
        # 读取模型dim 以及 attention head
        self.emb_dim = self.embedding.config.hidden_size
        self.emb_attn_head = self.embedding.config.num_attention_heads
        emb_dim = self.emb_dim
        emb_attn_head = self.emb_attn_head
        self.add_entity_type = False
        self.entity_type_dict = {'ORGANIZATION':0, 'PERSON':1, 'NUMBER':2, 'NATIONALITY':3, 'DATE':4, 'COUNTRY':5, 'MISC':6, 'TITLE':7, 'RELIGION':8, 'LOCATION':9, 'CITY':10, 'IDEOLOGY':11, 'DURATION':12, 'STATE_OR_PROVINCE':13, 'CAUSE_OF_DEATH':14, 'URL':15, 'CRIMINAL_CHARGE':16}

        self.fused_method = "mlp(last_token||attention)||emb"
        # self.fused_method = "cat(cls||(atten1+dropout(span1))@(atten2+dropout(span2)))"
        # self.fused_method = "cat(cls||(atten1+span1)||(atten2+span2))"
        # self.fused_method = "cat(cls||(atten1+span1)||(atten2+span2))RNN"
        # self.fused_method = "mlp(last_token||mlp(attention))||emb"
        # self.fused_method = "mlp(last_token||attention)||add(emb)"
        # self.fused_method = "other"
        if self.fused_method == "mlp(last_token||attention)||emb" or self.fused_method == "mlp(last_token||attention)||add(emb)":
            # self.mlp = nn.Sequential(
            #     nn.Linear(emb_dim + emb_attn_head * 2, emb_dim),
            #     nn.ReLU(),
            #     nn.Linear(emb_dim, emb_dim)
            # ).to(torch.bfloat16)

            # self.attn_mlp = nn.Sequential(
            #     nn.Linear(emb_attn_head *2 , emb_dim),
            #     nn.ReLU(),
            #     nn.Linear(emb_dim, emb_dim) 
            # ).to(torch.bfloat16)

            # 0803 "mlp(last_token||attention) add emb"
            self.attn_mlp =nn.Sequential(
                nn.Linear(emb_attn_head *3 + emb_dim , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim) 
            ).to(torch.bfloat16)

        elif self.fused_method == "mlp(last_token||mlp(attention))||emb":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim * 2, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)

            self.attn_mlp = nn.Sequential(
                nn.Linear(emb_attn_head *6 , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim) 
            ).to(torch.bfloat16)

        elif self.fused_method == "cat(cls||(atten1+dropout(span1))@(atten2+dropout(span2)))":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)

            self.attn_mlp = nn.Sequential(
                nn.Linear(emb_attn_head, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim) 
            ).to(torch.bfloat16)

            self.dropout = nn.Dropout(0.01)
            # self.mask 
        elif self.fused_method == "cat(cls||(atten1+span1)||(atten2+span2))":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)

            self.attn_mlp = nn.Sequential(
                nn.Linear(emb_attn_head * 3, emb_dim),
                # nn.Linear(emb_attn_head * 2, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim) 
            ).to(torch.bfloat16)
        elif self.fused_method == "cat(cls||(atten1+span1)||(atten2+span2))RNN":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)

            self.sub_rnn = nn.RNN(emb_dim, 
                            emb_dim,
                            num_layers= 3,
                            batch_first=True).to(torch.bfloat16)

            self.obj_rnn = nn.RNN(emb_dim, 
                            emb_dim,
                            num_layers= 3,
                            batch_first=True).to(torch.bfloat16)

            self.attn_mlp = nn.Sequential(
                nn.Linear(emb_attn_head * 3, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim) 
            ).to(torch.bfloat16)
        else:
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)

        if  self.fused_method == "mlp(last_token||attention)||add(emb)":
            input_size = input_size //3 * 2
        self.forward_method = forward_method
        self.apply_mask = apply_mask
        self.classifier = Relation_Classifier(input_size, hidden_size, output_size).to(self.embedding.device).to(torch.bfloat16)
        # self.L_mlp = nn.Sequential(
        #         nn.Linear(emb_dim , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim)
        #     ).to(torch.bfloat16)
        # self.R_mlp = nn.Sequential(
        #         nn.Linear(emb_dim , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim) 
        #     ).to(torch.bfloat16)

        # self.entity_mlp = nn.Sequential(
        #         nn.Linear(emb_dim , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim) 
        #     ).to(torch.bfloat16)
        
        self.head_entity_mlp = nn.Sequential(
                nn.Linear(emb_dim , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)
        self.tail_entity_mlp = nn.Sequential(
                nn.Linear(emb_dim , emb_dim),
                nn.ReLU(),  
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)

        # self.attn_mlp = nn.Sequential(
        #         nn.Linear(emb_attn_head *2 , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim) 
        #     ).to(torch.bfloat16)
        # self.eL_mlp = nn.Sequential(
        #         nn.Linear(emb_dim , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim)
        #     ).to(torch.bfloat16)
        # self.eR_mlp = nn.Sequential(
        #         nn.Linear(emb_dim , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim) 
        #     ).to(torch.bfloat16)

        self.attn_emb_mlp = nn.Sequential(
                nn.Linear(emb_dim + emb_attn_head , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)
        # self.loss_fn = nn.CrossEntropyLoss()
        if self.add_entity_type:
            self.entity_type_embedding = nn.Embedding(len(self.entity_type_dict), emb_dim).to(torch.bfloat16)
        self.loss_fn = nn.BCEWithLogitsLoss()



    # def get_span_and_content_emb(self, embs, entity_spans, attn, context_window=5):
    def get_span_and_content_emb(self, embs, subject_spans, object_spans, attn, context_window=5):
        """
        计算实体的平均 embedding、上下文 embedding 和注意力分数。

        参数：
        - embs: (batch_size, seq_len, embedding_dim) 的张量，表示 token 级别的 embedding。
        - subject_spans: (batch_size, 2) 的张量
        - object_spans: (batch_size, 2) 的张量
        - attn: (batch_size, num_heads, seq_len, seq_len) 的张量，表示注意力分数。
        - context_window: int，表示左右边界的 token 数量。

        返回：
        - entity_embs_avg: (batch_size, num_entities, embedding_dim) 的张量，每个 batch 保持自己的实体数。
        - context_embs_avg: (batch_size, num_entities, embedding_dim) 的张量，每个实体的上下文 embedding。
        - attention_scores: (batch_size, num_entities, num_heads) 的张量，每个实体的注意力分数。
        """

        batch_size,  _ = subject_spans.shape
        _, seq_len, embedding_dim = embs.shape

        # 计算实体的起始和结束索引
        subject_spans_start = subject_spans[:,0]
        subject_spans_end = subject_spans[:,1]
        object_spans_start = object_spans[:,0]
        object_spans_end = object_spans[:,1]

        # 计算上下文的左右边界
        sub_left_indices  = torch.clamp(subject_spans_start - context_window, min=0)  # 左边界，防止越界
        sub_right_indices = torch.clamp(subject_spans_end   + context_window, max=seq_len - 1)  # 右边界，防止越界

        obj_left_indices  = torch.clamp(object_spans_start - context_window, min=0)  # 左边界，防止越界
        obj_right_indices = torch.clamp(object_spans_end   + context_window, max=seq_len - 1)  # 右边界，防止越界
        
        # entity embs 
        sub_embs = embs.gather(dim=1,index=subject_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
        obj_embs = embs.gather(dim=1,index=object_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
        sub_left_embs = self.eL_mlp(sub_embs[:, 0, :])  # (B, D)
        sub_right_embs = self.eR_mlp(sub_embs[:, 1, :])  # (B, D)
        sub_embs = (sub_left_embs + sub_right_embs) / 2  # (B, D)

        obj_left_embs = self.eL_mlp(obj_embs[:, 0, :])  # (B, D)
        obj_right_embs = self.eR_mlp(obj_embs[:, 1, :])  # (B, D)
        obj_embs = (obj_left_embs + obj_right_embs) / 2  # (B, D)


        # context embs
        sub_context_left_embs = self.L_mlp(embs.gather(dim=1, 
                                                       index=sub_left_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1,embs.size(-1)))  ).squeeze(1)  # (B, D)
        sub_context_right_embs = self.R_mlp(embs.gather(dim=1, 
                                                        index=sub_right_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, embs.size(-1))) 
                                                        ).squeeze(1)  # (B, D)
        sub_content_embs = (sub_context_left_embs + sub_context_right_embs) / 2

        obj_context_left_embs = self.L_mlp(embs.gather(dim=1, 
                                                       index=obj_left_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, embs.size(-1)))
                                                       ).squeeze(1)  # (B, D)
        obj_context_right_embs = self.R_mlp(embs.gather(dim=1, 
                                                        index=obj_right_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, embs.size(-1)))
                                                        ).squeeze(1)  # (B, D)
        obj_content_embs = (obj_context_left_embs + obj_context_right_embs) / 2

        attn_score = self.get_attn(attn, start = sub_right_indices.unsqueeze(-1), end = obj_right_indices.unsqueeze(-1))  # (B, E, H)
        attn_score = attn_score.squeeze(1) # B H
        # print(attn_score.shape)
        
        return sub_embs, obj_embs, sub_content_embs, obj_content_embs, attn_score  # 返回实体和上下文的 embedding

    def get_span_and_content_attn_emb(self, embs, subject_spans, object_spans, attn, context_window=5):
        """
        计算实体的平均 embedding、上下文 embedding 和注意力分数。

        参数：
        - embs: (batch_size, seq_len, embedding_dim) 的张量，表示 token 级别的 embedding。
        - subject_spans: (batch_size, 2) 的张量
        - object_spans: (batch_size, 2) 的张量
        - attn: (batch_size, num_heads, seq_len, seq_len) 的张量，表示注意力分数。
        - context_window: int，表示左右边界的 token 数量。

        返回：
        - entity_embs_avg: (batch_size, num_entities, embedding_dim) 的张量，每个 batch 保持自己的实体数。
        - context_embs_avg: (batch_size, num_entities, embedding_dim) 的张量，每个实体的上下文 embedding。
        - attention_scores: (batch_size, num_entities, num_heads) 的张量，每个实体的注意力分数。
        """

        # Batch_size, seq_len, embedding_dim = embs.shape

        # sequence_lengths = attention_mask.sum(dim=1) - 1  # 计算有效 token 的索引
        # last_token = embs[torch.arange(batch_size), sequence_lengths]  # (B, D)

        # # 计算实体的起始和结束索引
        # subject_spans_start = subject_spans[:,0]
        # subject_spans_end = subject_spans[:,1]
        # object_spans_start = object_spans[:,0]
        # object_spans_end = object_spans[:,1]

        # 计算上下文的左右边界
        # sub_left_indices  = torch.clamp(subject_spans_start - context_window, min=0)  # 左边界，防止越界
        # sub_right_indices = torch.clamp(subject_spans_end   + context_window, max=seq_len - 1)  # 右边界，防止越界

        # obj_left_indices  = torch.clamp(object_spans_start - context_window, min=0)  # 左边界，防止越界
        # obj_right_indices = torch.clamp(object_spans_end   + context_window, max=seq_len - 1)  # 右边界，防止越界
        
        # entity embs fix
        try:
            seq_len = embs.size(1)  # 获取序列长度
            # 裁剪 subject_spans 和 object_spans 中的索引值
            subject_spans = torch.clamp(subject_spans, 0, seq_len - 1)
            object_spans = torch.clamp(object_spans, 0, seq_len - 1)
            sub_embs = embs.gather(dim=1,index=subject_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
            obj_embs = embs.gather(dim=1,index=object_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
        except:
            print(subject_spans)

        # sub_embs = self.entity_mlp(sub_embs[:,0,:] + sub_embs[:, 1, :])  # (B, D)
        # obj_embs = self.entity_mlp(obj_embs[:,0,:] + obj_embs[:, 1, :])  # (B, D)

        # sub_embs = self.head_entity_mlp(sub_embs[:,0,:] + sub_embs[:, 1, :])  # (B, D)
        # obj_embs = self.tail_entity_mlp(obj_embs[:,0,:] + obj_embs[:, 1, :])  # (B, D)
        # 0722 fix
        if self.forward_method == "fhead_add_ftail":
            sub_embs = self.head_entity_mlp(sub_embs[:,0,:]) + self.head_entity_mlp(sub_embs[:, 1, :])
            obj_embs = self.tail_entity_mlp(obj_embs[:,0,:]) + self.tail_entity_mlp(obj_embs[:, 1, :])
        elif self.forward_method == "fhead_sub_ftail":
            sub_embs = self.head_entity_mlp(sub_embs[:,0,:]) - self.head_entity_mlp(sub_embs[:, 1, :])  # (B, D)
            obj_embs = self.tail_entity_mlp(obj_embs[:,0,:]) - self.tail_entity_mlp(obj_embs[:, 1, :])  # (B, D)
        elif self.forward_method == "fhead_add_tail":
            sub_embs = self.head_entity_mlp(sub_embs[:,0,:] + sub_embs[:, 1, :])
            obj_embs = self.tail_entity_mlp(obj_embs[:,0,:] + obj_embs[:, 1, :])
        elif self.forward_method == "fhead_sub_tail":
            sub_embs = self.head_entity_mlp(sub_embs[:,0,:] - sub_embs[:, 1, :])
            obj_embs = self.tail_entity_mlp(obj_embs[:,0,:] - obj_embs[:, 1, :])

        # if self.apply_mask :
        if self.apply_mask == 'True' :
            sub_embs_mask = torch.rand(sub_embs.size(0), 1, device=sub_embs.device) < 0.1
            sub_embs = sub_embs.masked_fill(sub_embs_mask, 0)  # (B, D)
            obj_embs_mask = torch.rand(obj_embs.size(0), 1, device=obj_embs.device) < 0.1
            obj_embs = obj_embs.masked_fill(obj_embs_mask, 0)  # (B, D)

        # sub_embs = self.head_entity_mlp(sub_embs[:,0,:]) +self.head_entity_mlp(sub_embs[:, 1, :])  # (B, D)
        # obj_embs = self.tail_entity_mlp(obj_embs[:,0,:]) +self.tail_entity_mlp(obj_embs[:, 1, :])  # (B, D)

        # sub_embs = sub_embs[:,0,:] - sub_embs[:, 1, :]
        # obj_embs = obj_embs[:,0,:] - obj_embs[:, 1, :]


        # sub_left_embs = self.eL_mlp(sub_embs[:, 0, :])  # (B, D)
        # sub_right_embs = self.eR_mlp(sub_embs[:, 1, :])  # (B, D)
        # sub_embs = (sub_left_embs + sub_right_embs) / 2  # (B, D)

        # obj_left_embs = self.eL_mlp(obj_embs[:, 0, :])  # (B, D)
        # obj_right_embs = self.eR_mlp(obj_embs[:, 1, :])  # (B, D)
        # obj_embs = (obj_left_embs + obj_right_embs) / 2  # (B, D)


        # context embs

        # # content_left_indices = torch.min(subject_spans_start, object_spans_start)
        # content_left_indices = torch.min(sub_left_indices, obj_left_indices)
        # content_right_indices = torch.max(sub_right_indices, obj_right_indices)

        # content_left_embs = self.L_mlp(embs.gather(dim=1,index=content_left_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1,embs.size(-1)))  ).squeeze(1)  # (B, D)
        # content_right_embs = self.R_mlp(embs.gather(dim=1,index=content_right_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, embs.size(-1)))).squeeze(1)  # (B, D)
        # content_embs = (content_left_embs + content_right_embs) / 2  # (B, D)
        # content_attn = self.get_attn(attn, start = content_left_indices.unsqueeze(-1), end= content_right_indices.unsqueeze(-1)).squeeze(1)  # (B, H)
        # content_attn_emb = self.attn_emb_mlp(torch.cat([content_embs, content_attn], dim=-1))  # (B, D + H) -> (B, D) # 融合attn_score

        # attn_score = self.get_attn(attn, start = sub_right_indices.unsqueeze(-1), end = obj_right_indices.unsqueeze(-1))  # (B, E, H)
        # attn_score = attn_score.squeeze(1) # B H
        # print(attn_score.shape)
        
        # return sub_embs, obj_embs, content_attn_emb, attn_score  # 返回实体和上下文的 embedding
        return sub_embs, obj_embs

    def get_rnn_emb(self, embs, subject_spans, object_spans, attn = None, context_window=5):

        # # entity embs fix
        # try:
        #     seq_len = embs.size(1)  # 获取序列长度
        #     # 裁剪 subject_spans 和 object_spans 中的索引值
        #     subject_spans = torch.clamp(subject_spans, 0, seq_len - 1)
        #     object_spans = torch.clamp(object_spans, 0, seq_len - 1)
        #     sub_embs = embs.gather(dim=1,index=subject_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
        #     obj_embs = embs.gather(dim=1,index=object_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
        # except Exception as e:
        #     print(e)
        #     print(subject_spans)
        # 取出区间内的所有token的emb
        try:
            seq_len = embs.size(1)  # 获取序列长度
            # 裁剪 subject_spans 和 object_spans 中的索引值
            subject_spans = torch.clamp(subject_spans, 0, seq_len - 1)
            object_spans = torch.clamp(object_spans, 0, seq_len - 1)

            batch_size = embs.size(0)
            sub_embs_list = []
            obj_embs_list = []

            # 遍历每个批次
            for i in range(batch_size):
                sub_start, sub_end = subject_spans[i]
                obj_start, obj_end = object_spans[i]

                # 提取区间内的所有 token
                sub_embs_batch = embs[i, sub_start:sub_end + 1, :]
                obj_embs_batch = embs[i, obj_start:obj_end + 1, :]

                sub_embs_list.append(sub_embs_batch)
                obj_embs_list.append(obj_embs_batch)

            # 使用 pad_sequence 对不同长度的序列进行填充
            sub_embs = torch.nn.utils.rnn.pad_sequence(sub_embs_list, batch_first=True)
            obj_embs = torch.nn.utils.rnn.pad_sequence(obj_embs_list, batch_first=True)

        except Exception as e:
            print(e)
            print(subject_spans)

        # 0722 fix
        # sub_embs = self.sub_rnn(sub_embs)[-1].permute(1, 0, 2) # (B, 2, D)
        output_sub, sub_embs = self.sub_rnn(sub_embs)
        output_obj, obj_embs = self.obj_rnn(obj_embs)
        sub_embs = sub_embs[-1] # B * D
        obj_embs = obj_embs[-1] # B * D
        
        if self.apply_mask :
            sub_embs_mask = torch.rand(sub_embs.size(0), 1, device=sub_embs.device) < 0.1
            sub_embs = sub_embs.masked_fill(sub_embs_mask, 0)  # (B, D)
            obj_embs_mask = torch.rand(obj_embs.size(0), 1, device=obj_embs.device) < 0.1
            obj_embs = obj_embs.masked_fill(obj_embs_mask, 0)  # (B, D)

        return sub_embs, obj_embs

    def get_attn(self, attn, start, end):
        """
        根据 start_indices 和 end_indices 从 attn 中提取注意力分数。

        参数：
        - attn: (B, H, S, S) 的张量，表示注意力矩阵。
        - start_indices: (B, E) 的张量，表示每个实体的起始索引。
        - end_indices: (B, E) 的张量，表示每个实体的结束索引。

        返回：
        - attn_scores: (B, E, H) 的张量，表示每个实体的注意力分数。
        """
        batch_size, num_heads, seq_len, _ = attn.shape
        start = start.clamp(min=0, max=seq_len - 1)  # (B, E)
        end = end.clamp(min=0, max=seq_len - 1)  # (B, E)
        token_num = start.size(1)  # E
        # 扩展 start_indices 和 end_indices 的维度
        start_indices_expanded = start.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, -1, seq_len)  # (B, H, E, S)
        end_indices_expanded = end.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, token_num, 1)  # (B, H, E, 1)

        # 从 attn 中提取起始索引的值
        attn_start = torch.gather(attn, dim=2, index=start_indices_expanded)  # (B, H, E, S)
        # 从 attn_start 中提取结束索引的值
        attn_end = torch.gather(attn_start, dim=3, index=end_indices_expanded)  # (B, H, E, 1)

        # 调整维度为 (B, E, H)
        # attn_scores = attn_end.squeeze(-1).squeeze(-1) # (B, H)
        attn_scores = attn_end.squeeze(-1).permute(0, 2, 1)  # (B, E, H)
        return attn_scores

    def get_avg_attn(self, attn, start_indices, end_indices):
        batch_size, num_heads, seq_len, _ = attn.shape
        num_token = start_indices.size(1)

        # 确保索引在有效范围内
        start_indices = start_indices.clamp(min=0, max=seq_len - 1)  # (B, E)
        end_indices = end_indices.clamp(min=0, max=seq_len - 1)      # (B, E)

        # 存储每个实体的注意力分数
        # 我们希望最终得到 (B, E, H)
        avg_attn_scores = []

        # 遍历每个批次中的每个实体
        for b in range(batch_size):
            start_begin = start_indices[b, 0]
            strat_end = start_indices[b, -1]
            end = end_indices[b, 0]
            # 1. 获取 Span 内所有 Token 的索引 注意：这里假设end是包含的，所以范围是 [start, end]
            span_token_indices = torch.arange(start_begin, strat_end + 1, device=attn.device) # (span_len,)
            # current_span_attn = attn[b, :, span_token_indices, :] # (H, span_len, S)
            current_span_attn = attn[b, :, span_token_indices, end] # (H, span_len, S)
            span_attn_end_avg = torch.mean(current_span_attn, dim=1) # (H,)
            avg_attn_scores.append(span_attn_end_avg)
        # 将所有批次的结果堆叠起来
        attn_scores = torch.stack(avg_attn_scores, dim=0) 
        return attn_scores

    # def forward(self,input_ids,attention_mask,labels):
    def forward(self,input_ids,attention_mask,span_info,labels):
        outputs = self.embedding.forward(input_ids,attention_mask,output_attentions=True)
        embs, attn = outputs.last_hidden_state,outputs.attentions[-1] # embs: B * S * D, attn: B * Head * S * S
        # batch_size, num_heads, seq_len, _ = attn.shape  # attn 维度 (B, Head, S, S)
        relations = span_info
        # multi_hot_labels = self.convert_to_multi_hot(labels, num_classes=24)
        # multi_hot_labels = labels.clone().detach()  # 直接使用标签
        multi_hot_labels = labels.to(self.embedding.device).to(torch.float32)  # 确保标签在正确的设备上
        all_subj_span = []
        all_obj_span = []

        batch_size = input_ids.size(0)
        sequence_lengths = attention_mask.sum(dim=1) - 1  # 计算有效 token 的索引 
        last_token_index = sequence_lengths.unsqueeze(-1)  # (B, 1)

        last_token = embs[torch.arange(batch_size), sequence_lengths]  # (B, D)
        if self.add_entity_type:
            all_sub_entity_type = []
            all_obj_entity_type = []
        for rel in relations:
            subj_span = rel['subj_tok_span']  # (start, end)
            obj_span = rel['obj_tok_span']  # (start, end)
            all_subj_span.append(subj_span)
            all_obj_span.append(obj_span)
            if self.add_entity_type:
                all_sub_entity_type.append(self.entity_type_dict[rel['sub_label']])
                all_obj_entity_type.append(self.entity_type_dict[rel['obj_label']])
                sub_type_index = torch.tensor(all_sub_entity_type).to(self.embedding.device)
                obj_type_index = torch.tensor(all_obj_entity_type).to(self.embedding.device)
                sub_type_embs = self.entity_type_embedding(sub_type_index)
                obj_type_embs = self.entity_type_embedding(obj_type_index)
        # 转换为 tensor
        all_subj_span = torch.tensor(all_subj_span, dtype=torch.long, device=embs.device)  # (B, 2)
        all_obj_span = torch.tensor(all_obj_span, dtype=torch.long, device=embs.device)  # (B, 2)

        # sub_embs, obj_embs, sub_context_emb,obj_context_emb, attn_lastoken = self.get_span_and_content_emb(embs, all_subj_span, all_obj_span, attn=attn)  # (B, entity, D)
        # cls_mode = "last"

        # sub_embs, obj_embs, context_emb, attn_content = self.get_span_and_content_attn_emb(embs, all_subj_span, all_obj_span, attn=attn)  # (B, D)
        if self.fused_method == "cat(cls||(atten1+span1)||(atten2+span2))RNN":
            sub_embs, obj_embs = self.get_rnn_emb(embs, all_subj_span, all_obj_span, attn=attn)
        else:
            sub_embs, obj_embs = self.get_span_and_content_attn_emb(embs, all_subj_span, all_obj_span, attn=attn)

        sub_avg_attn = self.get_avg_attn(attn, start_indices=all_subj_span, end_indices=last_token_index)  # (B, E, H)
        obj_avg_attn = self.get_avg_attn(attn, start_indices=all_obj_span, end_indices=last_token_index)  # (B, E, H)

        sub_attn = self.get_attn(attn, start=all_subj_span, end = last_token_index ).squeeze(1) # (B *2 * H)
        obj_attn = self.get_attn(attn, start=all_obj_span, end = last_token_index ).squeeze(1) # (B *2 *H)

        sub_head_attn = sub_attn[:,0,:].squeeze(1)  # (B, H)
        sub_tail_attn = sub_attn[:,1,:].squeeze(1)  # (B, H)
        obj_head_attn = obj_attn[:,0,:].squeeze(1)  # (B, H)
        obj_tail_attn = obj_attn[:,1,:].squeeze(1)  # (B, H)

        # sub_attn = self.get_attn(attn, start=all_subj_span, end = last_token_index ).squeeze(1) # (B * H)
        # obj_attn = self.get_attn(attn, start=all_obj_span, end = last_token_index ).squeeze(1) # (B * H)
        # if cls_mode == "last":
        #     cls_indice = attention_mask.sum(dim=1) - 1  # 获取每个样本的最后一个 token 的索引
        #     context_emb = embs[torch.arange(embs.size(0)), cls_indice]  # (B, D)
        
        fused_method = self.fused_method

        if fused_method == "mlp(last_token||attention)||emb":
            context_emb = last_token
            # attn_content = torch.cat([sub_attn, obj_attn], dim=-1) # (B, 2H)
            # # context_attn_mlp_emb = self.mlp(torch.cat([context_emb, attn_content], dim=-1))  # (B, D + H) -> (B, D) # 融合attn_score
            # context_attn_mlp_emb = self.mlp(torch.cat([context_emb, attn_content], dim=-1))
            # cat_embs = torch.cat([sub_embs, obj_embs,context_attn_mlp_emb], dim=-1)  # (B, 3 * D)

            # 0803
            # "mlp(last_token||attention) add emb"
            sub_cls_attn = self.attn_mlp(torch.cat([sub_head_attn,sub_tail_attn,sub_avg_attn,context_emb],dim= -1))
            obj_cls_attn = self.attn_mlp(torch.cat([obj_head_attn,obj_tail_attn,obj_avg_attn,context_emb],dim= -1))
            

            if self.add_entity_type:
                # cat_embs = torch.cat([sub_embs + sub_cls_attn + sub_type_embs ,obj_embs +obj_cls_attn + obj_type_embs],dim = -1) # (B, 2 * D)
                cat_embs = torch.cat([sub_embs + sub_cls_attn + sub_type_embs ,obj_embs +obj_cls_attn + obj_type_embs, context_emb ],dim = -1) # (B, 3 * D)
            else:
                cat_embs = torch.cat([sub_embs + sub_cls_attn ,obj_embs +obj_cls_attn, context_emb],dim = -1) # (B, 2 * D)

            # cat_embs = torch.cat([sub_embs + sub_cls_attn ,obj_embs +obj_cls_attn, context_emb],dim = -1) # (B, 2 * D)

        elif fused_method == "mlp(last_token||mlp(attention))||emb":
            context_emb = last_token
            # attn_content = self.attn_mlp(torch.cat([sub_attn, obj_attn], dim=-1)) # (B, 2H) -> (B, D)
            # context_attn_mlp_emb = self.mlp(torch.cat([context_emb,attn_content],dim = -1)) # (B, 2*D) -> (B, D)
            # cat_embs = torch.cat([sub_embs, obj_embs,context_attn_mlp_emb], dim=-1)  # (B, 3 * D)

            # 0803
            attn = self.attn_mlp(torch.cat([sub_head_attn,sub_tail_attn,sub_avg_attn,obj_head_attn,obj_tail_attn,obj_avg_attn],dim = -1)) #(B * 6H) -> (B * D)
            context_attn_mlp_emb = self.mlp(torch.cat([context_emb,attn],dim = -1)) # (B, 2*D) -> (B, D)
            cat_embs = torch.cat([sub_embs, obj_embs,context_attn_mlp_emb], dim=-1)  # (B, 3 * D)

        elif fused_method == "mlp(last_token||attention)||add(emb)":
            context_emb = last_token
            attn_content = torch.cat([sub_attn, obj_attn], dim=-1) # (B, 2H)
            # context_attn_mlp_emb = self.mlp(torch.cat([context_emb, attn_content], dim=-1))  # (B, D + H) -> (B, D) # 融合attn_score
            context_attn_mlp_emb = self.mlp(torch.cat([context_emb, attn_content], dim=-1))
            entity_embs = sub_embs + obj_embs
            cat_embs = torch.cat([entity_embs,context_attn_mlp_emb], dim=-1)  # (B, 2 * D)
        elif fused_method == "cat(cls||(atten1+dropout(span1))@(atten2+dropout(span2)))":
            context_emb = last_token
            context_emb = self.mlp(context_emb) # 0713 add
            sub_attn = self.attn_mlp(sub_attn) #(B * H) -> (B * D)
            obj_attn = self.attn_mlp(obj_attn) #(B * H) -> (B * D)
            # 在batch维度对embs 进行随机mask,embs 形状：B * D
            # sub_embs_mask = torch.rand(embs.size(0), 1, device=embs.device) < 0.1
            # sub_embs = sub_embs.masked_fill(sub_embs_mask, 0)  # (B, D)
            # obj_embs_mask = torch.rand(embs.size(0), 1, device=embs.device) < 0.1
            # obj_embs = obj_embs.masked_fill(obj_embs_mask, 0)  # (B, D)
            # sub_embs = self.dropout(sub_embs)
            # obj_embs = self.dropout(obj_embs)
            # cat_embs = self.mlp(torch.cat([context_emb, sub_embs + sub_attn, obj_embs + obj_attn])) # (B * 3D)
            # cat_embs = torch.cat([context_emb, sub_embs + sub_attn, obj_embs + obj_attn],dim = -1) # (B * 3D)
            # cat_embs = torch.cat([context_emb, sub_embs + sub_attn + obj_embs + obj_attn],dim = -1) # (B * 2D)
            # cat_embs = torch.cat([context_emb, sub_embs + sub_attn - obj_embs - obj_attn],dim = -1) # (B * 2D)
            # cat_embs = torch.cat([context_emb, sub_embs + sub_attn , obj_embs + obj_attn],dim = -1) # (B * 3D)
            cat_embs = torch.cat([context_emb, sub_embs + sub_tail_attn , obj_embs + obj_tail_attn],dim = -1) # (B * 3D)
        elif fused_method == "cat(cls||(atten1+span1)||(atten2+span2))":
            context_emb = last_token
            context_emb = self.mlp(context_emb)
            # left cat right cat avg
            sub_attn = self.attn_mlp(torch.cat([sub_head_attn,sub_tail_attn,sub_avg_attn],dim = -1)) #(B * 3H) -> (B * D)
            obj_attn = self.attn_mlp(torch.cat([obj_head_attn,obj_tail_attn,obj_avg_attn],dim = -1)) #(B * 3H) -> (B * D)
            # left cat right
            # sub_attn = self.attn_mlp(torch.cat([sub_head_attn,sub_tail_attn],dim = -1)) #(B * 2H) -> (B * D)
            # obj_attn = self.attn_mlp(torch.cat([obj_head_attn,obj_tail_attn],dim = -1)) #(B * 2H) -> (B * D)
            # cat_embs = torch.cat([context_emb, sub_embs + sub_attn , obj_embs + obj_attn],dim = -1) # (B * 3D)
            if self.add_entity_type:
                cat_embs = torch.cat([context_emb, sub_embs + sub_attn + sub_type_embs , obj_embs + obj_attn + obj_type_embs],dim = -1) # (B * 3D)
            else:
                cat_embs = torch.cat([context_emb, sub_embs + sub_attn , obj_embs + obj_attn],dim = -1) # (B * 3D)

        elif fused_method == "cat(cls||(atten1+span1)||(atten2+span2))RNN":
            context_emb = last_token
            context_emb = self.mlp(context_emb)
            sub_attn = self.attn_mlp(torch.cat([sub_head_attn,sub_tail_attn,sub_avg_attn],dim = -1)) #(B * 3H) -> (B * D)
            obj_attn = self.attn_mlp(torch.cat([obj_head_attn,obj_tail_attn,obj_avg_attn],dim = -1)) #(B * 3H) -> (B * D)
            cat_embs = torch.cat([context_emb, sub_embs + sub_attn , obj_embs + obj_attn],dim = -1) # (B * 3D)


        out = self.classifier(cat_embs)
        # loss = self.loss_fn(out,relation_label)
        loss = self.loss_fn(out,multi_hot_labels)
        return {"loss":loss,"logits":out}