from transformers import LlamaForCausalLM
from transformers.utils import add_start_docstrings_to_model_forward
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
from typing import Optional, Tuple, Union, List, Dict
import torch
import torch.nn.functional as F
from transformers.modeling_outputs import ModelOutput
from ...extras.constants import IGNORE_INDEX
from dataclasses import dataclass
import torch.nn as nn
import copy
from pytorch_metric_learning import losses
import requests
import json
import time
def call_remote_teacher_service(input_ids, attention_mask, special_token_num,
                                url='http://11.171.194.62:8000/inference', 
                                max_retries=1, retry_interval=2):
    """
    调用远程teacher模型服务获取hidden_states。
    
    Args:
        input_ids (torch.Tensor): 输入的input_ids, shape [batch_size, seq_len]
        attention_mask (torch.Tensor): attention_mask, shape [batch_size, seq_len]
        pooling_method (str): pooling方法，默认取'last'
        url (str): 服务地址
        max_retries (int): 最大重试次数
        retry_interval (int): 重试间隔时间(秒)
        
    Returns:
        torch.Tensor: hidden_states, shape [batch_size, hidden_size]
    """
    input_ids_cpu = input_ids.cpu().tolist()
    attention_mask_cpu = attention_mask.cpu().tolist()

    data = {
        "input_ids": input_ids_cpu,
        "attention_mask": attention_mask_cpu,
        "special_token_num": special_token_num
    }

    headers = {'Content-Type': 'application/json'}

    for attempt in range(max_retries):
        try:
            response = requests.post(url, json=data, headers=headers, timeout=30)
            if response.status_code == 200:
                hidden_states = response.json()['last_hidden_state']
                hidden_tensor = torch.tensor(hidden_states, dtype=torch.float32).to(input_ids.device)
                return hidden_tensor
            else:
                print(f"Attempt {attempt+1}/{max_retries} failed, status code: {response.status_code}, response: {response.text}")
        except requests.exceptions.RequestException as e:
            print(f"Attempt {attempt+1}/{max_retries} encountered exception: {e}")

        # 等待后重试
        time.sleep(retry_interval)

    raise RuntimeError(f"Remote teacher service failed after {max_retries} attempts.")


@dataclass
class MultipleChoiceModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    loss_dict: Optional[Dict[str, torch.FloatTensor]] = None




class CustormerLlamaForMultiEmb(LlamaForCausalLM):
    def __init__(self, config):
        self.in_batch_cl = False # 控制是否增加in batch 对比损失
        self.cl_temperature = 0.1 # 控制in batch 对比损失 温度系数
        self.config=config
        super().__init__(config)
        # self._init_teacher_model()
        '''gth修改'''
        self.supcon_loss = losses.SupConLoss(temperature=self.cl_temperature)
        self.emb_token_ids = list(range(128256, 128256+int(config.special_tokens_num)))  # 假设已添加到tokenizer
        self.special_tokens_num=int(config.special_tokens_num)
        self.add_cl=config.add_teacher_cl
        self.if_freeze_teacher=config.if_freeze_teacher
        self.main_loss_type=config.main_loss_type
        '''gth修改 end'''


    def _gather_tensors(self, tensor):
        """分布式收集各卡 tensor 并拼接"""
        if not torch.distributed.is_initialized():
            return tensor  # 单卡情况直接返回
        
        # 保证 tensor 连续内存存储
        tensor = tensor.contiguous()
        
        world_size = torch.distributed.get_world_size()
        # 确保各卡 tensor 形状一致
        gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
        torch.distributed.all_gather(gathered_tensors, tensor)
        return torch.cat(gathered_tensors, dim=0)
    

    '''gth 修改'''
    def in_batch_contrastive_loss(self, student_emb, teacher_emb, temperature):
        # 收集各卡的embedding
        gathered_student = self._gather_tensors(student_emb) # [world_size*B, D]
        gathered_teacher = self._gather_tensors(teacher_emb)   # [world_size*B, D]
        
        with torch.cuda.amp.autocast(enabled=False):
            gathered_student = gathered_student.float()
            gathered_teacher = gathered_teacher.float()
            
            # 计算跨卡相似度矩阵
            similarity = F.cosine_similarity(
                gathered_student.unsqueeze(1), 
                gathered_teacher.unsqueeze(0), 
                dim=2
            ) / temperature
            
            # 生成全局标签
            total_batch = gathered_student.size(0)
            labels = torch.arange(total_batch, device=similarity.device)
            # 打印各卡的 tensor 形状和内容            
            

            # 对称对比损失计算
            row_loss = F.cross_entropy(similarity, labels)
            col_loss = F.cross_entropy(similarity.t(), labels)
            return (row_loss + col_loss) / 2

    def get_label_hidden_mean(self, hidden_states, labels, mask_token=128009):
        """计算labels序列中非-100位置的hidden states均值"""
        # 创建mask（非-100的位置为True）
        mask = (labels != mask_token).unsqueeze(-1)  # (batch_size, seq_len, 1)
        # 应用mask并计算均值
        masked_hidden = hidden_states * mask
        sum_hidden = masked_hidden.sum(dim=1)  # (batch_size, hidden_size)
        valid_counts = mask.sum(dim=1)         # (batch_size, 1)
        # 避免除以零
        valid_counts = torch.clamp(valid_counts, min=1e-9)
        return sum_hidden / valid_counts  # (batch_size, hidden_size)

    def get_special_tokens_hidden_mean(self, hidden_states):
        """计算special tokens的hidden states均值"""
        # 假设special tokens位于序列末尾的固定位置
        special_hidden = hidden_states[:, -self.special_tokens_num:, :]  # (batch_size, num_special, hidden_size)
        return special_hidden.mean(dim=1)  # (batch_size, hidden_size)
    '''gth修改 end'''

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        original_labels: Optional[torch.LongTensor] = None,
        labels_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, vocab_size)`, *optional*):
                对每个sequence, labels表示其生成每个keyword token的分类。
                label只用0，1。从数据集上保证不会出现keyword出现次数过多情况？
        """
        expanded_input_ids = torch.cat([input_ids, input_ids], dim=0)
        expanded_attention_mask = torch.cat([attention_mask, attention_mask], dim=0)
        if position_ids is not None:
            expanded_position_ids = torch.cat([position_ids, position_ids], dim=0)
        else:
            expanded_position_ids = None
        
        # 单次前向计算
        outputs = self.model(
            input_ids=expanded_input_ids,
            attention_mask=expanded_attention_mask,
            position_ids=expanded_position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=True,
            output_hidden_states=True,
            return_dict=True,
        )
        
        # 分割结果
        hidden_states = outputs.hidden_states[-1]
        original_batch = input_ids.size(0)
        hidden_states1, hidden_states2 = torch.split(hidden_states, original_batch, dim=0)

        outputs_answer = self.model(
            input_ids=original_labels,
            attention_mask=labels_attention_mask,
            output_attentions=True,
            output_hidden_states=True,
            return_dict=True,
        )
        hidden_states_answer = outputs_answer.hidden_states[-1] # last hidden states
        
    
        loss_dict = {}
        '''gth 修改'''

        
        modified_input_ids = input_ids.clone()
        labels_tail=labels.clone()[:, -self.special_tokens_num:]
        labels_tail[labels_tail==-100] = 128009
        modified_input_ids[:, -self.special_tokens_num:] =labels_tail

        if self.add_cl:
            # 使用clone避免修改原始input_ids

            if self.if_freeze_teacher:
                try:
                    hidden_states_teacher = call_remote_teacher_service(
                        input_ids=modified_input_ids,
                        attention_mask=attention_mask,
                        special_token_num=self.special_tokens_num
                    )
                    hidden_states_teacher=torch.tensor(hidden_states_teacher, dtype=hidden_states1.dtype).to(modified_input_ids.device)
                except Exception as e:
                    hidden_states_teacher = hidden_states1.clone().detach()
            else:
                with torch.no_grad():
                    outputs_teacher = self.model(
                        input_ids=modified_input_ids,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        past_key_values=past_key_values,
                        inputs_embeds=inputs_embeds,
                        use_cache=use_cache,
                        output_attentions=True,
                        output_hidden_states=True,
                        return_dict=True,
                    )
                hidden_states_teacher = outputs_teacher.hidden_states[-1].detach() # last hidden states

            inputLast_index=-self.special_tokens_num-1
            specialFirst_index = -self.special_tokens_num  # 或者 -self.special_tokens_num - 1

            '''V1 改'''
        
            # 获取所有需要对比的嵌入
            # TSL
            embedding_student1 = hidden_states1[:, inputLast_index, :]  # (B, D)
            embedding_student2 = hidden_states2[:, inputLast_index, :]  # (B, D)
            embedding_teacher = hidden_states_teacher[:, inputLast_index, :].detach()  # (B, D)
            embedding_answer = hidden_states_answer[:, -1, :].detach()  # (B, D)

            # TSM
            # embedding_student1 =  F.normalize(self.get_special_tokens_hidden_mean(hidden_states1), p=2, dim=-1)
            # embedding_student2 =  F.normalize(self.get_special_tokens_hidden_mean(hidden_states2), p=2, dim=-1)
            # # embedding_teacher =  F.normalize(self.get_special_tokens_hidden_mean(hidden_states_teacher), p=2, dim=-1)
            # embedding_answer =  F.normalize(self.get_special_tokens_hidden_mean(hidden_states_answer), p=2, dim=-1)

            # 聚合所有视图 (4 views per sample)
            # 构建对比视图
            views = [
                embedding_student1.unsqueeze(1),  # (B,1,D)
                embedding_student2.unsqueeze(1),  # (B,1,D)
                embedding_teacher.unsqueeze(1)
                embedding_answer.unsqueeze(1),   # (B,1,D)
            ]
            all_embeddings = torch.cat(views, dim=1)  # (B,3,D)
            all_embeddings = all_embeddings.view(-1, all_embeddings.size(-1))  # (B*4, D)

            batch_size = embedding_student1.size(0)

            '''非all gather版本'''
            # 创建样本级标签 (每个样本的4个视图共享同一个标签)
            labels_CL = torch.arange(batch_size, device=all_embeddings.device)
            labels_CL = labels_CL.repeat_interleave(4)  # (B*3,)

            # 计算监督对比损失
            sup_con_loss = self.supcon_loss(all_embeddings.float(), labels_CL)

            '''all gather版本'''
            # # 生成全局唯一标签
            # if torch.distributed.is_initialized():
            #     rank = torch.distributed.get_rank()
            # else:
            #     rank = 0
            # offset = rank * batch_size  # 计算当前进程的偏移量

            # # 创建样本级标签 (每个样本的4个视图共享同一个全局唯一标签)
            # labels_CL = torch.arange(batch_size, device=all_embeddings.device) + offset
            # labels_CL = labels_CL.repeat_interleave(4)  # (B*4,)

            # # 分布式收集所有进程的embeddings和标签
            # all_embeddings_gathered = self._gather_tensors(all_embeddings)  # (World_Size*B*4, D)
            # # 计算监督对比损失（使用全局数据）
            # labels_CL_gathered = self._gather_tensors(labels_CL)          # (World_Size*B*4,)

            # # 计算监督对比损失（使用全局聚合后的数据）
            # sup_con_loss = self.supcon_loss(all_embeddings_gathered.float(), labels_CL_gathered)


            '''else'''
            teacher_cl_loss= sup_con_loss 
            # 记录损失值
            loss_dict["supcon_loss"] = sup_con_loss.detach().cpu()
            '''V1 改 end'''

        '''gth修改 end'''

        loss = None
        if labels is not None:
            '''gth修改'''

            # ========== 核心修改部分 ==========
            # 1. 定位特殊token位置
            batch_size, seq_len = input_ids.shape

            hidden_states=hidden_states1
            # 2. 提取对应位置的隐藏状态
            special_token_mask = (input_ids >= self.emb_token_ids[0]) & \
                        (input_ids <= self.emb_token_ids[-1])
            # '''one-step'''
            # special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)  # 初始化为全False
            # special_token_mask[:, -self.special_tokens_num-1:] = True  # 将末尾的n个位置设为True

            emb_hidden = hidden_states[special_token_mask]  # [batch*16, hidden_size]
            emb_hidden = emb_hidden.view(batch_size, self.special_tokens_num, -1)  # [batch, 8, hidden_size]
            # 3. 计算预测logits
            emb_logits = self.lm_head(emb_hidden)  # [batch, 16, vocab_size]

            logits = emb_logits.view(batch_size, -1)  # [batch, 16 * vocab_size]

            if 'KL' in self.main_loss_type or 'MSE' in self.main_loss_type:
                try:
                    hidden_states_teacher = call_remote_teacher_service(
                        input_ids=modified_input_ids,
                        attention_mask=attention_mask,
                        special_token_num=self.special_tokens_num
                    )
                    hidden_states_teacher=torch.tensor(hidden_states_teacher, dtype=hidden_states.dtype).to(modified_input_ids.device)
                except Exception as e:
                    hidden_states_teacher = hidden_states.clone().detach()


            if 'CE' in self.main_loss_type:

                # 4. 获取目标标签（最后16个有效token）
                target_labels = labels[:, -self.special_tokens_num:]  # [batch, 16]
                # '''one-step'''
                # target_labels = labels[:, -self.special_tokens_num-1:]  # [batch, 16]

                # 5. 计算交叉熵损失
                loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

                # 展平 emb_logits 和 target_labels
                emb_logits_flat = emb_logits.view(-1, emb_logits.size(-1))  # [batch * 16, vocab_size]
                target_labels_flat = target_labels.reshape(-1)  # [batch * 16]

                # 一次性计算所有位置的损失
                loss = loss_fct(emb_logits_flat, target_labels_flat)
                loss_dict['CEloss'] = loss.detach().cpu()
            if 'KL' in self.main_loss_type:
                emb_hidden_teacher = hidden_states_teacher[:, -self.special_tokens_num:, :].view(batch_size, self.special_tokens_num, -1)
                emb_logits_teacher = self.lm_head(emb_hidden_teacher)  # [batch, special_tokens_num, vocab_size]
                # 对embedding进行mask
                student_log_probs = F.log_softmax(emb_logits, dim=-1)  # [batch, seq, vocab]
                teacher_probs = F.softmax(emb_logits_teacher, dim=-1).detach()  # 分离梯度

                # 计算逐元素KL散度
                kl_per_element = F.kl_div(
                    student_log_probs, 
                    teacher_probs, 
                    reduction='none', 
                    log_target=False
                )  # [batch, seq, vocab]
                kl_per_position = kl_per_element.sum(dim=-1)  # [batch, seq]
                mask = (labels[:, -self.special_tokens_num:] != -100).float()  # [batch, seq]
                # 应用mask并计算平均损失
                kl_loss = (kl_per_position * mask).sum() / mask.sum()
                loss = loss+kl_loss if loss else kl_loss
                loss_dict['KLDivLoss'] = kl_loss.detach().cpu()
            if 'MSE' in self.main_loss_type:
                # 获取教师模型的归一化嵌入
                emb_hidden_teacher = hidden_states_teacher[:, -self.special_tokens_num:, :].view(batch_size, self.special_tokens_num, -1)
                emb_hidden_teacher_norm = F.normalize(emb_hidden_teacher, p=2, dim=-1)
                
                # 定义计算单个视图MSE的函数
                def compute_view_mse(hidden_states_view):
                    # 提取当前视图的特殊token嵌入并归一化
                    emb_view = hidden_states_view[special_token_mask].view(batch_size, self.special_tokens_num, -1)
                    emb_view_norm = F.normalize(emb_view, p=2, dim=-1)
                    
                    # 计算与教师模型的MSE
                    mse_loss = F.mse_loss(
                        emb_view_norm,
                        emb_hidden_teacher_norm,
                        reduction='none'
                    )
                    return mse_loss
                
                # 同时计算两个学生视图的MSE
                mse_loss1 = compute_view_mse(hidden_states1)  # 第一个视图
                mse_loss2 = compute_view_mse(hidden_states2)  # 第二个视图
                
                # 创建mask并扩展维度 [batch, seq, 1]
                mask = (labels[:, -self.special_tokens_num:] != -100).float().unsqueeze(-1)
                
                # 对称损失计算：平均两个视图的损失
                total_mse = (mse_loss1 * mask).sum() + (mse_loss2 * mask).sum()
                total_mse /= (2 * mask.sum())  # 取两个视图的平均
                
                loss = loss + total_mse if loss else total_mse
                loss_dict['MSELoss'] = total_mse.detach().cpu()
            if self.add_cl:
                loss+=teacher_cl_loss

            '''gth修改 end'''


        # if not return_dict:
        #     output = (logits,) + outputs[1:]
        #     return (loss,) + output if loss is not None else output

        return MultipleChoiceModelOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            loss_dict=loss_dict,
        )