# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import logging

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from unicore import utils
from unicore.data import Dictionary
from unicore.models import (BaseUnicoreModel, register_model,
                            register_model_architecture)
from unicore.modules import LayerNorm
import unicore


from .transformer_encoder_with_pair import TransformerEncoderWithPair
from .unimol import NonLinearHead, UniMolModel, base_architecture

logger = logging.getLogger(__name__)


@register_model("binding_affinity")
class BindingAffinityModel(BaseUnicoreModel):
    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        parser.add_argument(
            "--mol-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-encoder-layers",
            type=int,
            help="pocket encoder layers",
        )
        parser.add_argument(
            "--recycling",
            type=int,
            default=1,
            help="recycling nums of decoder",
        )


    def __init__(self, args, mol_dictionary, pocket_dictionary):
        super().__init__()
        unimol_docking_architecture(args)

        self.args = args
        self.mol_model = UniMolModel(args.mol, mol_dictionary)
        self.pocket_model = UniMolModel(args.pocket, pocket_dictionary)

        self.cross_distance_project = NonLinearHead(
            args.mol.encoder_embed_dim * 2 + args.mol.encoder_attention_heads, 1, "relu"
        )
        self.holo_distance_project = DistanceHead(
            args.mol.encoder_embed_dim + args.mol.encoder_attention_heads, "relu"
        )
        
        self.mol_project = NonLinearHead(
            args.mol.encoder_embed_dim, 128, "relu"
        )

        self.logit_scale = nn.Parameter(torch.ones([1], device="cuda") * np.log(14))
        print(self.logit_scale, self.logit_scale.requires_grad,self.logit_scale.is_leaf)
        
        #self.mol_project = nn.Linear(args.mol.encoder_embed_dim, 512)

        '''
        
        self.mol_project = nn.Sequential(
            nn.Linear(args.mol.encoder_embed_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
        )

        '''
        
        self.pocket_project = NonLinearHead(
            args.pocket.encoder_embed_dim, 128, "relu"
        )
        
        #self.pocket_project = nn.Linear(args.pocket.encoder_embed_dim, 512)
        # '''
        # self.pocket_project = nn.Sequential(
        #     nn.Linear(args.pocket.encoder_embed_dim, 1024),
        #     nn.ReLU(),
        #     nn.Linear(1024, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, 128),
        # )
        # '''
        self.fuse_project = NonLinearHead(
            256, 1, "relu"
        )
        self.classification_head = nn.Sequential(
            nn.Linear(args.pocket.encoder_embed_dim + args.pocket.encoder_embed_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        return cls(args, task.dictionary, task.pocket_dictionary)

    def forward(
        self,
        mol_src_tokens,
        mol_src_distance,
        mol_src_edge_type,
        pocket_src_tokens,
        pocket_src_distance,
        pocket_src_edge_type,
        smi_list=None,
        pocket_list=None,
        encode=False,
        masked_tokens=None,
        features_only=True,
        is_train=True,
        **kwargs
    ):
        def get_dist_features(dist, et, flag):
            if flag == "mol":
                n_node = dist.size(-1)
                gbf_feature = self.mol_model.gbf(dist, et)
                gbf_result = self.mol_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias
            else:
                n_node = dist.size(-1)
                gbf_feature = self.pocket_model.gbf(dist, et)
                gbf_result = self.pocket_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias

        mol_padding_mask = mol_src_tokens.eq(self.mol_model.padding_idx)
        mol_x = self.mol_model.embed_tokens(mol_src_tokens)
        mol_graph_attn_bias = get_dist_features(
            mol_src_distance, mol_src_edge_type, "mol"
        )
        mol_outputs = self.mol_model.encoder(
            mol_x, padding_mask=mol_padding_mask, attn_mask=mol_graph_attn_bias
        )
        mol_encoder_rep = mol_outputs[0]
        encoder_pair_rep = mol_outputs[1]
        #print(encoder_pair_rep.shape)

        pocket_padding_mask = pocket_src_tokens.eq(self.pocket_model.padding_idx)
        pocket_x = self.pocket_model.embed_tokens(pocket_src_tokens)
        pocket_graph_attn_bias = get_dist_features(
            pocket_src_distance, pocket_src_edge_type, "pocket"
        )
        pocket_outputs = self.pocket_model.encoder(
            pocket_x, padding_mask=pocket_padding_mask, attn_mask=pocket_graph_attn_bias
        )
        pocket_encoder_rep = pocket_outputs[0]

        mol_rep =  mol_encoder_rep[:,0,:]
        pocket_rep = pocket_encoder_rep[:,0,:]

        mol_emb = self.mol_project(mol_rep)
        mol_emb = mol_emb / mol_emb.norm(dim=1, keepdim=True)
        pocket_emb = self.pocket_project(pocket_rep)
        pocket_emb = pocket_emb / pocket_emb.norm(dim=1, keepdim=True)
        
        
        #return torch.sum(pocket_emb * mol_emb, dim=1)*self.logit_scale.exp().detach()
        
        
        ba_predict = torch.matmul(pocket_emb, torch.transpose(mol_emb, 0, 1))

        # mol_emb = self.mol_project(mol_encoder_rep)[:,1:,:]

        # pocket_emb = self.pocket_project(pocket_encoder_rep)[:,1:,:]
        
        # ba_predict_pocket = torch.einsum('bik,tjk->btij', pocket_emb, mol_emb)
        # ba_predict_pocket = torch.max(ba_predict_pocket, dim=-1)[0]
        
        # ba_predict_pocket = torch.mean(ba_predict_pocket, dim=-1)
        
        # ba_predict_mol = torch.einsum('bik,tjk->btij', mol_emb, pocket_emb)
        # ba_predict_mol = torch.max(ba_predict_mol, dim=-1)[0]
        
        # ba_predict_mol = torch.mean(ba_predict_mol, dim=-1)
        





        
        
        
        bsz = ba_predict.shape[0]
        
        pockets = np.array(pocket_list, dtype=str)
        pockets = np.expand_dims(pockets, 1)
        matrix1 = np.repeat(pockets, len(pockets), 1)
        matrix2 = np.repeat(np.transpose(pockets), len(pockets), 0)
        pocket_duplicate_matrix = matrix1==matrix2
        pocket_duplicate_matrix = 1*pocket_duplicate_matrix
        pocket_duplicate_matrix = torch.tensor(pocket_duplicate_matrix, dtype=ba_predict.dtype).cuda()
        
        mols = np.array(smi_list, dtype=str)
        mols = np.expand_dims(mols, 1)
        matrix1 = np.repeat(mols, len(mols), 1)
        matrix2 = np.repeat(np.transpose(mols), len(mols), 0)
        mol_duplicate_matrix = matrix1==matrix2
        mol_duplicate_matrix = 1*mol_duplicate_matrix
        mol_duplicate_matrix = torch.tensor(mol_duplicate_matrix, dtype=ba_predict.dtype).cuda()

        
        

        onehot_labels = torch.eye(bsz).cuda()
        indicater_matrix = pocket_duplicate_matrix + mol_duplicate_matrix - 2*onehot_labels
        
        #print(ba_predict.shape)
        ba_predict = ba_predict *  self.logit_scale.exp().detach()
        ba_predict = indicater_matrix * -1e6 + ba_predict

        return ba_predict, self.logit_scale.exp() #_pocket, ba_predict_mol

    def set_num_updates(self, num_updates):
        """State from trainer to pass along to model at every update."""

        self._num_updates = num_updates

    def get_num_updates(self):
        return self._num_updates



@register_model("binding_affinity_ns")
class BindingAffinityNSModel(BaseUnicoreModel):
    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        parser.add_argument(
            "--mol-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-encoder-layers",
            type=int,
            help="pocket encoder layers",
        )
        parser.add_argument(
            "--recycling",
            type=int,
            default=1,
            help="recycling nums of decoder",
        )

    def __init__(self, args, mol_dictionary, pocket_dictionary):
        super().__init__()
        unimol_docking_architecture(args)

        self.args = args
        self.mol_model = UniMolModel(args.mol, mol_dictionary)
        self.pocket_model = UniMolModel(args.pocket, pocket_dictionary)

        self.cross_distance_project = NonLinearHead(
            args.mol.encoder_embed_dim * 2 + args.mol.encoder_attention_heads, 1, "relu"
        )
        self.holo_distance_project = DistanceHead(
            args.mol.encoder_embed_dim + args.mol.encoder_attention_heads, "relu"
        )
        
        self.mol_project = NonLinearHead(
            args.mol.encoder_embed_dim, 128, "relu"
        )

        self.logit_scale = nn.Parameter(torch.ones([1], device="cuda") * np.log(1/0.07))
        print(self.logit_scale, self.logit_scale.requires_grad,self.logit_scale.is_leaf)
        
        self.global_dataset = None
        self.global_data=  None

        #self.mol_project = nn.Linear(args.mol.encoder_embed_dim, 512)

        '''
        
        self.mol_project = nn.Sequential(
            nn.Linear(args.mol.encoder_embed_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
        )

        '''
        
        self.pocket_project = NonLinearHead(
            args.pocket.encoder_embed_dim, 128, "relu"
        )
        
        #self.pocket_project = nn.Linear(args.pocket.encoder_embed_dim, 512)
        # '''
        # self.pocket_project = nn.Sequential(
        #     nn.Linear(args.pocket.encoder_embed_dim, 1024),
        #     nn.ReLU(),
        #     nn.Linear(1024, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, 128),
        # )
        # '''
        self.fuse_project = NonLinearHead(
            256, 1, "relu"
        )
        self.classification_head = nn.Sequential(
            nn.Linear(args.pocket.encoder_embed_dim + args.pocket.encoder_embed_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        return cls(args, task.dictionary, task.pocket_dictionary)



    def forward_once(
        self,
        mol_src_tokens,
        mol_src_distance,
        mol_src_edge_type,
        pocket_src_tokens,
        pocket_src_distance,
        pocket_src_edge_type,
        smi_list,
        pocket_list,
        mol_model,
        pocket_model,
        **kwargs
    ):
        def get_dist_features(dist, et, flag):
            if flag == "mol":
                n_node = dist.size(-1)
                gbf_feature = mol_model.gbf(dist, et)
                gbf_result = mol_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias
            else:
                n_node = dist.size(-1)
                gbf_feature = pocket_model.gbf(dist, et)
                gbf_result = pocket_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias

        mol_padding_mask = mol_src_tokens.eq(mol_model.padding_idx)

        mol_x = mol_model.embed_tokens(mol_src_tokens)
        mol_graph_attn_bias = get_dist_features(
            mol_src_distance, mol_src_edge_type, "mol"
        )
        mol_outputs = mol_model.encoder(
            mol_x, padding_mask=mol_padding_mask, attn_mask=mol_graph_attn_bias
        )
        mol_encoder_rep = mol_outputs[0]

        pocket_padding_mask = pocket_src_tokens.eq(pocket_model.padding_idx)
        pocket_x = pocket_model.embed_tokens(pocket_src_tokens)
        pocket_graph_attn_bias = get_dist_features(
            pocket_src_distance, pocket_src_edge_type, "pocket"
        )
        pocket_outputs = pocket_model.encoder(
            pocket_x, padding_mask=pocket_padding_mask, attn_mask=pocket_graph_attn_bias
        )
        pocket_encoder_rep = pocket_outputs[0]

        mol_rep =  mol_encoder_rep[:,0,:]
        pocket_rep = pocket_encoder_rep[:,0,:]

        mol_emb = self.mol_project(mol_rep)
        mol_emb = mol_emb / mol_emb.norm(dim=1, keepdim=True)
        pocket_emb = self.pocket_project(pocket_rep)
        pocket_emb = pocket_emb / pocket_emb.norm(dim=1, keepdim=True)
        
        
        ba_predict = torch.matmul(pocket_emb, torch.transpose(mol_emb, 0, 1))
        
        bsz = ba_predict.shape[0]
        
        pockets = np.array(pocket_list, dtype=str)
        pockets = np.expand_dims(pockets, 1)
        matrix1 = np.repeat(pockets, len(pockets), 1)
        matrix2 = np.repeat(np.transpose(pockets), len(pockets), 0)
        pocket_duplicate_matrix = matrix1==matrix2
        pocket_duplicate_matrix = 1*pocket_duplicate_matrix
        pocket_duplicate_matrix = torch.tensor(pocket_duplicate_matrix, dtype=ba_predict.dtype).cuda()
        
        mols = np.array(smi_list, dtype=str)
        mols = np.expand_dims(mols, 1)
        matrix1 = np.repeat(mols, len(mols), 1)
        matrix2 = np.repeat(np.transpose(mols), len(mols), 0)
        mol_duplicate_matrix = matrix1==matrix2
        mol_duplicate_matrix = 1*mol_duplicate_matrix
        mol_duplicate_matrix = torch.tensor(mol_duplicate_matrix, dtype=ba_predict.dtype).cuda()

        
        

        onehot_labels = torch.eye(bsz).cuda()
        indicater_matrix = pocket_duplicate_matrix + mol_duplicate_matrix - 2*onehot_labels
        
        #print(ba_predict.shape)
        ba_predict = ba_predict *  self.logit_scale.exp()
        ba_predict = indicater_matrix * -1e6 + ba_predict


        return mol_encoder_rep, pocket_encoder_rep
    

    def forward(
        self,
        mol_src_tokens,
        mol_src_distance,
        mol_src_edge_type,
        pocket_src_tokens,
        pocket_src_distance,
        pocket_src_edge_type,
        smi_list=None,
        pocket_list=None,
        encode=False,
        masked_tokens=None,
        features_only=True,
        is_train=True,
        **kwargs
    ):
        
        def get_dist_features(dist, et, flag):
            if flag == "mol":
                n_node = dist.size(-1)
                gbf_feature = self.mol_model.gbf(dist, et)
                gbf_result = self.mol_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias
            else:
                n_node = dist.size(-1)
                gbf_feature = self.pocket_model.gbf(dist, et)
                gbf_result = self.pocket_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias

        mol_padding_mask = mol_src_tokens.eq(self.mol_model.padding_idx)

        mol_x = self.mol_model.embed_tokens(mol_src_tokens)
        mol_graph_attn_bias = get_dist_features(
            mol_src_distance, mol_src_edge_type, "mol"
        )
        mol_outputs = self.mol_model.encoder(
            mol_x, padding_mask=mol_padding_mask, attn_mask=mol_graph_attn_bias
        )
        mol_encoder_rep = mol_outputs[0]

        pocket_padding_mask = pocket_src_tokens.eq(self.pocket_model.padding_idx)
        pocket_x = self.pocket_model.embed_tokens(pocket_src_tokens)
        pocket_graph_attn_bias = get_dist_features(
            pocket_src_distance, pocket_src_edge_type, "pocket"
        )
        pocket_outputs = self.pocket_model.encoder(
            pocket_x, padding_mask=pocket_padding_mask, attn_mask=pocket_graph_attn_bias
        )
        pocket_encoder_rep = pocket_outputs[0]

        mol_rep =  mol_encoder_rep[:,0,:]
        pocket_rep = pocket_encoder_rep[:,0,:]

        mol_emb = self.mol_project(mol_rep)
        mol_emb = mol_emb / mol_emb.norm(dim=1, keepdim=True)
        pocket_emb = self.pocket_project(pocket_rep)
        pocket_emb = pocket_emb / pocket_emb.norm(dim=1, keepdim=True)



        bsz = pocket_emb.shape[0]
        # random indices from self.global_dataset

        #subset_indices = np.random.choice(len(self.global_dataset), bsz, replace=False)
        #subset = torch.utils.data.Subset(self.global_dataset, subset_indices)
        #global_data = torch.utils.data.DataLoader(subset, batch_size=bsz, collate_fn=self.global_dataset.collater, num_workers=64)

        #global_data = torch.utils.data.DataLoader(self.global_dataset, batch_size=bsz, collate_fn=self.global_dataset.collater, shuffle=True)
        global_sample = next(iter(self.global_data))
        global_sample = unicore.utils.move_to_cuda(global_sample)
        #print(global_sample["mol_len"])
        global_mol_src_tokens = global_sample["net_input"]["mol_src_tokens"]
        global_mol_src_distance = global_sample["net_input"]["mol_src_distance"]
        global_mol_src_edge_type = global_sample["net_input"]["mol_src_edge_type"]


        global_mol_padding_mask = global_mol_src_tokens.eq(self.mol_model.padding_idx)

        global_mol_x = self.mol_model.embed_tokens(global_mol_src_tokens)
        global_mol_graph_attn_bias = get_dist_features(
            global_mol_src_distance, global_mol_src_edge_type, "mol"
        )
        global_mol_outputs = self.mol_model.encoder(
            global_mol_x, padding_mask=global_mol_padding_mask, attn_mask=global_mol_graph_attn_bias
        )
        global_mol_encoder_rep = global_mol_outputs[0]
        
        global_mol_rep =  global_mol_encoder_rep[:,0,:]

        global_mol_emb = self.mol_project(global_mol_rep)

        global_mol_emb = global_mol_emb / global_mol_emb.norm(dim=1, keepdim=True)

        cat_mol_emb = torch.cat([mol_emb, global_mol_emb], dim=0)

        if 1==1:
            ba_predict = torch.mm(pocket_emb, cat_mol_emb.transpose(0,1))
            # bsz = ba_predict.shape[0]
        
            # pockets = np.array(pocket_list, dtype=str)
            # pockets = np.expand_dims(pockets, 1)
            # matrix1 = np.repeat(pockets, len(pockets), 1)
            # matrix2 = np.repeat(np.transpose(pockets), len(pockets), 0)
            # pocket_duplicate_matrix = matrix1==matrix2
            # pocket_duplicate_matrix = 1*pocket_duplicate_matrix
            # pocket_duplicate_matrix = torch.tensor(pocket_duplicate_matrix, dtype=ba_predict.dtype).cuda()
            # global_smi_list = global_sample["smi_name"]
            # mols = np.array(smi_list, dtype=str)
            # mol_global = np.array(global_smi_list, dtype=str)
            # mols = np.concatenate([mols, mol_global], axis=0)
            # mols = np.expand_dims(mols, 1)
            # matrix1 = np.repeat(mols, len(mols), 1)
            # matrix2 = np.repeat(np.transpose(mols), len(mols), 0)
            # mol_duplicate_matrix = matrix1==matrix2
            # mol_duplicate_matrix = 1*mol_duplicate_matrix
            # mol_duplicate_matrix = torch.tensor(mol_duplicate_matrix, dtype=ba_predict.dtype).cuda()

            # print(pocket_duplicate_matrix.shape, mol_duplicate_matrix.shape, ba_predict.shape)
            

            # onehot_labels = torch.eye(bsz).cuda()
            # indicater_matrix = pocket_duplicate_matrix + mol_duplicate_matrix - 2*onehot_labels
            ba_predict = ba_predict * self.logit_scale.exp().detach()
            #ba_predict = indicater_matrix * -1e6 + ba_predict
        else:
            ba_predict = torch.mm(pocket_emb, mol_emb.transpose(0,1))
            bsz = ba_predict.shape[0]
        
            pockets = np.array(pocket_list, dtype=str)
            pockets = np.expand_dims(pockets, 1)
            matrix1 = np.repeat(pockets, len(pockets), 1)
            matrix2 = np.repeat(np.transpose(pockets), len(pockets), 0)
            pocket_duplicate_matrix = matrix1==matrix2
            pocket_duplicate_matrix = 1*pocket_duplicate_matrix
            pocket_duplicate_matrix = torch.tensor(pocket_duplicate_matrix, dtype=ba_predict.dtype).cuda()
            mols = np.array(smi_list, dtype=str)
            mols = np.expand_dims(mols, 1)
            matrix1 = np.repeat(mols, len(mols), 1)
            matrix2 = np.repeat(np.transpose(mols), len(mols), 0)
            mol_duplicate_matrix = matrix1==matrix2
            mol_duplicate_matrix = 1*mol_duplicate_matrix
            mol_duplicate_matrix = torch.tensor(mol_duplicate_matrix, dtype=ba_predict.dtype).cuda()

            
            

            onehot_labels = torch.eye(bsz).cuda()
            indicater_matrix = pocket_duplicate_matrix + mol_duplicate_matrix - 2*onehot_labels
            ba_predict = ba_predict * self.logit_scale.exp().detach()
            ba_predict = indicater_matrix * -1e6 + ba_predict
        
        
        

        return ba_predict, self.logit_scale.exp() #_pocket, ba_predict_mol

    def set_num_updates(self, num_updates):
        """State from trainer to pass along to model at every update."""

        self._num_updates = num_updates

    def get_num_updates(self):
        return self._num_updates


@register_model("binding_affinity_hns")
class BindingAffinityHNSModel(BaseUnicoreModel):
    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        parser.add_argument(
            "--mol-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-encoder-layers",
            type=int,
            help="pocket encoder layers",
        )
        parser.add_argument(
            "--recycling",
            type=int,
            default=1,
            help="recycling nums of decoder",
        )

    def __init__(self, args, mol_dictionary, pocket_dictionary):
        super().__init__()
        unimol_docking_architecture(args)

        self.args = args
        self.mol_model = UniMolModel(args.mol, mol_dictionary)
        self.pocket_model = UniMolModel(args.pocket, pocket_dictionary)

        self.cross_distance_project = NonLinearHead(
            args.mol.encoder_embed_dim * 2 + args.mol.encoder_attention_heads, 1, "relu"
        )
        self.holo_distance_project = DistanceHead(
            args.mol.encoder_embed_dim + args.mol.encoder_attention_heads, "relu"
        )
        
        self.mol_project = NonLinearHead(
            args.mol.encoder_embed_dim, 128, "relu"
        )

        self.logit_scale = nn.Parameter(torch.ones([1], device="cuda") * np.log(1/0.07))
        print(self.logit_scale, self.logit_scale.requires_grad,self.logit_scale.is_leaf)
        
        self.global_dataset = None
        self.global_data=  None

        #self.mol_project = nn.Linear(args.mol.encoder_embed_dim, 512)

        '''
        
        self.mol_project = nn.Sequential(
            nn.Linear(args.mol.encoder_embed_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
        )

        '''
        
        self.pocket_project = NonLinearHead(
            args.pocket.encoder_embed_dim, 128, "relu"
        )
        
        #self.pocket_project = nn.Linear(args.pocket.encoder_embed_dim, 512)
        # '''
        # self.pocket_project = nn.Sequential(
        #     nn.Linear(args.pocket.encoder_embed_dim, 1024),
        #     nn.ReLU(),
        #     nn.Linear(1024, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, 128),
        # )
        # '''
        self.fuse_project = NonLinearHead(
            256, 1, "relu"
        )
        self.classification_head = nn.Sequential(
            nn.Linear(args.pocket.encoder_embed_dim + args.pocket.encoder_embed_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        return cls(args, task.dictionary, task.pocket_dictionary)



    def forward_once(
        self,
        mol_src_tokens,
        mol_src_distance,
        mol_src_edge_type,
        pocket_src_tokens,
        pocket_src_distance,
        pocket_src_edge_type,
        smi_list,
        pocket_list,
        mol_model,
        pocket_model,
        **kwargs
    ):
        def get_dist_features(dist, et, flag):
            if flag == "mol":
                n_node = dist.size(-1)
                gbf_feature = mol_model.gbf(dist, et)
                gbf_result = mol_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias
            else:
                n_node = dist.size(-1)
                gbf_feature = pocket_model.gbf(dist, et)
                gbf_result = pocket_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias

        mol_padding_mask = mol_src_tokens.eq(mol_model.padding_idx)

        mol_x = mol_model.embed_tokens(mol_src_tokens)
        mol_graph_attn_bias = get_dist_features(
            mol_src_distance, mol_src_edge_type, "mol"
        )
        mol_outputs = mol_model.encoder(
            mol_x, padding_mask=mol_padding_mask, attn_mask=mol_graph_attn_bias
        )
        mol_encoder_rep = mol_outputs[0]

        pocket_padding_mask = pocket_src_tokens.eq(pocket_model.padding_idx)
        pocket_x = pocket_model.embed_tokens(pocket_src_tokens)
        pocket_graph_attn_bias = get_dist_features(
            pocket_src_distance, pocket_src_edge_type, "pocket"
        )
        pocket_outputs = pocket_model.encoder(
            pocket_x, padding_mask=pocket_padding_mask, attn_mask=pocket_graph_attn_bias
        )
        pocket_encoder_rep = pocket_outputs[0]

        mol_rep =  mol_encoder_rep[:,0,:]
        pocket_rep = pocket_encoder_rep[:,0,:]

        mol_emb = self.mol_project(mol_rep)
        mol_emb = mol_emb / mol_emb.norm(dim=1, keepdim=True)
        pocket_emb = self.pocket_project(pocket_rep)
        pocket_emb = pocket_emb / pocket_emb.norm(dim=1, keepdim=True)
        
        
        ba_predict = torch.matmul(pocket_emb, torch.transpose(mol_emb, 0, 1))
        
        bsz = ba_predict.shape[0]
        
        pockets = np.array(pocket_list, dtype=str)
        pockets = np.expand_dims(pockets, 1)
        matrix1 = np.repeat(pockets, len(pockets), 1)
        matrix2 = np.repeat(np.transpose(pockets), len(pockets), 0)
        pocket_duplicate_matrix = matrix1==matrix2
        pocket_duplicate_matrix = 1*pocket_duplicate_matrix
        pocket_duplicate_matrix = torch.tensor(pocket_duplicate_matrix, dtype=ba_predict.dtype).cuda()
        
        mols = np.array(smi_list, dtype=str)
        mols = np.expand_dims(mols, 1)
        matrix1 = np.repeat(mols, len(mols), 1)
        matrix2 = np.repeat(np.transpose(mols), len(mols), 0)
        mol_duplicate_matrix = matrix1==matrix2
        mol_duplicate_matrix = 1*mol_duplicate_matrix
        mol_duplicate_matrix = torch.tensor(mol_duplicate_matrix, dtype=ba_predict.dtype).cuda()

        
        

        onehot_labels = torch.eye(bsz).cuda()
        indicater_matrix = pocket_duplicate_matrix + mol_duplicate_matrix - 2*onehot_labels
        
        #print(ba_predict.shape)
        ba_predict = ba_predict *  self.logit_scale.exp()
        ba_predict = indicater_matrix * -1e6 + ba_predict


        return mol_encoder_rep, pocket_encoder_rep
    

    def forward(
        self,
        mol_src_tokens,
        mol_src_distance,
        mol_src_edge_type,
        pocket_src_tokens,
        pocket_src_distance,
        pocket_src_edge_type,
        mol_src_tokens_hns=None,
        mol_src_distance_hns=None,
        mol_src_edge_type_hns=None,
        smi_list=None,
        pocket_list=None,
        encode=False,
        masked_tokens=None,
        features_only=True,
        is_train=True,
        **kwargs
    ):
        
        def get_dist_features(dist, et, flag):
            if flag == "mol":
                n_node = dist.size(-1)
                gbf_feature = self.mol_model.gbf(dist, et)
                gbf_result = self.mol_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias
            else:
                n_node = dist.size(-1)
                gbf_feature = self.pocket_model.gbf(dist, et)
                gbf_result = self.pocket_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias

        mol_padding_mask = mol_src_tokens.eq(self.mol_model.padding_idx)

        mol_x = self.mol_model.embed_tokens(mol_src_tokens)
        mol_graph_attn_bias = get_dist_features(
            mol_src_distance, mol_src_edge_type, "mol"
        )
        mol_outputs = self.mol_model.encoder(
            mol_x, padding_mask=mol_padding_mask, attn_mask=mol_graph_attn_bias
        )
        mol_encoder_rep = mol_outputs[0]

        pocket_padding_mask = pocket_src_tokens.eq(self.pocket_model.padding_idx)
        pocket_x = self.pocket_model.embed_tokens(pocket_src_tokens)
        pocket_graph_attn_bias = get_dist_features(
            pocket_src_distance, pocket_src_edge_type, "pocket"
        )
        pocket_outputs = self.pocket_model.encoder(
            pocket_x, padding_mask=pocket_padding_mask, attn_mask=pocket_graph_attn_bias
        )
        pocket_encoder_rep = pocket_outputs[0]

        mol_rep =  mol_encoder_rep[:,0,:]
        pocket_rep = pocket_encoder_rep[:,0,:]

        mol_emb = self.mol_project(mol_rep)
        mol_emb = mol_emb / mol_emb.norm(dim=1, keepdim=True)
        pocket_emb = self.pocket_project(pocket_rep)
        pocket_emb = pocket_emb / pocket_emb.norm(dim=1, keepdim=True)

        ba_predict = torch.mm(pocket_emb, mol_emb.transpose(0,1))

        ba_predict = ba_predict * self.logit_scale.exp()
        
        if not is_train:
            return ba_predict, ba_predict, self.logit_scale.exp()

        bsz = pocket_emb.shape[0]
        # random indices from self.global_dataset

        #subset_indices = np.random.choice(len(self.global_dataset), bsz, replace=False)
        #subset = torch.utils.data.Subset(self.global_dataset, subset_indices)
        #global_data = torch.utils.data.DataLoader(subset, batch_size=bsz, collate_fn=self.global_dataset.collater, num_workers=64)

        #global_data = torch.utils.data.DataLoader(self.global_dataset, batch_size=bsz, collate_fn=self.global_dataset.collater, shuffle=True)

        #print(global_sample["mol_len"])



        mol_padding_mask_hns = mol_src_tokens_hns.eq(self.mol_model.padding_idx)

        mol_x_hns = self.mol_model.embed_tokens(mol_src_tokens_hns)
        mol_graph_attn_bias_hns = get_dist_features(
            mol_src_distance_hns, mol_src_edge_type_hns, "mol"
        )
        mol_outputs_hns = self.mol_model.encoder(
            mol_x_hns, padding_mask=mol_padding_mask_hns, attn_mask=mol_graph_attn_bias_hns
        )
        mol_encoder_rep_hns = mol_outputs_hns[0]
        
        mol_rep_hns =  mol_encoder_rep_hns[:,0,:]

        mol_emb_hns = self.mol_project(mol_rep_hns)

        mol_emb_hns = mol_emb_hns / mol_emb_hns.norm(dim=1, keepdim=True)

        #cat_mol_emb = torch.cat([mol_emb, mol_emb_hns], dim=0)

        


        # get diagonal elements of ba_predict and keep dim
        ba_predict_diag = torch.diag(ba_predict).unsqueeze(1)


        ba_predict_hns = torch.sum(pocket_emb * mol_emb_hns, dim=1, keepdim=True)

        ba_predict_hns = torch.cat([ba_predict_hns, ba_predict_diag], dim=1)
        
        #ba_predict = torch.cat([ba_predict, ba_predict_hns], dim=1)



        #print(ba_predict_hns.shape)

        ba_predict_hns = ba_predict_hns * self.logit_scale.exp()
        

        return ba_predict, ba_predict_hns, self.logit_scale.exp() #_pocket, ba_predict_mol

    def set_num_updates(self, num_updates):
        """State from trainer to pass along to model at every update."""

        self._num_updates = num_updates

    def get_num_updates(self):
        return self._num_updates

@register_model("binding_affinity_colbert")
class BindingAffinityColbertModel(BaseUnicoreModel):
    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        parser.add_argument(
            "--mol-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-encoder-layers",
            type=int,
            help="pocket encoder layers",
        )
        parser.add_argument(
            "--recycling",
            type=int,
            default=1,
            help="recycling nums of decoder",
        )
        parser.add_argument(
            "--emb-dropout",
            type=float,
            metavar="D",
            help="dropout probability for embeddings",
        )
        parser.add_argument(
            "--dropout", type=float, metavar="D", help="dropout probability"
        )
        parser.add_argument(
            "--attention-dropout",
            type=float,
            metavar="D",
            help="dropout probability for attention weights",
        )

    def __init__(self, args, mol_dictionary, pocket_dictionary):
        super().__init__()
        unimol_docking_architecture(args)

        self.args = args
        self.mol_model = UniMolModel(args.mol, mol_dictionary)
        self.pocket_model = UniMolModel(args.pocket, pocket_dictionary)

        self.mol_colbert_model = UniMolModel(args.mol, mol_dictionary)
        self.pocket_colbert_model = UniMolModel(args.pocket, pocket_dictionary)


        self.cross_distance_project = NonLinearHead(
            args.mol.encoder_embed_dim * 2 + args.mol.encoder_attention_heads, 1, "relu"
        )
        self.holo_distance_project = DistanceHead(
            args.mol.encoder_embed_dim + args.mol.encoder_attention_heads, "relu"
        )
        
        self.mol_project_s = NonLinearHead(
            args.mol.encoder_embed_dim, 128, "relu"
        )
        self.mol_project_t = NonLinearHead(
            args.mol.encoder_embed_dim, 128, "relu"
        )

        self.logit_scale = nn.Parameter(torch.ones([1], device="cuda") * np.log(1/0.07))
        print(self.logit_scale, self.logit_scale.requires_grad,self.logit_scale.is_leaf)
        
        # self.mol_project_s = nn.Linear(args.mol.encoder_embed_dim, 512)
        # self.mol_project_t = nn.Linear(args.mol.encoder_embed_dim, 512)

        '''
        
        self.mol_project = nn.Sequential(
            nn.Linear(args.mol.encoder_embed_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
        )

        '''
        
        self.pocket_project_s = NonLinearHead(
            args.pocket.encoder_embed_dim, 128, "relu"
        )

        self.pocket_project_t = NonLinearHead(
            args.pocket.encoder_embed_dim, 128, "relu"
        )
        
        # self.pocket_project_s = nn.Linear(args.pocket.encoder_embed_dim, 512)
        # self.pocket_project_t = nn.Linear(args.pocket.encoder_embed_dim, 512)

        # '''
        # self.pocket_project = nn.Sequential(
        #     nn.Linear(args.pocket.encoder_embed_dim, 1024),
        #     nn.ReLU(),
        #     nn.Linear(1024, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, 128),
        # )
        # '''
        self.fuse_project = NonLinearHead(
            256, 1, "relu"
        )
        self.classification_head = nn.Sequential(
            nn.Linear(args.pocket.encoder_embed_dim + args.pocket.encoder_embed_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        return cls(args, task.dictionary, task.pocket_dictionary)





        

    def forward_once(
        self,
        mol_src_tokens,
        mol_src_distance,
        mol_src_edge_type,
        pocket_src_tokens,
        pocket_src_distance,
        pocket_src_edge_type,
        mol_model,
        pocket_model,
        **kwargs
    ):
        def get_dist_features(dist, et, flag):
            if flag == "mol":
                n_node = dist.size(-1)
                gbf_feature = mol_model.gbf(dist, et)
                gbf_result = mol_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias
            else:
                n_node = dist.size(-1)
                gbf_feature = pocket_model.gbf(dist, et)
                gbf_result = pocket_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                return graph_attn_bias

        mol_padding_mask = mol_src_tokens.eq(mol_model.padding_idx)

        mol_x = mol_model.embed_tokens(mol_src_tokens)
        mol_graph_attn_bias = get_dist_features(
            mol_src_distance, mol_src_edge_type, "mol"
        )
        mol_outputs = mol_model.encoder(
            mol_x, padding_mask=mol_padding_mask, attn_mask=mol_graph_attn_bias
        )
        mol_encoder_rep = mol_outputs[0]

        mol_encoder_rep_prev = mol_outputs[-1]

        pocket_padding_mask = pocket_src_tokens.eq(pocket_model.padding_idx)
        pocket_x = pocket_model.embed_tokens(pocket_src_tokens)
        pocket_graph_attn_bias = get_dist_features(
            pocket_src_distance, pocket_src_edge_type, "pocket"
        )
        pocket_outputs = pocket_model.encoder(
            pocket_x, padding_mask=pocket_padding_mask, attn_mask=pocket_graph_attn_bias
        )
        pocket_encoder_rep = pocket_outputs[0]
        pocket_encoder_rep_prev = pocket_outputs[-1]


        return mol_encoder_rep, pocket_encoder_rep, mol_encoder_rep_prev, pocket_encoder_rep_prev
    


        

    def remove_duplicate(self, sim_mx, pocket_list, smi_list):
        bsz = sim_mx.shape[0]
        pockets = np.array(pocket_list, dtype=str)
        pockets = np.expand_dims(pockets, 1)
        matrix1 = np.repeat(pockets, len(pockets), 1)
        matrix2 = np.repeat(np.transpose(pockets), len(pockets), 0)
        pocket_duplicate_matrix = matrix1==matrix2
        pocket_duplicate_matrix = 1*pocket_duplicate_matrix
        pocket_duplicate_matrix = torch.tensor(pocket_duplicate_matrix, dtype=sim_mx.dtype).cuda()
        
        mols = np.array(smi_list, dtype=str)
        mols = np.expand_dims(mols, 1)
        matrix1 = np.repeat(mols, len(mols), 1)
        matrix2 = np.repeat(np.transpose(mols), len(mols), 0)
        mol_duplicate_matrix = matrix1==matrix2
        mol_duplicate_matrix = 1*mol_duplicate_matrix
        mol_duplicate_matrix = torch.tensor(mol_duplicate_matrix, dtype=sim_mx.dtype).cuda()

        onehot_labels = torch.eye(bsz).cuda()
        indicater_matrix = pocket_duplicate_matrix + mol_duplicate_matrix - 2*onehot_labels
        #print(torch.sum(indicater_matrix))
        
        sim_mx = indicater_matrix * -1e6 + sim_mx
        return sim_mx

    # def late_interaction(self, c_len, q_len, c_emb, q_emb):
    #     c_len = torch.tensor(c_len).cuda() + 1
    #     q_len = torch.tensor(q_len).cuda() + 1
    #     raw_sim_matrix = torch.einsum('bik,tjk->btij', q_emb, c_emb)
    #     b,t,i,j = raw_sim_matrix.shape
    #     row_mask=torch.arange(0, i).cuda().unsqueeze(1).repeat(1, j).repeat(b,t,1,1)
    #     pocket_len_mask=q_len.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,t,i,j)
    #     row_mask=row_mask<pocket_len_mask

    #     col_mask=torch.arange(0, j).cuda().unsqueeze(0).repeat(i, 1).repeat(b,t,1,1)
    #     clen_mask=c_len.unsqueeze(1).unsqueeze(1).repeat(b,1,i,j)
    #     col_mask=col_mask<clen_mask
    #     mask=row_mask*col_mask
    #     masked_sim_mx=raw_sim_matrix*mask
    #     max_sim=masked_sim_mx.max(-1)[0]
    #     len_mx=q_len.unsqueeze(1).repeat(1,t)
    #     sim_mx=max_sim.sum(-1)/len_mx
    #     return sim_mx

    def late_interaction(self, c_len_, q_len_, c_emb, q_emb):
        c_len = c_len_ 
        q_len = q_len_ 
        raw_sim_matrix = torch.einsum('bik,tjk->btij', q_emb, c_emb)
        b,t,i,j = raw_sim_matrix.shape
        row_mask=torch.arange(0, i).cuda().unsqueeze(1).repeat(1, j).repeat(b,t,1,1)
        #print(row_mask.shape)
        q_len_mask=q_len.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,t,i,j)
        row_mask=row_mask<q_len_mask
        #print(row_mask[0,0])

        col_mask=torch.arange(0, j).cuda().unsqueeze(0).repeat(i, 1).repeat(b,t,1,1)
        clen_mask=c_len.unsqueeze(1).unsqueeze(1).repeat(b,1,i,j)
        col_mask=col_mask<clen_mask
        mask=row_mask*col_mask
        mask = ~mask * 1
        mask = mask * -1e6
        #print(mask[0,0,0])
        masked_sim_mx=raw_sim_matrix + mask
        #print(masked_sim_mx[0,0,0])

        # average of topk of masked_sim_mx
        topk_sim, _ = masked_sim_mx.topk(4, dim=-1)
        #print(topk_sim[0,0])
        topk_sim = torch.where(topk_sim < -1, torch.zeros_like(topk_sim), topk_sim)
        max_sim = topk_sim.mean(-1)
        
        # len_mx=q_len.unsqueeze(1).repeat(1,t)
        # sim_mx=max_sim.sum(-1)/len_mx

        max_sim, _ = max_sim.topk(4, dim=-1)
        sim_mx = max_sim.mean(-1)

        return sim_mx

    def forward(
        self,
        mol_src_tokens,
        mol_src_distance,
        mol_src_edge_type,
        pocket_src_tokens,
        pocket_src_distance,
        pocket_src_edge_type,
        mol_len,
        pocket_len,
        smi_list=None,
        pocket_list=None,
        encode=False,
        masked_tokens=None,
        features_only=True,
        is_train=True,
        **kwargs
    ):
        # mol_padding_mask = mol_src_tokens.eq(self.mol_model.padding_idx)
        # print(mol_padding_mask.shape[-1]-torch.sum(mol_padding_mask, dim=-1), mol_len)
        pocket_padding_mask = pocket_src_tokens.eq(self.pocket_model.padding_idx)
        #print(pocket_padding_mask.shape[-1]-torch.sum(pocket_padding_mask, dim=-1), pocket_len)

        # mol_encoder_rep_s, pocket_encoder_rep_s = self.forward_once(mol_src_tokens, 
        #                     mol_src_distance,
        #                     mol_src_edge_type,
        #                     pocket_src_tokens,
        #                     pocket_src_distance,
        #                     pocket_src_edge_type,
        #                     self.mol_model,
        #                     self.pocket_model,
        #                 )
        # mol_emb_s = self.mol_project_s(mol_encoder_rep_s[:,0,:])
        # mol_emb_s = mol_emb_s / mol_emb_s.norm(dim=-1, keepdim=True)
        # pocket_emb_s = self.pocket_project_s(pocket_encoder_rep_s[:,0,:])
        # pocket_emb_s = pocket_emb_s / pocket_emb_s.norm(dim=-1, keepdim=True)

        # if not is_train:
        #     return torch.sum(mol_emb_s*pocket_emb_s, dim=1)
        
        # sim_mx_s = torch.matmul(pocket_emb_s, torch.transpose(mol_emb_s, 0, 1))

        mol_encoder_rep_t, pocket_encoder_rep_t, mol_encoder_rep_t_prev, pocket_encoder_rep_t_prev = self.forward_once(mol_src_tokens, 
                            mol_src_distance,
                            mol_src_edge_type,
                            pocket_src_tokens,
                            pocket_src_distance,
                            pocket_src_edge_type,
                            self.mol_colbert_model,
                            self.pocket_colbert_model,
                        )

        mol_emb_t = self.mol_project_t(mol_encoder_rep_t)
        mol_emb_t = mol_emb_t / mol_emb_t.norm(dim=-1, keepdim=True)

        pocket_emb_t = self.pocket_project_t(pocket_encoder_rep_t)
        pocket_emb_t = pocket_emb_t / pocket_emb_t.norm(dim=-1, keepdim=True)

        #mol_emb_t_prev = self.mol_project_t(mol_encoder_rep_t_prev)
        mol_emb_t_prev = mol_encoder_rep_t_prev / mol_encoder_rep_t_prev.norm(dim=-1, keepdim=True)
        #pocket_emb_t_prev = self.pocket_project_t(mol_encoder_rep_t_prev)
        pocket_emb_t_prev = pocket_encoder_rep_t_prev / pocket_encoder_rep_t_prev.norm(dim=-1, keepdim=True)


        # mol_len = torch.tensor(mol_len).cuda() + 1
        # pocket_len = torch.tensor(pocket_len).cuda() + 1
        # raw_sim_matrix = torch.einsum('bik,tjk->btij', pocket_emb_t, mol_emb_t)
        # b,t,i,j = raw_sim_matrix.shape
        # row_mask=torch.arange(0, i).cuda().unsqueeze(1).repeat(1, j).repeat(b,t,1,1)
        # pocket_len_mask=pocket_len.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,t,i,j)
        # row_mask=row_mask<pocket_len_mask

        # col_mask=torch.arange(0, j).cuda().unsqueeze(0).repeat(i, 1).repeat(b,t,1,1)
        # clen_mask=mol_len.unsqueeze(1).unsqueeze(1).repeat(b,1,i,j)
        # col_mask=col_mask<clen_mask
        # mask=row_mask*col_mask
        # masked_sim_mx=raw_sim_matrix*mask
        # max_sim=masked_sim_mx.max(-1)[0]
        # len_mx=pocket_len.unsqueeze(1).repeat(1,t)
        # sim_mx_t=max_sim.sum(-1)/len_mx

        #sim_mx_t = max_sim.sum(-1)

        #print(sim_mx_t.shape)

        #sim_mx_t_mol = self.late_interaction(pocket_len, mol_len, pocket_emb_t[:,1:,:], mol_emb_t[:,1:,:])
        
        sim_mx_t_mol = self.late_interaction(pocket_len, mol_len, pocket_emb_t[:,1:,:], mol_emb_t[:,1:,:])
        sim_mx_t_pocket = self.remove_duplicate(sim_mx_t_mol.T, pocket_list, smi_list)
        #sim_mx_t_mol = self.remove_duplicate(sim_mx_t_mol, pocket_list, smi_list)
        sim_mx_t_mol = sim_mx_t_pocket.T
        sim_mx_s=torch.matmul(pocket_emb_t[:,0,:], torch.transpose(mol_emb_t[:,0,:], 0, 1))
        sim_mx_s = self.remove_duplicate(sim_mx_s, pocket_list, smi_list)


        sim_mx_s = sim_mx_s * self.logit_scale.exp().detach()
        sim_mx_t_pocket = sim_mx_t_pocket * self.logit_scale.exp().detach()
        sim_mx_t_mol = sim_mx_t_mol * self.logit_scale.exp().detach()
        return sim_mx_s, sim_mx_t_pocket, sim_mx_t_mol

    def set_num_updates(self, num_updates):
        """State from trainer to pass along to model at every update."""

        self._num_updates = num_updates

    def get_num_updates(self):
        return self._num_updates


class DistanceHead(nn.Module):
    def __init__(
        self,
        heads,
        activation_fn,
    ):
        super().__init__()
        self.dense = nn.Linear(heads, heads)
        self.layer_norm = nn.LayerNorm(heads)
        self.out_proj = nn.Linear(heads, 1)
        self.activation_fn = utils.get_activation_fn(activation_fn)

    def forward(self, x):
        bsz, seq_len, seq_len, _ = x.size()
        x[x == float("-inf")] = 0
        x = self.dense(x)
        x = self.activation_fn(x)
        x = self.layer_norm(x)
        x = self.out_proj(x).view(bsz, seq_len, seq_len)
        x = (x + x.transpose(-1, -2)) * 0.5
        return x




@register_model_architecture("binding_affinity", "binding_affinity")
def unimol_docking_architecture(args):

    parser = argparse.ArgumentParser()
    args.mol = parser.parse_args([])
    args.pocket = parser.parse_args([])

    args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15)
    args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512)
    args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048)
    args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64)
    args.mol.dropout = getattr(args, "mol_dropout", 0.1)
    args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.1)
    args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.1)
    args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0)
    args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0)
    args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512)
    args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu")
    args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh")
    args.mol.post_ln = getattr(args, "mol_post_ln", False)
    args.mol.masked_token_loss = -1.0
    args.mol.masked_coord_loss = -1.0
    args.mol.masked_dist_loss = -1.0
    args.mol.x_norm_loss = -1.0
    args.mol.delta_pair_repr_norm_loss = -1.0

    args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15)
    args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512)
    args.pocket.encoder_ffn_embed_dim = getattr(
        args, "pocket_encoder_ffn_embed_dim", 2048
    )
    args.pocket.encoder_attention_heads = getattr(
        args, "pocket_encoder_attention_heads", 64
    )
    args.pocket.dropout = getattr(args, "pocket_dropout", 0.1)
    args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.1)
    args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.1)
    args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0)
    args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0)
    args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512)
    args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu")
    args.pocket.pooler_activation_fn = getattr(
        args, "pocket_pooler_activation_fn", "tanh"
    )
    args.pocket.post_ln = getattr(args, "pocket_post_ln", False)
    args.pocket.masked_token_loss = -1.0
    args.pocket.masked_coord_loss = -1.0
    args.pocket.masked_dist_loss = -1.0
    args.pocket.x_norm_loss = -1.0
    args.pocket.delta_pair_repr_norm_loss = -1.0

    base_architecture(args)



@register_model_architecture("binding_affinity", "binding_affinity_test")
def unimol_docking_architecture(args):

    parser = argparse.ArgumentParser()
    args.mol = parser.parse_args([])
    args.pocket = parser.parse_args([])

    args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15)
    args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512)
    args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048)
    args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64)
    args.mol.dropout = getattr(args, "mol_dropout", 0.0)
    args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.0)
    args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.0)
    args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0)
    args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0)
    args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512)
    args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu")
    args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh")
    args.mol.post_ln = getattr(args, "mol_post_ln", False)
    args.mol.masked_token_loss = -1.0
    args.mol.masked_coord_loss = -1.0
    args.mol.masked_dist_loss = -1.0
    args.mol.x_norm_loss = -1.0
    args.mol.delta_pair_repr_norm_loss = -1.0

    args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15)
    args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512)
    args.pocket.encoder_ffn_embed_dim = getattr(
        args, "pocket_encoder_ffn_embed_dim", 2048
    )
    args.pocket.encoder_attention_heads = getattr(
        args, "pocket_encoder_attention_heads", 64
    )
    args.pocket.dropout = getattr(args, "pocket_dropout", 0.0)
    args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.0)
    args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.0)
    args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0)
    args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0)
    args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512)
    args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu")
    args.pocket.pooler_activation_fn = getattr(
        args, "pocket_pooler_activation_fn", "tanh"
    )
    args.pocket.post_ln = getattr(args, "pocket_post_ln", False)
    args.pocket.masked_token_loss = -1.0
    args.pocket.masked_coord_loss = -1.0
    args.pocket.masked_dist_loss = -1.0
    args.pocket.x_norm_loss = -1.0
    args.pocket.delta_pair_repr_norm_loss = -1.0

    base_architecture(args)

@register_model_architecture("binding_affinity_hns", "binding_affinity_hns")
def unimol_docking_architecture(args):

    parser = argparse.ArgumentParser()
    args.mol = parser.parse_args([])
    args.pocket = parser.parse_args([])

    args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15)
    args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512)
    args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048)
    args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64)
    args.mol.dropout = getattr(args, "mol_dropout", 0.1)
    args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.1)
    args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.1)
    args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0)
    args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0)
    args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512)
    args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu")
    args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh")
    args.mol.post_ln = getattr(args, "mol_post_ln", False)
    args.mol.masked_token_loss = -1.0
    args.mol.masked_coord_loss = -1.0
    args.mol.masked_dist_loss = -1.0
    args.mol.x_norm_loss = -1.0
    args.mol.delta_pair_repr_norm_loss = -1.0

    args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15)
    args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512)
    args.pocket.encoder_ffn_embed_dim = getattr(
        args, "pocket_encoder_ffn_embed_dim", 2048
    )
    args.pocket.encoder_attention_heads = getattr(
        args, "pocket_encoder_attention_heads", 64
    )
    args.pocket.dropout = getattr(args, "pocket_dropout", 0.1)
    args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.1)
    args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.1)
    args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0)
    args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0)
    args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512)
    args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu")
    args.pocket.pooler_activation_fn = getattr(
        args, "pocket_pooler_activation_fn", "tanh"
    )
    args.pocket.post_ln = getattr(args, "pocket_post_ln", False)
    args.pocket.masked_token_loss = -1.0
    args.pocket.masked_coord_loss = -1.0
    args.pocket.masked_dist_loss = -1.0
    args.pocket.x_norm_loss = -1.0
    args.pocket.delta_pair_repr_norm_loss = -1.0

    base_architecture(args)



@register_model_architecture("binding_affinity_hns", "binding_affinity_hns_test")
def unimol_docking_architecture(args):

    parser = argparse.ArgumentParser()
    args.mol = parser.parse_args([])
    args.pocket = parser.parse_args([])

    args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15)
    args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512)
    args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048)
    args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64)
    args.mol.dropout = getattr(args, "mol_dropout", 0.0)
    args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.0)
    args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.0)
    args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0)
    args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0)
    args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512)
    args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu")
    args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh")
    args.mol.post_ln = getattr(args, "mol_post_ln", False)
    args.mol.masked_token_loss = -1.0
    args.mol.masked_coord_loss = -1.0
    args.mol.masked_dist_loss = -1.0
    args.mol.x_norm_loss = -1.0
    args.mol.delta_pair_repr_norm_loss = -1.0

    args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15)
    args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512)
    args.pocket.encoder_ffn_embed_dim = getattr(
        args, "pocket_encoder_ffn_embed_dim", 2048
    )
    args.pocket.encoder_attention_heads = getattr(
        args, "pocket_encoder_attention_heads", 64
    )
    args.pocket.dropout = getattr(args, "pocket_dropout", 0.0)
    args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.0)
    args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.0)
    args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0)
    args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0)
    args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512)
    args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu")
    args.pocket.pooler_activation_fn = getattr(
        args, "pocket_pooler_activation_fn", "tanh"
    )
    args.pocket.post_ln = getattr(args, "pocket_post_ln", False)
    args.pocket.masked_token_loss = -1.0
    args.pocket.masked_coord_loss = -1.0
    args.pocket.masked_dist_loss = -1.0
    args.pocket.x_norm_loss = -1.0
    args.pocket.delta_pair_repr_norm_loss = -1.0

    base_architecture(args)


@register_model_architecture("binding_affinity_colbert", "binding_affinity_colbert")
def unimol_docking_architecture(args):

    parser = argparse.ArgumentParser()
    args.mol = parser.parse_args([])
    args.pocket = parser.parse_args([])

    args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15)
    args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512)
    args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048)
    args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64)
    args.mol.dropout = getattr(args, "mol_dropout", 0.1)
    args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.1)
    args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.1)
    args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0)
    args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0)
    args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512)
    args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu")
    args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh")
    args.mol.post_ln = getattr(args, "mol_post_ln", False)
    args.mol.masked_token_loss = -1.0
    args.mol.masked_coord_loss = -1.0
    args.mol.masked_dist_loss = -1.0
    args.mol.x_norm_loss = -1.0
    args.mol.delta_pair_repr_norm_loss = -1.0

    args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15)
    args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512)
    args.pocket.encoder_ffn_embed_dim = getattr(
        args, "pocket_encoder_ffn_embed_dim", 2048
    )
    args.pocket.encoder_attention_heads = getattr(
        args, "pocket_encoder_attention_heads", 64
    )
    args.pocket.dropout = getattr(args, "pocket_dropout", 0.1)
    args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.1)
    args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.1)
    args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0)
    args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0)
    args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512)
    args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu")
    args.pocket.pooler_activation_fn = getattr(
        args, "pocket_pooler_activation_fn", "tanh"
    )
    args.pocket.post_ln = getattr(args, "pocket_post_ln", False)
    args.pocket.masked_token_loss = -1.0
    args.pocket.masked_coord_loss = -1.0
    args.pocket.masked_dist_loss = -1.0
    args.pocket.x_norm_loss = -1.0
    args.pocket.delta_pair_repr_norm_loss = -1.0

    base_architecture(args)


@register_model_architecture("binding_affinity_colbert", "binding_affinity_colbert_test")
def unimol_docking_architecture(args):

    parser = argparse.ArgumentParser()
    args.mol = parser.parse_args([])
    args.pocket = parser.parse_args([])

    args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15)
    args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512)
    args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048)
    args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64)
    args.mol.dropout = getattr(args, "mol_dropout", 0.0)
    args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.0)
    args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.0)
    args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0)
    args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0)
    args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512)
    args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu")
    args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh")
    args.mol.post_ln = getattr(args, "mol_post_ln", False)
    args.mol.masked_token_loss = -1.0
    args.mol.masked_coord_loss = -1.0
    args.mol.masked_dist_loss = -1.0
    args.mol.x_norm_loss = -1.0
    args.mol.delta_pair_repr_norm_loss = -1.0

    args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15)
    args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512)
    args.pocket.encoder_ffn_embed_dim = getattr(
        args, "pocket_encoder_ffn_embed_dim", 2048
    )
    args.pocket.encoder_attention_heads = getattr(
        args, "pocket_encoder_attention_heads", 64
    )
    args.pocket.dropout = getattr(args, "pocket_dropout", 0.0)
    args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.0)
    args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.0)
    args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0)
    args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0)
    args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512)
    args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu")
    args.pocket.pooler_activation_fn = getattr(
        args, "pocket_pooler_activation_fn", "tanh"
    )
    args.pocket.post_ln = getattr(args, "pocket_post_ln", False)
    args.pocket.masked_token_loss = -1.0
    args.pocket.masked_coord_loss = -1.0
    args.pocket.masked_dist_loss = -1.0
    args.pocket.x_norm_loss = -1.0
    args.pocket.delta_pair_repr_norm_loss = -1.0

    base_architecture(args)


@register_model_architecture("binding_affinity_ns", "binding_affinity_ns")
def unimol_docking_architecture(args):

    parser = argparse.ArgumentParser()
    args.mol = parser.parse_args([])
    args.pocket = parser.parse_args([])

    args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15)
    args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512)
    args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048)
    args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64)
    args.mol.dropout = getattr(args, "mol_dropout", 0.1)
    args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.1)
    args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.1)
    args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0)
    args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0)
    args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512)
    args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu")
    args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh")
    args.mol.post_ln = getattr(args, "mol_post_ln", False)
    args.mol.masked_token_loss = -1.0
    args.mol.masked_coord_loss = -1.0
    args.mol.masked_dist_loss = -1.0
    args.mol.x_norm_loss = -1.0
    args.mol.delta_pair_repr_norm_loss = -1.0

    args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15)
    args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512)
    args.pocket.encoder_ffn_embed_dim = getattr(
        args, "pocket_encoder_ffn_embed_dim", 2048
    )
    args.pocket.encoder_attention_heads = getattr(
        args, "pocket_encoder_attention_heads", 64
    )
    args.pocket.dropout = getattr(args, "pocket_dropout", 0.1)
    args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.1)
    args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.1)
    args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0)
    args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0)
    args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512)
    args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu")
    args.pocket.pooler_activation_fn = getattr(
        args, "pocket_pooler_activation_fn", "tanh"
    )
    args.pocket.post_ln = getattr(args, "pocket_post_ln", False)
    args.pocket.masked_token_loss = -1.0
    args.pocket.masked_coord_loss = -1.0
    args.pocket.masked_dist_loss = -1.0
    args.pocket.x_norm_loss = -1.0
    args.pocket.delta_pair_repr_norm_loss = -1.0

    base_architecture(args)


@register_model_architecture("binding_affinity_ns", "binding_affinity_ns_test")
def unimol_docking_architecture(args):

    parser = argparse.ArgumentParser()
    args.mol = parser.parse_args([])
    args.pocket = parser.parse_args([])

    args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15)
    args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512)
    args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048)
    args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64)
    args.mol.dropout = getattr(args, "mol_dropout", 0.0)
    args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.0)
    args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.0)
    args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0)
    args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0)
    args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512)
    args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu")
    args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh")
    args.mol.post_ln = getattr(args, "mol_post_ln", False)
    args.mol.masked_token_loss = -1.0
    args.mol.masked_coord_loss = -1.0
    args.mol.masked_dist_loss = -1.0
    args.mol.x_norm_loss = -1.0
    args.mol.delta_pair_repr_norm_loss = -1.0

    args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15)
    args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512)
    args.pocket.encoder_ffn_embed_dim = getattr(
        args, "pocket_encoder_ffn_embed_dim", 2048
    )
    args.pocket.encoder_attention_heads = getattr(
        args, "pocket_encoder_attention_heads", 64
    )
    args.pocket.dropout = getattr(args, "pocket_dropout", 0.0)
    args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.0)
    args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.0)
    args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0)
    args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0)
    args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512)
    args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu")
    args.pocket.pooler_activation_fn = getattr(
        args, "pocket_pooler_activation_fn", "tanh"
    )
    args.pocket.post_ln = getattr(args, "pocket_post_ln", False)
    args.pocket.masked_token_loss = -1.0
    args.pocket.masked_coord_loss = -1.0
    args.pocket.masked_dist_loss = -1.0
    args.pocket.x_norm_loss = -1.0
    args.pocket.delta_pair_repr_norm_loss = -1.0

    base_architecture(args)