import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.inits import glorot
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from common import FFN

from omegaconf import OmegaConf

def load_config(path=''):
    config= OmegaConf.load(path)
    return config




class SimplePrompt(nn.Module):

# flag =true -> nn.embeding else linear
    def __init__(self, config,flag=False):
        super().__init__()
        self.pnums=config.finetune.p_nums
        self.dims=config.dims
        self.flag=flag
        if flag:
            
            # self.tokens=nn.Embedding(config.cls.ag,config.dims*config.finetune.p_nums)
            self.ptokens=nn.Embedding(config.cls.ag,config.dims)
            # self.t_linear=nn.Linear(config.dims,config.dims)
            # self.p_linear=nn.Linear(config.dims,config.dims)
        else:
            # self.tokens=nn.Linear(config.cls.ag,config.dims*config.finetune.p_nums)
            self.ptokens=nn.Linear(config.cls.ag,config.dims)


    def forward(self,X,task_id):
        b,f,n,d=X.shape
        # batch rels
        task_id=task_id.unsqueeze(1).unsqueeze(1).expand(-1,f,n,-1)

        X=self.ptokens(task_id)+X
        
        return X

class GPFPlus(nn.Module):
# flag =true -> nn.embeding else linear
    def __init__(self, config,flag=False):
        super().__init__()
        self.pnums=config.finetune.p_nums
        self.dims=config.dims
        self.detach=config.gpfp.detach
        print('gfpf detach',self.detach)
        self.flag=flag

        self.tokens=nn.Linear(config.cls.ag,config.dims*config.finetune.p_nums)
        self.ptokens=nn.Linear(config.cls.ag,config.dims)

        self.net=nn.Sequential(nn.Linear(config.dims,config.finetune.p_nums),nn.Sigmoid())

    def forward(self,X_in,task_id,detach=False):
        if self.detach:
            X=X_in.detach()
        else:
            X=X_in
        b,f,n,d=X.shape
        
        # batch rels
        task_id=task_id.unsqueeze(1).unsqueeze(1).repeat(1,f,n,1)

        task_token=self.ptokens(task_id)+X

        weight=self.net(task_token).unsqueeze(-2)

        prompt=self.tokens(task_id).reshape(b,f,n,self.pnums,self.dims)

        prompt=weight@prompt
        prompt=prompt.squeeze(-2)
        return X_in+prompt


if __name__=='__main__':
    a=torch.randn(2,10,11,768)
    b=torch.randint(0,157,(2,157))
    config=load_config()
    gpf=GPFPlus(config,True)
    ans=gpf(a,b)
    print(ans.shape)

    

