
import torch.nn as nn

def tensor_prompt(a, b, c=None):
    if c is None:
        p=nn.Linear(b,a,bias=False)
        
    else:#torch.Size([5, 10, 768])
        p=nn.ModuleList([
            nn.Linear(c,b, bias=False)
            for i in range(a)
        ])
    return p



class PrefixOne(nn.Module):
    def __init__(self, emb_d, e_p_length, e_layers):
        super().__init__()
        #self.task_count = 0
        self.emb_d = emb_d 
        
        self.e_layers = e_layers
        
        self.e_p_length = e_p_length

        
        for e in self.e_layers:

            p = tensor_prompt(self.e_p_length, self.emb_d)
            setattr(self, f'e_p_{e}', p)

    
    def forward(self, l,batch_size):

        p_return = None
        if l in self.e_layers:
            
            p = getattr(self, f'e_p_{l}')  
            p_return = p.weight.expand(batch_size, -1, -1)
    
        return p_return

