# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
# from transformers import Qwen2ForCausalLM, Qwen2Tokenizer,Qwen2Model, AutoModel
from transformers import Qwen3Model

class Relation_Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Relation_Classifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        # self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        # out = self.softmax(out)
        return out
    
    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)


class Relevance_model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Relevance_model, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)



class Qwen_relevence_atten_model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size = 24, llm = None):
        super(Qwen_relevence_atten_model, self).__init__()
        if not llm:
            # self.embedding = Qwen2Model.from_pretrained(
            #     "Qwen/Qwen2.5-0.5B-Instruct",
            #     # torch_dtype="auto",
            #     # device_map="auto"
            # ).to(torch.bfloat16)
            self.embedding = Qwen3Model.from_pretrained(
                # "Qwen/Qwen3-0.6B",
                "Qwen/Qwen3-Embedding-0.6B",
                attn_implementation="eager"
                # torch_dtype="auto",
                # device_map="auto"
            ).to(torch.bfloat16)
        else:
            self.embedding = llm


        # 读取模型dim 以及 attention head
        self.emb_dim = self.embedding.config.hidden_size
        self.emb_attn_head = self.embedding.config.num_attention_heads
        emb_dim = self.emb_dim
        emb_attn_head = self.emb_attn_head
        # self.fused_method = "mlp(last_token||attention)||emb"
        self.fused_method = "cat(cls||(atten1+dropout(span1))@(atten2+dropout(span2)))"
        # self.fused_method = "mlp(last_token||mlp(attention))||emb"
        # self.fused_method = "mlp(last_token||attention)||add(emb)"
        # self.fused_method = "other"
        if self.fused_method == "mlp(last_token||attention)||emb" or self.fused_method == "mlp(last_token||attention)||add(emb)":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim + emb_attn_head * 2, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)

            self.attn_mlp = nn.Sequential(
                nn.Linear(emb_attn_head *2 , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim) 
            ).to(torch.bfloat16)

        elif self.fused_method == "mlp(last_token||mlp(attention))||emb":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim * 2, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)
        elif self.fused_method == "cat(cls||(atten1+dropout(span1))@(atten2+dropout(span2)))":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)

            self.attn_mlp = nn.Sequential(
                nn.Linear(emb_attn_head, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim) 
            ).to(torch.bfloat16)

            self.dropout = nn.Dropout(0.01)
        else:
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)

        if  self.fused_method == "mlp(last_token||attention)||add(emb)":
            input_size = input_size //3 * 2
        self.classifier = Relation_Classifier(input_size, hidden_size, output_size).to(self.embedding.device).to(torch.bfloat16)
        # self.L_mlp = nn.Sequential(
        #         nn.Linear(emb_dim , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim)
        #     ).to(torch.bfloat16)
        # self.R_mlp = nn.Sequential(
        #         nn.Linear(emb_dim , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim) 
        #     ).to(torch.bfloat16)

        self.entity_mlp = nn.Sequential(
                nn.Linear(emb_dim , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim) 
            ).to(torch.bfloat16)

        # self.attn_mlp = nn.Sequential(
        #         nn.Linear(emb_attn_head *2 , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim) 
        #     ).to(torch.bfloat16)
        # self.eL_mlp = nn.Sequential(
        #         nn.Linear(emb_dim , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim)
        #     ).to(torch.bfloat16)
        # self.eR_mlp = nn.Sequential(
        #         nn.Linear(emb_dim , emb_dim),
        #         nn.ReLU(),
        #         nn.Linear(emb_dim, emb_dim) 
        #     ).to(torch.bfloat16)

        self.attn_emb_mlp = nn.Sequential(
                nn.Linear(emb_dim + emb_attn_head , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)
        # self.loss_fn = nn.CrossEntropyLoss()
        self.loss_fn = nn.BCEWithLogitsLoss()



    # def get_span_and_content_emb(self, embs, entity_spans, attn, context_window=5):
    def get_span_and_content_emb(self, embs, subject_spans, object_spans, attn, context_window=5):
        """
        计算实体的平均 embedding、上下文 embedding 和注意力分数。

        参数：
        - embs: (batch_size, seq_len, embedding_dim) 的张量，表示 token 级别的 embedding。
        - subject_spans: (batch_size, 2) 的张量
        - object_spans: (batch_size, 2) 的张量
        - attn: (batch_size, num_heads, seq_len, seq_len) 的张量，表示注意力分数。
        - context_window: int，表示左右边界的 token 数量。

        返回：
        - entity_embs_avg: (batch_size, num_entities, embedding_dim) 的张量，每个 batch 保持自己的实体数。
        - context_embs_avg: (batch_size, num_entities, embedding_dim) 的张量，每个实体的上下文 embedding。
        - attention_scores: (batch_size, num_entities, num_heads) 的张量，每个实体的注意力分数。
        """

        batch_size,  _ = subject_spans.shape
        _, seq_len, embedding_dim = embs.shape

        # 计算实体的起始和结束索引
        subject_spans_start = subject_spans[:,0]
        subject_spans_end = subject_spans[:,1]
        object_spans_start = object_spans[:,0]
        object_spans_end = object_spans[:,1]

        # 计算上下文的左右边界
        sub_left_indices  = torch.clamp(subject_spans_start - context_window, min=0)  # 左边界，防止越界
        sub_right_indices = torch.clamp(subject_spans_end   + context_window, max=seq_len - 1)  # 右边界，防止越界

        obj_left_indices  = torch.clamp(object_spans_start - context_window, min=0)  # 左边界，防止越界
        obj_right_indices = torch.clamp(object_spans_end   + context_window, max=seq_len - 1)  # 右边界，防止越界
        
        # entity embs 
        sub_embs = embs.gather(dim=1,index=subject_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
        obj_embs = embs.gather(dim=1,index=object_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
        sub_left_embs = self.eL_mlp(sub_embs[:, 0, :])  # (B, D)
        sub_right_embs = self.eR_mlp(sub_embs[:, 1, :])  # (B, D)
        sub_embs = (sub_left_embs + sub_right_embs) / 2  # (B, D)

        obj_left_embs = self.eL_mlp(obj_embs[:, 0, :])  # (B, D)
        obj_right_embs = self.eR_mlp(obj_embs[:, 1, :])  # (B, D)
        obj_embs = (obj_left_embs + obj_right_embs) / 2  # (B, D)


        # context embs
        sub_context_left_embs = self.L_mlp(embs.gather(dim=1, 
                                                       index=sub_left_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1,embs.size(-1)))  ).squeeze(1)  # (B, D)
        sub_context_right_embs = self.R_mlp(embs.gather(dim=1, 
                                                        index=sub_right_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, embs.size(-1))) 
                                                        ).squeeze(1)  # (B, D)
        sub_content_embs = (sub_context_left_embs + sub_context_right_embs) / 2

        obj_context_left_embs = self.L_mlp(embs.gather(dim=1, 
                                                       index=obj_left_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, embs.size(-1)))
                                                       ).squeeze(1)  # (B, D)
        obj_context_right_embs = self.R_mlp(embs.gather(dim=1, 
                                                        index=obj_right_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, embs.size(-1)))
                                                        ).squeeze(1)  # (B, D)
        obj_content_embs = (obj_context_left_embs + obj_context_right_embs) / 2

        attn_score = self.get_attn(attn, start = sub_right_indices.unsqueeze(-1), end = obj_right_indices.unsqueeze(-1))  # (B, E, H)
        attn_score = attn_score.squeeze(1) # B H
        # print(attn_score.shape)
        
        return sub_embs, obj_embs, sub_content_embs, obj_content_embs, attn_score  # 返回实体和上下文的 embedding

    def get_span_and_content_attn_emb(self, embs, subject_spans, object_spans, attn, context_window=5):
        """
        计算实体的平均 embedding、上下文 embedding 和注意力分数。

        参数：
        - embs: (batch_size, seq_len, embedding_dim) 的张量，表示 token 级别的 embedding。
        - subject_spans: (batch_size, 2) 的张量
        - object_spans: (batch_size, 2) 的张量
        - attn: (batch_size, num_heads, seq_len, seq_len) 的张量，表示注意力分数。
        - context_window: int，表示左右边界的 token 数量。

        返回：
        - entity_embs_avg: (batch_size, num_entities, embedding_dim) 的张量，每个 batch 保持自己的实体数。
        - context_embs_avg: (batch_size, num_entities, embedding_dim) 的张量，每个实体的上下文 embedding。
        - attention_scores: (batch_size, num_entities, num_heads) 的张量，每个实体的注意力分数。
        """

        # Batch_size, seq_len, embedding_dim = embs.shape

        # sequence_lengths = attention_mask.sum(dim=1) - 1  # 计算有效 token 的索引
        # last_token = embs[torch.arange(batch_size), sequence_lengths]  # (B, D)

        # # 计算实体的起始和结束索引
        # subject_spans_start = subject_spans[:,0]
        # subject_spans_end = subject_spans[:,1]
        # object_spans_start = object_spans[:,0]
        # object_spans_end = object_spans[:,1]

        # 计算上下文的左右边界
        # sub_left_indices  = torch.clamp(subject_spans_start - context_window, min=0)  # 左边界，防止越界
        # sub_right_indices = torch.clamp(subject_spans_end   + context_window, max=seq_len - 1)  # 右边界，防止越界

        # obj_left_indices  = torch.clamp(object_spans_start - context_window, min=0)  # 左边界，防止越界
        # obj_right_indices = torch.clamp(object_spans_end   + context_window, max=seq_len - 1)  # 右边界，防止越界
        
        # entity embs fix
        try:
            seq_len = embs.size(1)  # 获取序列长度
            # 裁剪 subject_spans 和 object_spans 中的索引值
            subject_spans = torch.clamp(subject_spans, 0, seq_len - 1)
            object_spans = torch.clamp(object_spans, 0, seq_len - 1)
            sub_embs = embs.gather(dim=1,index=subject_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
            obj_embs = embs.gather(dim=1,index=object_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))  # (B, 2, D)
        except:
            print(subject_spans)

        # sub_embs = self.entity_mlp(sub_embs[:,0,:] + sub_embs[:, 1, :])  # (B, D)
        # obj_embs = self.entity_mlp(obj_embs[:,0,:] + obj_embs[:, 1, :])  # (B, D)

        sub_embs = sub_embs[:,0,:] - sub_embs[:, 1, :]
        obj_embs = obj_embs[:,0,:] - obj_embs[:, 1, :]


        # sub_left_embs = self.eL_mlp(sub_embs[:, 0, :])  # (B, D)
        # sub_right_embs = self.eR_mlp(sub_embs[:, 1, :])  # (B, D)
        # sub_embs = (sub_left_embs + sub_right_embs) / 2  # (B, D)

        # obj_left_embs = self.eL_mlp(obj_embs[:, 0, :])  # (B, D)
        # obj_right_embs = self.eR_mlp(obj_embs[:, 1, :])  # (B, D)
        # obj_embs = (obj_left_embs + obj_right_embs) / 2  # (B, D)


        # context embs

        # # content_left_indices = torch.min(subject_spans_start, object_spans_start)
        # content_left_indices = torch.min(sub_left_indices, obj_left_indices)
        # content_right_indices = torch.max(sub_right_indices, obj_right_indices)

        # content_left_embs = self.L_mlp(embs.gather(dim=1,index=content_left_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1,embs.size(-1)))  ).squeeze(1)  # (B, D)
        # content_right_embs = self.R_mlp(embs.gather(dim=1,index=content_right_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, embs.size(-1)))).squeeze(1)  # (B, D)
        # content_embs = (content_left_embs + content_right_embs) / 2  # (B, D)
        # content_attn = self.get_attn(attn, start = content_left_indices.unsqueeze(-1), end= content_right_indices.unsqueeze(-1)).squeeze(1)  # (B, H)
        # content_attn_emb = self.attn_emb_mlp(torch.cat([content_embs, content_attn], dim=-1))  # (B, D + H) -> (B, D) # 融合attn_score

        # attn_score = self.get_attn(attn, start = sub_right_indices.unsqueeze(-1), end = obj_right_indices.unsqueeze(-1))  # (B, E, H)
        # attn_score = attn_score.squeeze(1) # B H
        # print(attn_score.shape)
        
        # return sub_embs, obj_embs, content_attn_emb, attn_score  # 返回实体和上下文的 embedding
        return sub_embs, obj_embs

    def get_attn(self, attn, start, end):
        """
        根据 start_indices 和 end_indices 从 attn 中提取注意力分数。

        参数：
        - attn: (B, H, S, S) 的张量，表示注意力矩阵。
        - start_indices: (B, E) 的张量，表示每个实体的起始索引。
        - end_indices: (B, E) 的张量，表示每个实体的结束索引。

        返回：
        - attn_scores: (B, E, H) 的张量，表示每个实体的注意力分数。
        """
        batch_size, num_heads, seq_len, _ = attn.shape

        start = start.clamp(min=0, max=seq_len - 1)
        end = end.clamp(min=0, max=seq_len - 1)
        # 扩展 start_indices 和 end_indices 的维度
        start_indices_expanded = start.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, -1, seq_len)  # (B, H, E, S)
        end_indices_expanded = end.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, -1, 1)  # (B, H, E, 1)

        # 从 attn 中提取起始索引的值
        attn_start = torch.gather(attn, dim=2, index=start_indices_expanded)  # (B, H, E, S)
        # 从 attn_start 中提取结束索引的值
        attn_end = torch.gather(attn_start, dim=3, index=end_indices_expanded)  # (B, H, E, 1)

        # 调整维度为 (B, E, H)
        # attn_scores = attn_end.squeeze(-1).squeeze(-1) # (B, H)
        attn_scores = attn_end.squeeze(-1).permute(0, 2, 1)  # (B, E, H)
        return attn_scores



    # def forward(self,input_ids,attention_mask,labels):
    def forward(self,input_ids,attention_mask,span_info,labels):
        outputs = self.embedding.forward(input_ids,attention_mask,output_attentions=True)
        embs, attn = outputs.last_hidden_state,outputs.attentions[-1] # embs: B * S * D, attn: B * Head * S * S
        # batch_size, num_heads, seq_len, _ = attn.shape  # attn 维度 (B, Head, S, S)
        relations = span_info
        # multi_hot_labels = self.convert_to_multi_hot(labels, num_classes=24)
        # multi_hot_labels = labels.clone().detach()  # 直接使用标签
        multi_hot_labels = labels.to(self.embedding.device).to(torch.float32)  # 确保标签在正确的设备上
        all_subj_span = []
        all_obj_span = []

        batch_size = input_ids.size(0)
        sequence_lengths = attention_mask.sum(dim=1) - 1  # 计算有效 token 的索引 
        last_token_index = sequence_lengths.unsqueeze(-1)  # (B, 1)

        last_token = embs[torch.arange(batch_size), sequence_lengths]  # (B, D)

        for rel in relations:
            subj_span = rel['subj_tok_span']  # (start, end)
            obj_span = rel['obj_tok_span']  # (start, end)
            all_subj_span.append(subj_span)
            all_obj_span.append(obj_span)
        # 转换为 tensor
        all_subj_span = torch.tensor(all_subj_span, dtype=torch.long, device=embs.device)  # (B, 2)
        all_obj_span = torch.tensor(all_obj_span, dtype=torch.long, device=embs.device)  # (B, 2)
        sub_embs, obj_embs = self.get_span_and_content_attn_emb(embs, all_subj_span, all_obj_span, attn=attn)

        sub_attn = self.get_attn(attn, start=all_subj_span, end = last_token_index ).squeeze(1) # (B * H)
        obj_attn = self.get_attn(attn, start=all_obj_span, end = last_token_index ).squeeze(1) # (B * H)
        # if cls_mode == "last":
        #     cls_indice = attention_mask.sum(dim=1) - 1  # 获取每个样本的最后一个 token 的索引
        #     context_emb = embs[torch.arange(embs.size(0)), cls_indice]  # (B, D)
        
        fused_method = self.fused_method

        if fused_method == "mlp(last_token||attention)||emb":
            context_emb = last_token
            attn_content = torch.cat([sub_attn, obj_attn], dim=-1) # (B, 2H)
            # context_attn_mlp_emb = self.mlp(torch.cat([context_emb, attn_content], dim=-1))  # (B, D + H) -> (B, D) # 融合attn_score
            context_attn_mlp_emb = self.mlp(torch.cat([context_emb, attn_content], dim=-1))
            cat_embs = torch.cat([sub_embs, obj_embs,context_attn_mlp_emb], dim=-1)  # (B, 3 * D)
        elif fused_method == "mlp(last_token||mlp(attention))||emb":
            context_emb = last_token
            attn_content = self.attn_mlp(torch.cat([sub_attn, obj_attn], dim=-1)) # (B, 2H) -> (B, D)
            context_attn_mlp_emb = self.mlp(torch.cat([context_emb,attn_content],dim = -1)) # (B, 2*D) -> (B, D)
            cat_embs = torch.cat([sub_embs, obj_embs,context_attn_mlp_emb], dim=-1)  # (B, 3 * D)
        elif fused_method == "mlp(last_token||attention)||add(emb)":
            context_emb = last_token
            attn_content = torch.cat([sub_attn, obj_attn], dim=-1) # (B, 2H)
            # context_attn_mlp_emb = self.mlp(torch.cat([context_emb, attn_content], dim=-1))  # (B, D + H) -> (B, D) # 融合attn_score
            context_attn_mlp_emb = self.mlp(torch.cat([context_emb, attn_content], dim=-1))
            entity_embs = sub_embs + obj_embs
            cat_embs = torch.cat([entity_embs,context_attn_mlp_emb], dim=-1)  # (B, 2 * D)
        elif fused_method == "cat(cls||(atten1+dropout(span1))@(atten2+dropout(span2)))":
            context_emb = last_token
            context_emb = self.mlp(context_emb) # 0713 add
            sub_attn = self.attn_mlp(sub_attn) #(B * H) -> (B * D)
            obj_attn = self.attn_mlp(obj_attn) #(B * H) -> (B * D)
            sub_embs = self.dropout(sub_embs)
            obj_embs = self.dropout(obj_embs)
            # cat_embs = self.mlp(torch.cat([context_emb, sub_embs + sub_attn, obj_embs + obj_attn])) # (B * 3D)
            # cat_embs = torch.cat([context_emb, sub_embs + sub_attn, obj_embs + obj_attn],dim = -1) # (B * 3D)
            cat_embs = torch.cat([context_emb, sub_embs + sub_attn + obj_embs + obj_attn],dim = -1) # (B * 2D)


        out = self.classifier(cat_embs)
        # loss = self.loss_fn(out,relation_label)
        loss = self.loss_fn(out,multi_hot_labels)
        return {"loss":loss,"logits":out}
