import torch
import torch.nn as nn
import copy

import math
def tensor_prompt(a, b, c=None,ortho=False):
    #将这个张量转换为 PyTorch Parameter 对象，并将 requires_grad 设为 True，表示在反向传播过程中将对这个张量计算梯度。
    if c is None:
        p = torch.nn.Parameter(torch.FloatTensor(a, b), requires_grad=True)
        if ortho:
            nn.init.orthogonal_(p)#使用正交矩阵将张量进行初始化
        else:
            nn.init.kaiming_uniform_(p, a=math.sqrt(5))
    else:#torch.Size([5, 10, 768])
        p = torch.nn.Parameter(torch.FloatTensor(a, b, c), requires_grad=True)
        if ortho:
            nn.init.orthogonal_(p)#使用正交矩阵将张量进行初始化
        else:
            for index in range(a):
                nn.init.kaiming_uniform_(p[index], a=math.sqrt(5))
    #nn.init.uniform_(p)#使用均匀分布从中提取随机值对张量进行初始化
    return p



class PrefixKeqV(nn.Module):
    def __init__(self, emb_d, n_tasks, e_p_length, e_pool_size, e_layers):
        super().__init__()
        self.task_count = 0
        self.emb_d = emb_d # 输入特征向量维度
        
        self.n_tasks = n_tasks # 全部任务数量
        self._init_smart( e_p_length,e_pool_size,e_layers)

        
        # e prompt init
        for e in self.e_layers:
            # 按池数量去初始化，对于dualPrompt每个任务对应一个e-prompt,对于L2p来说不需要一对一
            # 采用按任务数量去创建p，便于锁定不是当前任务的参数
            p = tensor_prompt(self.e_pool_size, self.e_p_length, emb_d)
            setattr(self, f'e_p_{e}', p)

    def _init_smart(self, e_p_length, e_pool_size, e_layers):

        # self.top_k = 1
        # self.task_id_bootstrap = True #训练模式下启用 任务id引导，测试下根据相似度找最相近的

        # prompt locations 
        # g_prompt插入层序号
        # self.g_layers = [0, 1]
        # self.e_layers = [2, 3, 4]#e_prompt插入层序号
        self.e_layers = e_layers
        # prompt parameter args:
        #    arg 1 = e-prompt pool size (# tasks) 等于实验设置的任务数量，每个任务对应一个e-prompt
        #    arg 2 = e-prompt pool length
        #    arg 3 = g-prompt pool length
        self.e_p_length = e_p_length
        self.e_pool_size = e_pool_size

    def process_task_count(self):#任务数量加一
        self.task_count += 1
    '''
        
        其中 batch_size 表示查询向量的数量，emb_d 表示每个查询向量的特征向量维度。
        l:一个整数，表示层序号。
        
        train:一个布尔值，表示是否在训练模式下。
        task_id:一个整数，表示任务 ID(训练),或者选择的任务序列（测试推理）
    '''
    def forward(self, l,batch_size, task_id=None):

        p_return = None
        #判断当前层是否需要使用 e-prompt
        if l in self.e_layers:
            
            p = getattr(self, f'e_p_{l}')  # 0 based indexing here 取出该层的e-prompt
            if task_id is None:
                return None
            if isinstance(task_id, int):
            #loss = (1.0 - cos_sim[:, task_id]).sum()#本批次数据是在同一task_id下，选出该task对应的相似度，损失是1-相似度 的和
                P_ = p[task_id].expand(batch_size, -1, -1)#该任务对应的e-prompt按批次扩展维度，变成batch_size,self.e_p_length, emb_d    
            else:
                # if self.task_count>0:
                #     print(1)
                P_ = p[task_id]#拼成batch_size,top_k,e_p_length, emb_d 
                
            # select prompts 将本批次的每个e-prompt拆分为Ek和Ev，用于preix tuning
            #if train:
            # i = int(self.e_p_length/2)
            # Ek = P_[:, :i, :].reshape((batch_size, -1, self.emb_d)) 
            # Ev = P_[:, i:, :].reshape((batch_size, -1, self.emb_d))
            #else:
                # i = int(self.e_p_length/2)
                # Ek = P_[:, :, :i, :].reshape((batch_size, -1, self.emb_d))
                # Ev = P_[:, :, i:, :].reshape((batch_size, -1, self.emb_d))
            p_return = P_
    
        return p_return

class PrefixKneqV(nn.Module):
    def __init__(self, emb_d, n_tasks, e_p_length, e_pool_size, e_layers):
        super().__init__()
        self.task_count = 0
        self.emb_d = emb_d # 输入特征向量维度
        
        self.n_tasks = n_tasks # 全部任务数量
        self._init_smart( e_p_length,e_pool_size,e_layers)

        
        # e prompt init
        for e in self.e_layers:
            # 按池数量去初始化，对于dualPrompt每个任务对应一个e-prompt,对于L2p来说不需要一对一
            # 采用按任务数量去创建p，便于锁定不是当前任务的参数
            p = tensor_prompt(self.e_pool_size, self.e_p_length, emb_d)
            setattr(self, f'e_p_{e}', p)

    def _init_smart(self, e_p_length, e_pool_size, e_layers):

        # self.top_k = 1
        # self.task_id_bootstrap = True #训练模式下启用 任务id引导，测试下根据相似度找最相近的

        # prompt locations 
        # g_prompt插入层序号
        # self.g_layers = [0, 1]
        # self.e_layers = [2, 3, 4]#e_prompt插入层序号
        self.e_layers = e_layers
        # prompt parameter args:
        #    arg 1 = e-prompt pool size (# tasks) 等于实验设置的任务数量，每个任务对应一个e-prompt
        #    arg 2 = e-prompt pool length
        #    arg 3 = g-prompt pool length
        self.e_p_length = e_p_length
        self.e_pool_size = e_pool_size

    def process_task_count(self):#任务数量加一
        self.task_count += 1
    '''
        
        其中 batch_size 表示查询向量的数量，emb_d 表示每个查询向量的特征向量维度。
        l:一个整数，表示层序号。
        
        train:一个布尔值，表示是否在训练模式下。
        task_id:一个整数，表示任务 ID(训练),或者选择的任务序列（测试推理）
    '''
    def forward(self, l,batch_size, task_id=None):

        p_return = None
        #判断当前层是否需要使用 e-prompt
        if l in self.e_layers:
            
            p = getattr(self, f'e_p_{l}')  # 0 based indexing here 取出该层的e-prompt
            if task_id is None:
                return None
            if isinstance(task_id, int):
            #loss = (1.0 - cos_sim[:, task_id]).sum()#本批次数据是在同一task_id下，选出该task对应的相似度，损失是1-相似度 的和
                P_ = p[task_id].expand(batch_size, -1, -1)#该任务对应的e-prompt按批次扩展维度，变成batch_size,self.e_p_length, emb_d    
            else:
                # if self.task_count>0:
                #     print(1)
                P_ = p[task_id]#拼成batch_size,top_k,e_p_length, emb_d 
            i = int(self.e_p_length/2)
            Ek = P_[:, :i, :].reshape((batch_size, -1, self.emb_d)) 
            Ev = P_[:, i:, :].reshape((batch_size, -1, self.emb_d)) 
            # select prompts 将本批次的每个e-prompt拆分为Ek和Ev，用于preix tuning
            #if train:
            # i = int(self.e_p_length/2)
            # Ek = P_[:, :i, :].reshape((batch_size, -1, self.emb_d)) 
            # Ev = P_[:, i:, :].reshape((batch_size, -1, self.emb_d))
            #else:
                # i = int(self.e_p_length/2)
                # Ek = P_[:, :, :i, :].reshape((batch_size, -1, self.emb_d))
                # Ev = P_[:, :, i:, :].reshape((batch_size, -1, self.emb_d))
            p_return = [Ek, Ev]
    
        return p_return


