import torch
import torch.nn as nn
import torch.nn.functional as F
# from TorchCRF import CRF
from .layers.crf import CRF

from transformers import BertModel,BertPreTrainedModel
from .layers.linears import PoolerEndLogits, PoolerStartLogits
from torch.nn import CrossEntropyLoss
from losses.focal_loss import FocalLoss
from losses.label_smoothing import LabelSmoothingCrossEntropy

class BertSoftmaxForNer(BertPreTrainedModel):
    def __init__(self, config):
        super(BertSoftmaxForNer, self).__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.loss_type = config.loss_type
        self.init_weights()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,labels=None):
        outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:
            assert self.loss_type in ['lsr', 'focal', 'ce']
            if self.loss_type == 'lsr':
                loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)
            elif self.loss_type == 'focal':
                loss_fct = FocalLoss(ignore_index=0)
            else:
                loss_fct = CrossEntropyLoss(ignore_index=0)
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs
        return outputs  # (loss), scores, (hidden_states), (attentions)

#Bert+clsmlp
class BertCrfForNerO(BertPreTrainedModel):
    def __init__(self, config):
        super(BertCrfForNer, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 添加 MLP �?
        self.encoder_hidden_size = config.hidden_size
        self.num_labels = config.num_labels
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size * 2, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, config.hidden_size)
        )
        self.mlp_dropout = nn.Dropout(config.hidden_dropout_prob)  # MLP 层的 Dropout
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(num_tags=config.num_labels, batch_first=True)


        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,labels=None):
        outputs =self.bert(
            input_ids = input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_attentions=True,  # 获取注意力矩�?
            )
        sequence_output = outputs[0]
        # 应用 BERT 层的 Dropout
        sequence_output = self.dropout(sequence_output)
        #------------------------------------------CLS和注意力融合--------------------------------------------------------#
        # 注意力矩阵，最后一层的每个头的注意�?
        attentions = outputs[-1]  
        # 提取最后一层的注意力矩�?
        last_attention = attentions[-1]  # shape: (batch_size, num_heads, seq_len, seq_len)
        # 将注意力矩阵展平�? (batch_size, num_heads * seq_len, seq_len)
        batch_size, num_heads, seq_len, _ = last_attention.size()

        # 提取 [CLS] 的注意力相关性分数（从所有头部中平均�?
        cls_attention = last_attention[:, :, 0, :]  # shape: (batch_size, num_heads, seq_len)
        # Commented by yuanwu : cls_attention = cls_attention.mean(dim=1)  # shape: (batch_size, seq_len)##
        ## cls_attention = cls_attention / self.encoder_hidden_size  # 归一化处�?####

        # 提取 [CLS] 的表�?
        cls_embedding = sequence_output[:, 0, :]  # shape: (batch_size, hidden_size)
                # 计算加权�? CLS 向量
        # attention_weights = cls_attention.unsqueeze(2)  # shape: (batch_size, seq_len, 1)
        cls_embedding_expanded = cls_embedding.unsqueeze(1)  # shape: (batch_size, 1, hidden_size)
        # Comment by Yuanwu : weighted_cls = cls_embedding_expanded * attention_weights  # shape: (batch_size, seq_len, hidden_size)
        # 拼接�? BERT 输出和加权的 CLS 向量
        concatenated_output = torch.cat([sequence_output, weighted_cls], dim=-1)  # shape: (batch_size, seq_len, hidden_size * 2)


        # 输入�? MLP �?
        mlp_output = self.mlp(concatenated_output)  # shape: (batch_size, seq_len, hidden_size)
        # 应用 Dropout
        mlp_output = self.mlp_dropout(mlp_output)  
        

        logits = self.classifier(mlp_output)  # shape: (batch_size, seq_len, num_labels)

        outputs = (logits,)
        if labels is not None:
            loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
            outputs =(-1*loss,)+outputs
        return outputs # (loss), scores

class BertCrfForNerOLD(BertPreTrainedModel):
    def __init__(self, config):
        super(BertCrfForNer, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        # self.crf = CRF(config.num_labels)
        self.crf = CRF(num_tags=config.num_labels, batch_first=True)
        
        # 定义 MLP 层
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size + config.num_attention_heads, config.hidden_size * 2),
            nn.ReLU(),
            nn.Linear(config.hidden_size * 2, config.hidden_size)
        )
        
        # 定义线性层
        self.linear = nn.Linear(config.hidden_size * 2, config.hidden_size)
        
        # 定义注意力分数加权参数
        #self.attention_weight = nn.Parameter(torch.ones(1, config.num_attention_heads, 1))

        self.init_weights()

        # 初始化：小方差正态分布
        torch.nn.init.normal_(self.classifier.weight, mean=0, std=0.01)
        torch.nn.init.constant_(self.classifier.bias, 0.0) 
        
        # self.crf.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        # 调用 BERT 模型，确保返回注意力分数
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_attentions=True
        )
        
        # 提取序列输出和注意力分数
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        sequence_output = self.dropout(sequence_output)
        
        # 提取 CLS token 的输出特征
        cls_output = sequence_output[:, 0, :]  # [batch_size, hidden_size]
        
        # 提取最后一层的注意力分数
        attentions = outputs.attentions  # 获取所有层的注意力分数
        last_layer_attentions = attentions[-1]  # 取最后一层的注意力分数 [batch_size, num_heads, seq_len, seq_len]
        
        # 提取 CLS token 对其他 token 的注意力分数
        cls_attentions = last_layer_attentions[:, :, 0, :]  # [batch_size, num_heads, seq_len]
        
        # 对注意力分数进行加权
        #weighted_cls_attentions = cls_attentions * self.attention_weight  # [batch_size, num_heads, seq_len]
        
        # 扩展 CLS token 的特征以匹配注意力分数的维度
        batch_size, seq_len, hidden_size = sequence_output.size()
        cls_output_expanded = cls_output.unsqueeze(1).expand(-1, seq_len, -1)  # [batch_size, seq_len, hidden_size]
        
        # 重排注意力分数的维度
        cls_attentions_reshaped = cls_attentions.permute(0, 2, 1)  # [batch_size, seq_len, num_heads]
        
        # 拼接 CLS 特征和加权注意力分数
        combined_features = torch.cat([cls_output_expanded, cls_attentions_reshaped], dim=-1)  # [batch_size, seq_len, hidden_size + num_heads]
        
        # 通过 MLP 层处理拼接后的特征
        mlp_output = self.mlp(combined_features)  # [batch_size, seq_len, 1]
        
        # 将 MLP 输出与 token 自身的输出特征拼接
        combined_output = torch.cat([sequence_output, mlp_output], dim=-1)  # [batch_size, seq_len, hidden_size + 1]
        # combined_output = sequence_output + mlp_output
        # 通过线性层映射回原始维度
        new_sequence_output = self.linear(combined_output)  # [batch_size, seq_len, hidden_size]
        
        # 通过分类器得到 logits
        base = True
        if base:
            logits = self.classifier(sequence_output)
        else:
            logits = self.classifier(new_sequence_output)  # [batch_size, seq_len, num_labels]
        # print(f"{(logits * attention_mask.byte()[:,:,None]).sum(dim=-1)}")
        outputs = (logits,)
        if labels is not None:
            loss = self.crf(logits, labels, mask=attention_mask)
            print(f"loss:",loss)
            outputs = (-1 * loss,) + outputs
        
        return outputs  # (loss), scores

class BertCrfForNerRaw(BertPreTrainedModel):

    def __init__(self, config):
        super(BertCrfForNer, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(num_tags=config.num_labels, batch_first=True)
        
        
        # 定义 MLP 层
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size + config.num_attention_heads, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
        # 定义线性层
        self.linear = nn.Linear(config.hidden_size + 1, config.hidden_size)
        self.linears = nn.Linear(config.hidden_size , config.hidden_size)
        self.norm = nn.LayerNorm(config.hidden_size) #归一化层
        self.activ2 = nn.GELU()

        # 定义MLM任务相关层
        self.vocab_size = config.vocab_size  # 添加vocab_size
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size,bias=False)
        #self.decoder_bias = nn.Parameter(torch.zeros(config.vocab_size))
        #self.decoder.weight = self.bert.embeddings.word_embeddings.weight

        # 定义注意力分数加权参数
        self.attention_weight = nn.Parameter(torch.ones(1, config.num_attention_heads, 1))
        
        self.init_weights()
    
    def get_mlm_parameters(self):
        """返回所有MLM任务专用的参数"""
        mlm_params = []
        # MLM专用层的参数
        mlm_params.extend(list(self.linears.parameters()))
        mlm_params.extend(list(self.norm.parameters()))
        mlm_params.extend(list(self.decoder.parameters()))
        #mlm_params.append(self.decoder_bias)
        return mlm_params

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,masked_pos=None,masked_labels=None):
        # 调用 BERT 模型，确保返回注意力分数
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_attentions=True
        )
        
        # 提取序列输出和注意力分数
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        sequence_output = self.dropout(sequence_output)
        
        # 提取 CLS token 的输出特征
        cls_output = sequence_output[:, 0, :]  # [batch_size, hidden_size]
        
        # 提取最后一层的注意力分数
        attentions = outputs.attentions  # 获取所有层的注意力分数
        last_layer_attentions = attentions[-1]  # 取最后一层的注意力分数 [batch_size, num_heads, seq_len, seq_len]
        
        # 提取 CLS token 对其他 token 的注意力分数
        cls_attentions = last_layer_attentions[:, :, 0, :]  # [batch_size, num_heads, seq_len]
        
        # 对注意力分数进行加权
        weighted_cls_attentions = cls_attentions * self.attention_weight  # [batch_size, num_heads, seq_len]
        
        # 扩展 CLS token 的特征以匹配注意力分数的维度
        batch_size, seq_len, hidden_size = sequence_output.size()
        cls_output_expanded = cls_output.unsqueeze(1).expand(-1, seq_len, -1)  # [batch_size, seq_len, hidden_size]
        
        # 重排注意力分数的维度
        cls_attentions_reshaped = weighted_cls_attentions.permute(0, 2, 1)  # [batch_size, seq_len, num_heads]
        
        # 拼接 CLS 特征和加权注意力分数
        combined_features = torch.cat([cls_output_expanded, cls_attentions_reshaped], dim=-1)  # [batch_size, seq_len, hidden_size + num_heads]
        
        # 通过 MLP 层处理拼接后的特征
        mlp_output = self.mlp(combined_features)  # [batch_size, seq_len, 1]
        
        # 将 MLP 输出与 token 自身的输出特征拼接
        combined_output = torch.cat([sequence_output, mlp_output], dim=-1)  # [batch_size, seq_len, hidden_size + 1]
        
        # 通过线性层映射回原始维度
        new_sequence_output = self.linear(combined_output)  # [batch_size, seq_len, hidden_size]
        
        # 通过分类器得到 logits
        logits_clsf = self.classifier(new_sequence_output)  # [batch_size, seq_len, num_labels]
        
        masked_pos = masked_pos[:, :, None].expand(-1, -1, new_sequence_output.size(-1))
        # get masked position from final output of transformer.
        h_masked = torch.gather(sequence_output, 1, masked_pos) #在output取出一维对应masked_pos数据 masking position [batch_size, max_pred, d_model]
        #print("h_masked stats:", h_masked.mean(), h_masked.std())
        h_masked = self.norm(self.activ2(self.linears(h_masked)))
        #print("GELU output stats:", h_masked.mean(), h_masked.std())
        logits_lm = self.decoder(h_masked) #+ self.decoder_bias # [batch_size, max_pred, n_vocab]!!!!问题所在

        outputs = (logits_clsf,)
        #print("masked_pos最小值:", masked_pos.min().item(), "最大值:", masked_pos.max().item())
        #print("logits_lm shape:", logits_lm.shape, "masked_labels shape:", masked_labels.shape)
        #print("logits_lm均值:", logits_lm.mean().item(), "logits_lm标准差:", logits_lm.std().item())

        if labels is not None:
            loss_crf = self.crf(emissions=logits_clsf, tags=labels, mask=attention_mask)
            loss_mlm = nn.CrossEntropyLoss()(logits_lm.view(-1, self.vocab_size),masked_labels.view(-1))
            mlm_params = list(self.decoder.parameters()) #+ [self.decoder_bias]
            #constrained_mlm_loss = torch.sigmoid(loss_mlm) * 5  # sigmoid()
            total_loss = loss_crf -  loss_mlm
            print(f"TOTAL:{total_loss.item():.4f},CRF:{loss_crf.item():.4f}, MLM:{loss_mlm.item():.4f}")
                
        else:
            total_loss = loss_crf
        
        # 检查NaN和过大损失值
        if torch.isnan(total_loss):
            raise ValueError(f"Loss出现NaN! CRF:{loss_crf.item():.4f}, MLM:{loss_mlm.item():.4f}")
        elif total_loss > 10000:
            print(f"警告：损失值过大({total_loss.item():.4f})！CRF:{loss_crf.item():.4f}, MLM:{loss_mlm.item():.4f}")

            
        outputs = (-1 * total_loss,) + outputs

        return outputs  # (loss), scores


class BertCrfForNer(BertPreTrainedModel):

    def __init__(self, config):
        super(BertCrfForNer, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(num_tags=config.num_labels, batch_first=True)
        
        
        # 定义 Point Line Plane NER MLP 层
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size + config.num_attention_heads, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

        # 定义 Dynamic CRF MLP 层A 用于学习两个 token 属于同一个实体的概率评估分数
        
        #修改为定义 Dynamic CRF MLP 层A 用于学习两个 token 转移的三种分数 
        # 1、同属实体内部 2、处于实体边界 3、非实体
        self.mlp_tran_attn = nn.Sequential(
            nn.Linear(config.num_attention_heads, 256),
            nn.ReLU(),
            nn.Linear(256, 3)
        )
        # 定义 Dynamic CRF MLP 层B 用于学习两个 token 不属于同一个实体的概率评估分数
        self.mlp_diff_entity = nn.Sequential(
            nn.Linear(config.num_attention_heads, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
        # 定义线性层
        self.linear = nn.Linear(config.hidden_size + 1, config.hidden_size)
        self.linears = nn.Linear(config.hidden_size , config.hidden_size)
        self.norm = nn.LayerNorm(config.hidden_size) #归一化层
        self.activ2 = nn.GELU()

        # 定义MLM任务相关层
        self.vocab_size = config.vocab_size  # 添加vocab_size
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size,bias=False)
        #self.decoder_bias = nn.Parameter(torch.zeros(config.vocab_size))
        #self.decoder.weight = self.bert.embeddings.word_embeddings.weight

        # 定义注意力分数加权参数
        self.attention_weight = nn.Parameter(torch.ones(1, config.num_attention_heads, 1))
        
        self.init_weights()
    
    def get_mlm_parameters(self):
        """返回所有MLM任务专用的参数"""
        mlm_params = []
        # MLM专用层的参数
        mlm_params.extend(list(self.linears.parameters()))
        mlm_params.extend(list(self.norm.parameters()))
        mlm_params.extend(list(self.decoder.parameters()))
        #mlm_params.append(self.decoder_bias)
        return mlm_params

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,masked_pos=None,masked_labels=None):
        # 调用 BERT 模型，确保返回注意力分数
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_attentions=True
        )
        
        # 提取序列输出和注意力分数
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        sequence_output = self.dropout(sequence_output)
        
        # 提取 CLS token 的输出特征
        cls_output = sequence_output[:, 0, :]  # [batch_size, hidden_size]
        
        # 提取最后一层的注意力分数
        attentions = outputs.attentions  # 获取所有层的注意力分数
        last_layer_attentions = attentions[-1]  # 取最后一层的注意力分数 [batch_size, num_heads, seq_len, seq_len]
        
        # 提取 CLS token 对其他 token 的注意力分数
        cls_attentions = last_layer_attentions[:, :, 0, :]  # [batch_size, num_heads, seq_len]
        

        
        # 对注意力分数进行加权
        weighted_cls_attentions = cls_attentions * self.attention_weight  # [batch_size, num_heads, seq_len]
        
        # 扩展 CLS token 的特征以匹配注意力分数的维度
        batch_size, seq_len, hidden_size = sequence_output.size()
        cls_output_expanded = cls_output.unsqueeze(1).expand(-1, seq_len, -1)  # [batch_size, seq_len, hidden_size]
        
        # 重排注意力分数的维度
        cls_attentions_reshaped = weighted_cls_attentions.permute(0, 2, 1)  # [batch_size, seq_len, num_heads]
        
        # 拼接 CLS 特征和加权注意力分数
        combined_features = torch.cat([cls_output_expanded, cls_attentions_reshaped], dim=-1)  # [batch_size, seq_len, hidden_size + num_heads]

        # 通过 MLP 层处理拼接后的特征
        mlp_output = self.mlp(combined_features)  # [batch_size, seq_len, 1]
        
        # 将 MLP 输出与 token 自身的输出特征拼接
        combined_output = torch.cat([sequence_output, mlp_output], dim=-1)  # [batch_size, seq_len, hidden_size + 1]
        
        # 通过线性层映射回原始维度
        new_sequence_output = self.linear(combined_output)  # [batch_size, seq_len, hidden_size]
        
        # 通过分类器得到 logits
        logits_clsf = self.classifier(new_sequence_output)  # [batch_size, seq_len, num_labels]
        
        
        
        ####### Dynamic-CRF

        # 提取第1个真实token到倒数第二个token各自和[CLS]的注意力分数
        cls_attentions_1_to_n_1 = cls_attentions[:, :, 1:-1]  # [batch_size, num_heads, seq_len-2]
        
        # 提取第2个到最后一个token各自和[CLS]的注意力分数
        cls_attentions_2_to_n = cls_attentions[:, :, 2:]  # [batch_size, num_heads, seq_len-2]
        
        # 拼接这两组分数
        combined_attentions = torch.cat([cls_attentions_1_to_n_1, cls_attentions_2_to_n], dim=-1)  # [batch_size, num_heads, seq_len-2]
        
        # 通过新增的 MLP 层处理拼接后的特征
        combined_attentions = combined_attentions.permute(0, 2, 1)  #交换最后俩维度 [batch_size, seq_len-2, num_heads]
        
        # 通过 mlp_same_entity 计算两个 token 属于同一个实体的概率评估分数
        
        tran_attn_score = self.mlp_tran_attn(combined_attentions)  # [batch_size, seq_len-2, 3]
        # 通过 mlp_diff_entity 计算两个 token 不属于同一个实体的概率评估分数
        #diff_entity_scores = self.mlp_diff_entity(combined_attentions)  # [batch_size, seq_len-2, 1]

        #不需要softmax
        #same_entity_scores = F.softmax(same_entity_scores, dim=-1)
        #diff_entity_scores = F.softmax(diff_entity_scores, dim=-1)

        ####### Dynamic-CRF



        masked_pos = masked_pos[:, :, None].expand(-1, -1, new_sequence_output.size(-1))
        # get masked position from final output of transformer.
        h_masked = torch.gather(sequence_output, 1, masked_pos) #在output取出一维对应masked_pos数据 masking position [batch_size, max_pred, d_model]
        #print("h_masked stats:", h_masked.mean(), h_masked.std())
        h_masked = self.norm(self.activ2(self.linears(h_masked)))
        #print("GELU output stats:", h_masked.mean(), h_masked.std())
        logits_lm = self.decoder(h_masked) #+ self.decoder_bias # [batch_size, max_pred, n_vocab]!!!!问题所在

        outputs = (logits_clsf,)
        #print("masked_pos最小值:", masked_pos.min().item(), "最大值:", masked_pos.max().item())
        #print("logits_lm shape:", logits_lm.shape, "masked_labels shape:", masked_labels.shape)
        #print("logits_lm均值:", logits_lm.mean().item(), "logits_lm标准差:", logits_lm.std().item())

        if labels is not None:
            #loss_crf = self.crf(emissions=logits_clsf, tags=labels, mask=attention_mask)
            #loss_crf = self.crf(emissions=logits_clsf, tags=labels, mask=attention_mask,
            #     same_entity_scores=same_entity_scores, diff_entity_scores=diff_entity_scores)
            loss_crf = self.crf(emissions=logits_clsf, tags=labels, mask=attention_mask,
                 tran_attn_scores = tran_attn_scores)
            #loss_same_entity = 
            #loss_diff_entity = 
            loss_mlm = nn.CrossEntropyLoss()(logits_lm.view(-1, self.vocab_size),masked_labels.view(-1))
            mlm_params = list(self.decoder.parameters()) #+ [self.decoder_bias]
            #constrained_mlm_loss = torch.sigmoid(loss_mlm) * 5  # sigmoid()
            total_loss = loss_crf -  loss_mlm
            print(f"TOTAL:{total_loss.item():.4f},CRF:{loss_crf.item():.4f}, MLM:{loss_mlm.item():.4f}")
                
        else:
            total_loss = loss_crf
        
        # 检查NaN和过大损失值
        if torch.isnan(total_loss):
            raise ValueError(f"Loss出现NaN! CRF:{loss_crf.item():.4f}, MLM:{loss_mlm.item():.4f}")
        elif total_loss > 10000:
            print(f"警告：损失值过大({total_loss.item():.4f})！CRF:{loss_crf.item():.4f}, MLM:{loss_mlm.item():.4f}")

            
        outputs = (-1 * total_loss,) + tran_attn_scores + outputs

        return outputs  # (loss), scores




class BertCrfForNerNew1(BertPreTrainedModel):
    def __init__(self, config):
        super(BertCrfForNer, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(num_tags=config.num_labels, batch_first=True)
        # 新加的 MLP 层
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size + config.num_attention_heads, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
        # 新加的线性层
        self.linear = nn.Linear(config.hidden_size + 1, config.hidden_size)
        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,labels=None):
        outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,
                           output_attentions=True,  # 获取注意力矩�?
                           )
        
        # sequence_output 是 BERT 的最后一层隐藏状态
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        
        # 获取 CLS token 的输出特征
        cls_output = sequence_output[:, 0, :]  # [batch_size, hidden_size]
        
        # 获取注意力分数矩阵
        attentions = outputs.attentions  # 这是一个元组，包含每一层的注意力分数矩阵
        last_layer_attentions = attentions[-1]  # 取最后一层的注意力分数矩阵 [batch_size, num_heads, seq_len, seq_len]
        
        # 获取 CLS token 与其他 token 的注意力分数
        cls_attentions = last_layer_attentions[:, :, 0, :]  # [batch_size, num_heads, seq_len]
        
        # 将 CLS token 的输出特征与注意力分数拼接
        batch_size, seq_len, hidden_size = sequence_output.size()
        num_heads = cls_attentions.size(1)
        
        # 扩展 cls_output 以匹配注意力分数的维度
        cls_output_expanded = cls_output.unsqueeze(1).expand(-1, seq_len, -1)  # [batch_size, seq_len, hidden_size]
        
        # 将 cls_output 和 cls_attentions 拼接
        cls_attentions_reshaped = cls_attentions.permute(0, 2, 1)  # [batch_size, seq_len, num_heads]
        combined_features = torch.cat([cls_output_expanded, cls_attentions_reshaped], dim=-1)  # [batch_size, seq_len, hidden_size + num_heads]
        
        # 通过 MLP 层处理拼接后的特征
        mlp_output = self.mlp(combined_features)  # [batch_size, seq_len, 1]
        
        # 将 MLP 输出与 token 自身的输出特征拼接
        combined_output = torch.cat([sequence_output, mlp_output], dim=-1)  # [batch_size, seq_len, hidden_size + 1]
        
        # 通过线性层映射回原始维度
        new_sequence_output = self.linear(combined_output)  # [batch_size, seq_len, hidden_size]
        logits = self.classifier(sequence_output)
        outputs = (logits,)
        if labels is not None:
            loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
            outputs =(-1*loss,)+outputs
        return outputs # (loss), scores  


class BertCrfForNerOld(BertPreTrainedModel):
    def __init__(self, config):
        super(BertCrfForNer, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(num_tags=config.num_labels, batch_first=True)
        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,labels=None):
        outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        outputs = (logits,)
        if labels is not None:
            loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
            outputs =(-1*loss,)+outputs
        return outputs # (loss), scores  

class BertSpanForNer(BertPreTrainedModel):
    def __init__(self, config,):
        super(BertSpanForNer, self).__init__(config)
        self.soft_label = config.soft_label
        self.num_labels = config.num_labels
        self.loss_type = config.loss_type
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
        if self.soft_label:
            self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
        else:
            self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
        outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        start_logits = self.start_fc(sequence_output)
        if start_positions is not None and self.training:
            if self.soft_label:
                batch_size = input_ids.size(0)
                seq_len = input_ids.size(1)
                label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
                label_logits.zero_()
                label_logits = label_logits.to(input_ids.device)
                label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
            else:
                label_logits = start_positions.unsqueeze(2).float()
        else:
            label_logits = F.softmax(start_logits, -1)
            if not self.soft_label:
                label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
        end_logits = self.end_fc(sequence_output, label_logits)
        outputs = (start_logits, end_logits,) + outputs[2:]

        if start_positions is not None and end_positions is not None:
            assert self.loss_type in ['lsr', 'focal', 'ce']
            if self.loss_type =='lsr':
                loss_fct = LabelSmoothingCrossEntropy()
            elif self.loss_type == 'focal':
                loss_fct = FocalLoss()
            else:
                loss_fct = CrossEntropyLoss()
            start_logits = start_logits.view(-1, self.num_labels)
            end_logits = end_logits.view(-1, self.num_labels)
            active_loss = attention_mask.view(-1) == 1
            active_start_logits = start_logits[active_loss]
            active_end_logits = end_logits[active_loss]

            active_start_labels = start_positions.view(-1)[active_loss]
            active_end_labels = end_positions.view(-1)[active_loss]

            start_loss = loss_fct(active_start_logits, active_start_labels)
            end_loss = loss_fct(active_end_logits, active_end_labels)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs
        return outputs







###


        max_seq_len = max_position_embeddings
        pad_len = max_seq_len - seq_len
        last_attention = torch.nn.functional.pad(last_attention, (0, pad_len, 0, 0, 0, 0, 0, 0), value=0)
 
        # 提取 [CLS] 的注意力相关性分�?
        cls_attention = last_attention[:, :, 0, :]  # shape: (batch_size, num_heads, seq_len)
        cls_attention = cls_attention.reshape(batch_size, -1)  # 扁平化为 (batch_size, num_heads * seq_len)
        # 提取 [CLS] 的表�?
        cls_embedding = sequence_output[:, 0, :]  # shape: (batch_size, hidden_size)

        # 输入�? MLP �?
        mlp_input = torch.cat([cls_embedding, cls_attention], dim=-1)  # shape: (batch_size, hidden_size + num_heads * seq_len)
        mlp_output = self.mlp(mlp_input)  # shape: (batch_size, hidden_size)

        # 扩展 MLP 的输出到�? BERT 序列输出相同的序列长�?
        mlp_output_expanded = mlp_output.unsqueeze(1).expand(-1, seq_len, -1)  # shape: (batch_size, seq_len, hidden_size)
        # 拼接�? BERT 输出�? MLP 输出
        concatenated_output = torch.cat([sequence_output, mlp_output_expanded], dim=-1)  # shape: (batch_size, seq_len, hidden_size * 2)

###
   