import torch
import torch.nn.functional as F
import os
import sys
sys.path.append('')
from myutils.common import *
import scipy.sparse as sp
from torch import nn
import numpy as np
from myutils.common import *
import einops

from myutils.config import *


from gpnn2.gpnn2 import GPNN4 as GPNN5
from myutils.gpfplus import (SimplePrompt,GPFPlus,SinglePrompt)

def self_count(model,x,node):
    x=torch.randn(1,16,768)
    node=torch.randn(16,768)
    model(x,node)

class MixLayer(nn.Module):
    def __init__(self, config,mlp_flag=False):
        super().__init__()
        temporal_encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.hidden_size,
            nhead=config.transformer.heads,
            dim_feedforward=config.hidden_size * 4,
            dropout=config.dropout,
            activation="gelu",
            batch_first=True
        )
        self.temporal = nn.TransformerEncoder(
            encoder_layer=temporal_encoder_layer, num_layers=config.mix.tfm_layer
        )

        spatial_encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.hidden_size,
            nhead=config.transformer.heads,
            dim_feedforward=config.hidden_size * 4,
            dropout=config.dropout,
            activation="gelu",
            batch_first=True
        )
        self.spatial = nn.TransformerEncoder(
            encoder_layer=spatial_encoder_layer, num_layers=config.mix.tfm_layer
        )
        if mlp_flag:
            self.temporal_mlp=TwoLayer(config.dims,config.dropout,config.eps)
            self.spatial_mlp=TwoLayer(config.dims,config.dropout,config.eps)
        self.frames=config.frames
        self.actors=config.actors+1
        self.mlp_flag=mlp_flag


    def forward(self,temporal_x,spaital_x,mask):

        spaital_x=einops.rearrange(spaital_x,'(b f) n d ->  (b n) f d',f=self.frames,n=self.actors) 

        temporal_x=einops.rearrange(temporal_x,'(b n) f d -> (b f) n d',f=self.frames,n=self.actors)
        if mask is not None:
            tem_mask=einops.rearrange(mask,'b f n -> (b f) n',f=self.frames,n=self.actors)
            spa_mask=einops.rearrange(mask,'b f n ->  (b n) f ',f=self.frames,n=self.actors) 
            spa_mask[torch.all(spa_mask==True,dim=-1)]=False
        
        if self.mlp_flag:
            if mask is None:
                out_s=self.spatial(self.spatial_mlp(temporal_x))
                out_t=self.temporal(self.temporal_mlp(spaital_x))
            else:
                out_s=self.spatial(self.spatial_mlp(temporal_x),
                                src_key_padding_mask=tem_mask)
                out_t=self.temporal(self.temporal_mlp(spaital_x),
                                src_key_padding_mask=spa_mask)      
        else:
            if mask is None:
                out_s=self.spatial(temporal_x)
                out_t=self.temporal(spaital_x)
            else:
                # breakpoint()
                out_s=self.spatial(temporal_x,
                        src_key_padding_mask=tem_mask)
                out_t=self.temporal(spaital_x,
                        src_key_padding_mask=spa_mask)
        # breakpoint()
        return out_t,out_s

class MixBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        # st
        self.layer1=MixLayer(config,False)
        # ts
        self.layer2=MixLayer(config,True)

        self.frames=config.frames
        self.actors=config.actors+1
        self.mode=config.mix.mode
        self.tmr_mlp=MLP(config.dims,config.dims,config.dropout,config.eps)
        self.spt_mlp=MLP(config.dims,config.dims,config.dropout,config.eps)


    # out temporal spatial
    # temporal batch*node frame dim
    # spatial batch*frame node dim
    def forward(self,temporal_in,spatial_in,mask):

        layer1_t,layer1_s=self.layer1(temporal_in,spatial_in,mask)

        if self.mode=='mix':
            layer2_t=self.tmr_mlp(einops.rearrange(layer1_s,'(b f) n d -> (b n) f d',f=self.frames,n=self.actors)+layer1_t+temporal_in)
            layer2_s=self.spt_mlp(einops.rearrange(layer1_t,'(b n) f d -> (b f) n d',f=self.frames,n=self.actors)+layer1_s+spatial_in)
        else:
            layer2_t=temporal_in+layer1_t
            layer2_s=spatial_in+layer1_s

        layer3_t,layer3_s=self.layer2(layer2_t,layer2_s,mask)
        return layer3_t,layer3_s

class MixTSE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer=config.mix.layer
        self.tse=nn.ModuleList() 
        for i in range(self.layer):
            self.tse.append(MixBlock(config))
        self.mlp=nn.Sequential(nn.Dropout(config.dropout),nn.Linear(config.dims,config.dims),nn.GELU())

        
    # batch frame node dim
    # mask batch frame node
    def forward(self,X,mask=None):
        ba,fr,nod,dim=X.shape
        temporal=einops.rearrange(X,'b f n d -> (b n) f d')
        spatial=einops.rearrange(X,'b f n d -> (b f) n d')
        if mask is not None:
            padding_=torch.zeros((ba,fr,1),dtype=torch.bool).to(mask.device)
            mask=torch.cat([padding_,mask],dim=-1)
        for layer in self.tse:
            temporal,spatial=layer(temporal,spatial,mask)
        temporal=einops.rearrange(temporal,'(b n) f d -> b f n d',b=ba,f=fr,n=nod)
        spatial=einops.rearrange(spatial,'(b f) n d -> b f n d',b=ba,f=fr,n=nod)
        return self.mlp(temporal+spatial)


class CLSHead(nn.Module):
    def __init__(self,dims,eps,dropout,cls):
        super().__init__()
        # self.ffn=FFN(dims,eps,dims*4,dropout)
        self.ffn=MLP(dims,dims,dropout,eps)
        self.cls_head=nn.Linear(dims,cls)
    def get_last_layer(self):
        return self.ffn.get_last_layer()

    def forward(self,X):
        return self.cls_head(self.ffn(X))

class ReconstructNetwork(nn.Module):

    def __init__(self,config):
        super().__init__()
        self.layer=MLPs(config.dims,config.dropout,config.eps,config.recs.layer)
    
    def forward(self,X):
        return self.layer(X)
        # return self.layer2(self.layer1(X))

class CLIPAdapter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pj=nn.Sequential(nn.Dropout(config.dropout),
                              nn.Linear(config.adapter.pj.indims,config.adapter.pj.outdims),nn.ReLU(inplace=True))
        self.adapter=nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 512),
            nn.ReLU(inplace=True)
        )


        # self.adapter=nn.Sequential(
        #     nn.Linear(768, 768),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(768, 768),
        #     nn.ReLU(inplace=True)
        # )
    
    # batch frames nums dims
    def forward(self,x):
        # pj_f=self.pj(x)
        # fea=self.adapter(pj_f)

        fea=self.pj(x+self.adapter(x))
        return fea

class VisualModel(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.norm=nn.Sequential(nn.LayerNorm(config.dims),nn.GELU())
        self.sq=nn.Sequential()
        for i in range(config.lan.layers):
            if i != config.lan.layers - 1:
                self.sq.add_module('ffn'+str(i),FFN(in_dim=config.dims,dropout=config.dropout,eps=config.eps))
            else:
                self.sq.add_module('ffn'+str(i),FFN(in_dim=config.dims,dropout=config.dropout,eps=config.eps,norm_=False))
    def forward(self,X):
        return self.norm(X+self.sq(X))

class ProjectionHead(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.li=nn.Sequential(nn.Linear(512,768),nn.LayerNorm(768),nn.GELU())
        self.ffn=VisualModel(config)

    def forward(self,X,bbx):
        X=self.li(X)+bbx
        return self.ffn(X)

class Head(nn.Module):
    def __init__(self, dims,eps,dropout,out_cls):
        super().__init__()
        self.dy=DynamicFlatter(dims,dropout,eps)
        self.head=CLSHead(dims,eps,dropout,out_cls)

    def set_visual(self,flag):
        self.dy.set_visual(flag)
    
    def clear_visual(self):
        self.dy.clear_visual()

    def get_visual(self):
        return self.dy.get_visual()

    def forward(self,X):
        return self.head(self.dy(X))

    def get_last_layer(self):
        return self.head.get_last_layer()

# new no middle

class GPNNMix4(nn.Module):
    # false linear true embedding
    # stage_1 train_stage private/common/total
    # stage_2 train_stage continue total middle
    # stage_3 train_stage backbone only -> middle
    # stage_4 inference -> middle total commonm private
    # stage_5 train_stage backbone only -> middle continue
    # stage_6 single branch
    # stage_7 dual branch
    # stage_8 single branch continue
    # stage_9 dual branch continue
    # stage_10 single branch inference
    # stage_11 dual branch inference
    # stage_12 groudtruth test total out
    # stage_13 no prompt test total out
    # init_1 backbone 
    # init_2 backbone+private+common+total+middle

    def __init__(self, config,flag=False,train_stage=1,pre=False,lt=0):
        super().__init__()
        self.stage=train_stage
        self.flag=flag
        self.config=config
        self.pre=pre
        # loss type
        # 0: all loss;1: reconstruction.2.separation.3 no loss
        self.lt=lt
        print('train_stage',self.stage)
        if self.stage in [3,5]:
            self.model_init1(config)
            self.model_init3(config)
        elif self.stage in [2,4,1,6,7,8,12,13]:
            self.model_init1(config)
            self.model_init2(config)
        elif self.stage in [9]:
            self.model_init1(config)
            self.model_init2(config)
        elif self.stage in [10,11]:
            self.model_init1(config)
            self.model_init2(config)
            self.model_init3(config)
            # self.model_init3(config)
        else:
            raise NotImplementedError
        self.freeze()
    
    def set_visual(self,flag=True):
        self.gpnn.set_visual(flag)
        self.p_head.set_visual(flag)
        self.c_head.set_visual(flag)
        self.cls_head.set_visual(flag)

    def get_visual(self):
        return self.gpnn.visual(),self.p_head.get_visual(),self.c_head.get_visual(),self.cls_head.get_visual()
    def clear_visual(self):
        self.gpnn.clear_visual(),self.p_head.clear_visual(),self.c_head.clear_visual(),self.cls_head.clear_visual()
    
    def model_init1(self,config):
        self.cls_embed=nn.Embedding(38,768,padding_idx=0)
        self.rel_embed=nn.Linear(30,768)
        self.tse=MixTSE(config)
        # self.feature_mergin=nn.Sequential(nn.Linear(config.dims*2,config.dims),nn.LayerNorm(config.dims),nn.GELU())


        self.mffn=FFN(config.dims,config.eps,config.dims*4,config.dropout)
        
        
        self.bbx_linear=nn.Sequential(nn.Linear(4,config.dims),nn.LayerNorm(config.dims,eps=config.eps),nn.GELU())
        self.fusion=nn.Sequential(nn.Dropout(config.dropout),nn.Linear(config.dims,config.dims),nn.LayerNorm(config.dims,eps=config.eps),nn.GELU())
        self.rel_mlp=MLPCLS(config.dims,config.cls.rel,config.dropout,config.eps)

        self.edge_fun=nn.Sequential(nn.Dropout(config.dropout),nn.Linear(config.dims*3,config.dims),nn.GELU(),
                                    nn.Dropout(config.dropout),nn.Linear(config.dims,config.dims),nn.LayerNorm(config.dims,eps=config.eps),nn.GELU())
        self.obj_mlp=MLPCLS(config.dims,config.cls.obj,config.dropout,config.eps)
        self.adapter=CLIPAdapter(config)
        self.pj=ProjectionHeadFT(config.dims,config.eps,config.dims*4,config.dropout)
        self.pos = nn.Parameter(torch.zeros(1,config.frames,1,config.dims))
        # self.mergin_feature=nn.Sequential(nn.Linear(config.dims*2,config.dims),nn.LayerNorm(config.dims),nn.GELU())

    # stage 2/3
    def model_init2(self,config):
        if config.prompt.type==1:
            print('pgfp')
            self.pgpfp=GPFPlus(config,self.flag)
            self.cgpfp=GPFPlus(config,self.flag)
        elif config.prompt.type==0:
            self.pgpfp=SimplePrompt(config,self.flag)
            self.cgpfp=SimplePrompt(config,self.flag)
        else:
            raise NotImplementedError
        self.gpnn=GPNN5(config,config.gpnn.layer.one)
        # self.m_head2=Head(config.dims,config.eps,config.dropout,config.cls.ag)

        self.mffn2=FFN(config.dims,config.eps,config.dims*4,config.dropout)
        self.mffn3=FFN(config.dims,config.eps,config.dims*4,config.dropout)
        if self.stage not in [9]:
            self.p_head=Head(config.dims,config.eps,config.dropout,config.cls.ag+1)
            self.c_head=Head(config.dims,config.eps,config.dropout,config.cls.ag+1)
        if self.lt<2:
            self.recs=ReconstructNetwork(config)
        self.total_pj=FFN(config.dims,config.eps,config.dims*4,config.dropout)
        self.gf=GateFusion(config)
        self.cls_head=Head(config.dims,config.eps,config.dropout,config.cls.ag)

    # stage 2/3
    def model_init3(self,config):
        self.m_head=Head(config.dims,config.eps,config.dropout,config.cls.ag)


    def get_weight(self,p_loss,c_loss):
        if self.stage in [6]:
            p_grad = torch.autograd.grad(p_loss, self.c_head.get_last_layer(), retain_graph=True)[0]
        else:
            p_grad = torch.autograd.grad(p_loss, self.p_head.get_last_layer(), retain_graph=True)[0]
        c_grad = torch.autograd.grad(c_loss, self.c_head.get_last_layer(), retain_graph=True)[0]

        d_weight = torch.norm(c_grad) / (torch.norm(p_grad) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        return d_weight
    
    def freeze(self):
        if self.stage in [1,3,4,6,7,10,11,12,13]:
            return
        for param in self.parameters():
            param.requires_grad = False
        # 2 total/middle
        # 3 middle backbone only
        train_modules={
            2: ['total_pj','cls_head'],
            8: ['total_pj','cls_head'],
            9: ['total_pj','cls_head'],
            5: ['m_head'],
            #2: [self.mffn, self.m_head, self.gf, self.cls_head, self.total_pj]
        }
        # modules=[m for m in getattr(self,module_name)]
        modules=[getattr(self,module_name) for module_name in train_modules.get(self.stage,[])]
        for module in modules:
            for param in module.parameters():
                param.requires_grad=True

    # private common total
    def forward1(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        adapter_humam=adapter_feature[:,:,1:,:]
        # pre_feature=
 
        # frames_features=
        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))
        # breakpoint()

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature
        cls_ans=self.obj_mlp(adapter_humam)
        # scene graph
        human_obj_feature=human_obj_feature+cls_feature
        #no scenegraph
        # human_obj_feature=human_obj_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))
        rel_ans=self.rel_mlp(edge_feature)
        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

        # scene graph
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)

        

        p_f=self.mffn2(self.pgpfp(human_obj_feature,task_id))
        c_f=self.mffn3(self.cgpfp(human_obj_feature,task_id))
        
        p_human_feature,p_obj_feature=self.gpnn(p_f[:,:,0,:].unsqueeze(-2),p_f[:,:,1:,:],edge_feature,self.pgpfp,task_id,mask,tfm_mask)

        c_human_feature,c_obj_feature=self.gpnn(c_f[:,:,0,:].unsqueeze(-2),c_f[:,:,1:,:],edge_feature,self.cgpfp,task_id,mask,tfm_mask)


        p_features=torch.cat([p_human_feature,p_obj_feature],dim=-2)

        c_features=torch.cat([c_human_feature,c_obj_feature],dim=-2)


        c_ans=self.c_head(c_features)
        p_ans=self.p_head(p_features)
        # rec=
        # t_node_features=
        t_node=self.gf(p_features,c_features)
        t_ans=self.cls_head(self.total_pj(t_node))
        recs=self.recs(t_node)
        

        return c_ans,p_ans,cls_ans,rel_ans,c_features,p_features,recs,human_obj_feature,t_ans
    # total middle
    def forward2(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)

        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos))

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature

        human_obj_feature=human_obj_feature+cls_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))

        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)
        m_ans=self.m_head(human_obj_feature)
        

        p_f=self.mffn2(self.pgpfp(human_obj_feature,task_id))
        c_f=self.mffn3(self.cgpfp(human_obj_feature,task_id))

        p_human_feature,p_obj_feature=self.gpnn(p_f[:,:,0,:].unsqueeze(-2),p_f[:,:,1:,:],edge_feature,self.pgpfp,task_id)

        # c_human_feature,c_obj_feature=self.cgpnn(human_feature,obj_feature,edge_feature,self.cgpfp,task_id)
        c_human_feature,c_obj_feature=self.gpnn(c_f[:,:,0,:].unsqueeze(-2),c_f[:,:,1:,:],edge_feature,self.cgpfp,task_id)

        p_features=torch.cat([p_human_feature,p_obj_feature],dim=-2)

        c_features=torch.cat([c_human_feature,c_obj_feature],dim=-2)
        t_node=self.gf(p_features,c_features)
        t_ans=self.cls_head(self.total_pj(t_node))
        

        return t_ans,m_ans

    # backbone only
    def forward3(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)

        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        adapter_humam=adapter_feature[:,:,1:,:]

        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos))

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature
        cls_ans=self.obj_mlp(adapter_humam)
        human_obj_feature=human_obj_feature+cls_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))
        rel_ans=self.rel_mlp(edge_feature)

        human_obj_feature=self.mffn(human_obj_feature)
        m_ans=self.m_head(human_obj_feature)
        
        return m_ans,cls_ans,rel_ans
    # inference -> middle/common/private/total
    def forward4(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        adapter_humam=adapter_feature[:,:,1:,:]
        # pre_feature=
 
        # frames_features=
        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))
        # breakpoint()

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature
        # scene graph
        human_obj_feature=human_obj_feature+cls_feature
        #no scenegraph
        # human_obj_feature=human_obj_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))
        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

        # scene graph
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)
        # m_ans=self.m_head(human_obj_feature)

        
        task_id2=(~task_id.bool()).float()
        p_f=self.mffn2(self.pgpfp(human_obj_feature,task_id2))
        c_f=self.mffn3(self.cgpfp(human_obj_feature,task_id))
        
        p_human_feature,p_obj_feature=self.gpnn(p_f[:,:,0,:].unsqueeze(-2),p_f[:,:,1:,:],edge_feature,self.pgpfp,task_id2,mask,tfm_mask)

        c_human_feature,c_obj_feature=self.gpnn(c_f[:,:,0,:].unsqueeze(-2),c_f[:,:,1:,:],edge_feature,self.cgpfp,task_id,mask,tfm_mask)
        p_features=torch.cat([p_human_feature,p_obj_feature],dim=-2)
        c_features=torch.cat([c_human_feature,c_obj_feature],dim=-2)
        c_ans=self.c_head(c_features)
        p_ans=self.p_head(p_features)
        # rec=
        # t_node_features=
        t_node=self.gf(p_features,c_features)
        # t_node=self.mergin_feature(torch.cat([p_features,c_features],dim=-1))
        # t_node=self.feature_mergin(torch.cat([p_features,c_features],dim=-1))
        t_ans=self.cls_head(self.total_pj(t_node))
        

        return c_ans,p_ans,t_ans
    # backbone only continue
    def forward5(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()

        cls_feature=self.cls_embed(cls)

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)

        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos))

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]+cls_feature
        # supevised by adapter feature
        human_obj_feature=human_obj_feature
        human_obj_feature=self.mffn(human_obj_feature)
        m_ans=self.m_head(human_obj_feature)
        return m_ans
 
  # single branch
    def forward6(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        adapter_humam=adapter_feature[:,:,1:,:]
        # pre_feature=
 
        # frames_features=
        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))
        # breakpoint()

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature
        cls_ans=self.obj_mlp(adapter_humam)
        # scene graph
        human_obj_feature=human_obj_feature+cls_feature
        #no scenegraph
        # human_obj_feature=human_obj_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))
        rel_ans=self.rel_mlp(edge_feature)
        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

        # scene graph
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)
        task_id=torch.cat([task_id,(~task_id.bool()).float()],dim=0)
        nhuman_obj_feature=torch.cat([human_obj_feature,human_obj_feature],dim=0)
        

        # p_f=self.mffn2(self.pgpfp(human_obj_feature,task_id))
        # c_f=self.mffn3(self.cgpfp(human_obj_feature,task_id))
        pc_f=self.mffn2(self.cgpfp(nhuman_obj_feature,task_id))
        edge_feature=torch.cat([edge_feature,edge_feature],dim=0)
        
        pc_human_feature,pc_obj_feature=self.gpnn(pc_f[:,:,0,:].unsqueeze(-2),pc_f[:,:,1:,:],edge_feature,self.cgpfp,task_id,mask,tfm_mask)


        # pc
        pc_feature=torch.cat([pc_human_feature,pc_obj_feature],dim=-2)
        p_features=pc_feature[B:,:,:,:]

        c_features=pc_feature[:B,:,:,:]


        pc_ans=self.c_head(pc_feature)
        p_ans=pc_ans[B:,:]
        c_ans=pc_ans[:B,:]
        # p_ans=self.p_head(p_features)
        # rec=
        # t_node_features=
        t_node=self.gf(p_features,c_features)
        t_ans=self.cls_head(self.total_pj(t_node))
        recs=self.recs(t_node)
        

        return c_ans,p_ans,cls_ans,rel_ans,c_features,p_features,recs,human_obj_feature,t_ans
    # dual branch
    def forward7(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        adapter_humam=adapter_feature[:,:,1:,:]
        # pre_feature=
 
        # frames_features=
        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))
        # breakpoint()

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature
        cls_ans=self.obj_mlp(adapter_humam)
        # scene graph
        human_obj_feature=human_obj_feature+cls_feature
        #no scenegraph
        # human_obj_feature=human_obj_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))
        rel_ans=self.rel_mlp(edge_feature)
        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

        # scene graph
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)

        
        task_id2=(~task_id.bool()).float()
        p_f=self.mffn2(self.pgpfp(human_obj_feature,task_id2))
        c_f=self.mffn3(self.cgpfp(human_obj_feature,task_id))
        
        p_human_feature,p_obj_feature=self.gpnn(p_f[:,:,0,:].unsqueeze(-2),p_f[:,:,1:,:],edge_feature,self.pgpfp,task_id2,mask,tfm_mask)

        c_human_feature,c_obj_feature=self.gpnn(c_f[:,:,0,:].unsqueeze(-2),c_f[:,:,1:,:],edge_feature,self.cgpfp,task_id,mask,tfm_mask)
        p_features=torch.cat([p_human_feature,p_obj_feature],dim=-2)
        c_features=torch.cat([c_human_feature,c_obj_feature],dim=-2)
        c_ans=self.c_head(c_features)
        p_ans=self.p_head(p_features)
        # rec=
        # t_node_features=
        t_node=self.gf(p_features,c_features)
        # t_node=self.mergin_feature(torch.cat([p_features,c_features],dim=-1))
        # t_node=self.feature_mergin(torch.cat([p_features,c_features],dim=-1))
        t_ans=self.cls_head(self.total_pj(t_node))
        if self.lt<2:
            recs=self.recs(t_node)
        else:
            recs=[]
        return c_ans,p_ans,cls_ans,rel_ans,c_features,p_features,recs,human_obj_feature,t_ans
    # total middle
  # single branch
    def forward8(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        with torch.no_grad():
            # nums =node+1
            B,Frame,Nums,dims=frames.shape
            # breakpoint()
            Nums=Nums
            # bbx_list=bbx_list[:,:,1:,:]
            bbx=self.bbx_linear(bbx_list)
            # breakpoint()
            cls_feature=self.cls_embed(cls)
            rel_feature=self.rel_embed(rel)
            

            pos=self.pos.repeat(B,1,Nums,1)
            adapter_feature=self.adapter(frames)
            # supervised
            # pre_feature=
    
            # frames_features=
            # projection head
            frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))
            # breakpoint()

            # total features for a consist edge cls
            human_obj_feature=frames_features[:,:,1:,:]
            # supevised by adapter feature
            # scene graph
            human_obj_feature=human_obj_feature+cls_feature
            #no scenegraph
            # human_obj_feature=human_obj_feature
            human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
            human_features=human_feature.repeat(1,1,Nums-2,1)
            obj_feature=human_obj_feature[:,:,1:,:]
            global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
            edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))
    
            # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

            # scene graph
            edge_feature=edge_feature+rel_feature

            human_obj_feature=self.mffn(human_obj_feature)
            task_id=torch.cat([task_id,(~task_id.bool()).float()],dim=0)
            nhuman_obj_feature=torch.cat([human_obj_feature,human_obj_feature],dim=0)
            

            # p_f=self.mffn2(self.pgpfp(human_obj_feature,task_id))
            # c_f=self.mffn3(self.cgpfp(human_obj_feature,task_id))
            pc_f=self.mffn2(self.cgpfp(nhuman_obj_feature,task_id))
            edge_feature=torch.cat([edge_feature,edge_feature],dim=0)
            
            pc_human_feature,pc_obj_feature=self.gpnn(pc_f[:,:,0,:].unsqueeze(-2),pc_f[:,:,1:,:],edge_feature,self.cgpfp,task_id,mask,tfm_mask)


            # pc
            pc_feature=torch.cat([pc_human_feature,pc_obj_feature],dim=-2)
            p_features=pc_feature[B:,:,:,:]

            c_features=pc_feature[:B,:,:,:]

            # p_ans=self.p_head(p_features)
            # rec=
            # t_node_features=
            t_node=self.gf(p_features,c_features)
        t_ans=self.cls_head(self.total_pj(t_node))
        m_ans=self.m_head(human_obj_feature)

        

        return t_ans,m_ans
    # dual branch
    def forward9(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
       
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        # pre_feature=

        # frames_features=
        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))
        # breakpoint()

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature
        # scene graph
        human_obj_feature=human_obj_feature+cls_feature
        #no scenegraph
        # human_obj_feature=human_obj_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))

        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

        # scene graph
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)

        
        task_id2=(~task_id.bool()).float()
        p_f=self.mffn2(self.pgpfp(human_obj_feature,task_id2))
        c_f=self.mffn3(self.cgpfp(human_obj_feature,task_id))
        
        p_human_feature,p_obj_feature=self.gpnn(p_f[:,:,0,:].unsqueeze(-2),p_f[:,:,1:,:],edge_feature,self.pgpfp,task_id2,mask,tfm_mask)

        c_human_feature,c_obj_feature=self.gpnn(c_f[:,:,0,:].unsqueeze(-2),c_f[:,:,1:,:],edge_feature,self.cgpfp,task_id,mask,tfm_mask)


        p_features=torch.cat([p_human_feature,p_obj_feature],dim=-2)

        c_features=torch.cat([c_human_feature,c_obj_feature],dim=-2)


        # rec=
        # t_node_features=
        t_node=self.gf(p_features,c_features)
        t_ans=self.cls_head(self.total_pj(t_node))
        # m_ans=self.m_head(human_obj_feature)
        

        return t_ans
    # total middle
  # single branch
    @torch.no_grad()
    def forward10(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        # pre_feature=
 
        # frames_features=
        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))
        # breakpoint()

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature

        # scene graph
        human_obj_feature=human_obj_feature+cls_feature
        #no scenegraph
        # human_obj_feature=human_obj_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))

        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

        # scene graph
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)
        task_id=torch.cat([task_id,(~task_id.bool()).float()],dim=0)
        # batch frame node dims
        nhuman_obj_feature=torch.cat([human_obj_feature,human_obj_feature],dim=0)
        

        # p_f=self.mffn2(self.pgpfp(human_obj_feature,task_id))
        # c_f=self.mffn3(self.cgpfp(human_obj_feature,task_id))
        pc_f=self.mffn2(self.cgpfp(nhuman_obj_feature,task_id))
        edge_feature=torch.cat([edge_feature,edge_feature],dim=0)
        #obj: batch*2 frame node-1 dims / human:batch*2 frame 1 dims


        pc_human_feature,pc_obj_feature=self.gpnn(pc_f[:,:,0,:].unsqueeze(-2),pc_f[:,:,1:,:],edge_feature,self.cgpfp,task_id,mask,tfm_mask)


        # pc
        pc_feature=torch.cat([pc_human_feature,pc_obj_feature],dim=-2)
        p_features=pc_feature[B:,:,:,:]

        c_features=pc_feature[:B,:,:,:]
        pc_ans=self.c_head(pc_feature)
        
        p_ans=pc_ans[B:,:]
        c_ans=pc_ans[:B,:]
        # p_ans=self.p_head(p_features)
        # rec=
        # t_node_features=
        t_node=self.gf(p_features,c_features)
        t_ans=self.cls_head(self.total_pj(t_node))

        m_ans=self.m_head(human_obj_feature)
        

        return c_ans,p_ans,t_ans,m_ans
    # dual branch
    @torch.no_grad()
    def forward11(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        adapter_humam=adapter_feature[:,:,1:,:]
        # pre_feature=
 
        # frames_features=
        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))

        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature

        # scene graph
        human_obj_feature=human_obj_feature+cls_feature

        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))
 
        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

        # scene graph
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)

        
        task_id2=(~task_id.bool()).float()
        p_f=self.mffn2(self.pgpfp(human_obj_feature,task_id2))
        c_f=self.mffn3(self.cgpfp(human_obj_feature,task_id))
        
        p_human_feature,p_obj_feature=self.gpnn(p_f[:,:,0,:].unsqueeze(-2),p_f[:,:,1:,:],edge_feature,self.pgpfp,task_id2,mask,tfm_mask)

        c_human_feature,c_obj_feature=self.gpnn(c_f[:,:,0,:].unsqueeze(-2),c_f[:,:,1:,:],edge_feature,self.cgpfp,task_id,mask,tfm_mask)
        p_features=torch.cat([p_human_feature,p_obj_feature],dim=-2)
        c_features=torch.cat([c_human_feature,c_obj_feature],dim=-2)

        t_node=self.gf(p_features,c_features)
        # t_node=self.mergin_feature(torch.cat([p_features,c_features],dim=-1))
        # t_node=self.feature_mergin(torch.cat([p_features,c_features],dim=-1))
        t_ans=self.cls_head(self.total_pj(t_node))
        m_ans=self.m_head(human_feature)
        return t_ans,m_ans
    # ground truth test
    def forward12(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        adapter_humam=adapter_feature[:,:,1:,:]
        # pre_feature=
 
        # frames_features=
        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))
        # breakpoint()

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature
        cls_ans=self.obj_mlp(adapter_humam)
        # scene graph
        human_obj_feature=human_obj_feature+cls_feature
        #no scenegraph
        # human_obj_feature=human_obj_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))
        rel_ans=self.rel_mlp(edge_feature)
        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

        # scene graph
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)
        # common id task id
        pc_id=torch.cat(task_id,dim=0)
        nhuman_obj_feature=torch.cat([human_obj_feature,human_obj_feature],dim=0)

        pc_f=self.mffn2(self.cgpfp(nhuman_obj_feature,pc_id))
        
        pc_human_feature,pc_obj_feature=self.gpnn(pc_f[:,:,0,:].unsqueeze(-2),pc_f[:,:,1:,:],
                                                  torch.cat([edge_feature,edge_feature],dim=0),self.cgpfp,pc_id,mask,tfm_mask)



        pc_features=torch.cat([pc_human_feature,pc_obj_feature],dim=-2)



        pc_ans=self.c_head(pc_features)
        # rec=
        # t_node_features=
        t_node=self.gf(pc_features[:B,:,:,:],pc_features[B:,:,:,:])
        t_ans=self.cls_head(self.total_pj(t_node))
        recs=self.recs(t_node)
        
        return pc_ans[:B:,:],pc_ans[B:,:],cls_ans,rel_ans,pc_features[:B,:,:,:],pc_features[B:,:,:,:],recs,human_obj_feature,t_ans
    # no prompt test
    def forward13(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        # nums =node+1
        B,Frame,Nums,dims=frames.shape
        # breakpoint()
        Nums=Nums
        # bbx_list=bbx_list[:,:,1:,:]
        bbx=self.bbx_linear(bbx_list)
        # breakpoint()
        cls_feature=self.cls_embed(cls)
        rel_feature=self.rel_embed(rel)
        

        pos=self.pos.repeat(B,1,Nums,1)
        adapter_feature=self.adapter(frames)
        # supervised
        adapter_humam=adapter_feature[:,:,1:,:]
        # pre_feature=
 
        # frames_features=
        # projection head
        frames_features=self.pj(self.tse(self.fusion(adapter_feature+bbx)+pos,tfm_mask))
        # breakpoint()

        # total features for a consist edge cls
        human_obj_feature=frames_features[:,:,1:,:]
        # supevised by adapter feature
        cls_ans=self.obj_mlp(adapter_humam)
        # scene graph
        human_obj_feature=human_obj_feature+cls_feature
        #no scenegraph
        # human_obj_feature=human_obj_feature
        human_feature=human_obj_feature[:,:,0,:].unsqueeze(-2)
        human_features=human_feature.repeat(1,1,Nums-2,1)
        obj_feature=human_obj_feature[:,:,1:,:]
        global_feature=frames_features[:,:,0,:].unsqueeze(-2).repeat(1,1,Nums-2,1)
        edge_feature=self.edge_fun(torch.cat([human_features,global_feature,obj_feature],dim=-1))
        rel_ans=self.rel_mlp(edge_feature)
        # print('shap',human_features.shape,global_feature.shape,obj_feature.shape)

        # scene graph
        edge_feature=edge_feature+rel_feature

        human_obj_feature=self.mffn(human_obj_feature)
        pc_feature=self.mffn2(human_obj_feature)
        
        p_human_feature,p_obj_feature=self.gpnn(pc_feature[:,:,0,:].unsqueeze(-2),pc_feature[:,:,1:,:],edge_feature,None,task_id,mask,tfm_mask)



        p_features=torch.cat([p_human_feature,p_obj_feature],dim=-2)


        # rec=
        # t_node_features=
        t_ans=self.cls_head(self.total_pj(p_features))
        return cls_ans,rel_ans,t_ans
    # add [0.,0.,1.,1.] to the first line of every batch 
    def forward(self,frames,cls,rel,bbx_list,task_id,mask=None,tfm_mask=None):
        forwards=[   
                 self.forward1,
                 self.forward2,
                 self.forward3,
                 self.forward4,
                 self.forward5,
                 self.forward6,
                 self.forward7,
                 self.forward8,
                 self.forward9,
                 self.forward10,
                 self.forward11,
                 self.forward12,
                 self.forward13]
        return forwards[self.stage-1](frames,cls,rel,bbx_list,task_id,mask,tfm_mask)


if __name__=='__main__':
    config=load_config()
    device='cuda:0'
    a=torch.randn(5,16,10,768)
    data=process_data(a)
    X,edge=data.x.to(device),data.edge_index.to(device)
    model=MyModel(config)
    model.to(device)
    private,common,atom,ans=model(X,edge)
    entro=torch.log2(torch.std(private[0],-1))
    print(ans.shape,entro.shape)
