# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import logging

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from unicore import utils
from unicore.data import Dictionary
from unicore.models import (BaseUnicoreModel, register_model,
                            register_model_architecture)
from unimol.models.unimol import ClassificationHead
from unicore.modules import LayerNorm
import unicore


from .transformer_encoder_with_pair import TransformerEncoderWithPair
from .unimol import NonLinearHead, UniMolModel, base_architecture, LinearHead

logger = logging.getLogger(__name__)


@register_model("pocket_matching")
class PocketMatchingModel(BaseUnicoreModel):
    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        parser.add_argument(
            "--mol-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-pooler-dropout",
            type=float,
            metavar="D",
            help="dropout probability in the masked_lm pooler layers",
        )
        parser.add_argument(
            "--pocket-encoder-layers",
            type=int,
            help="pocket encoder layers",
        )
        parser.add_argument(
            "--recycling",
            type=int,
            default=1,
            help="recycling nums of decoder",
        )


    def __init__(self, args, mol_dictionary, pocket_dictionary):
        super().__init__()
        unimol_docking_architecture(args)

        self.args = args
        self.pocket_model = UniMolModel(args.pocket, pocket_dictionary)
        

        
        self.pocket_project = NonLinearHead(
            args.pocket.encoder_embed_dim, 128, "relu"
        )
        
        self.pocket_project_fake = NonLinearHead(
            args.pocket.encoder_embed_dim, 128, "relu"
        )

        #self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.linear_prob = nn.Linear(args.pocket.encoder_embed_dim, 128)
        # self.linear_prob =  NonLinearHead(
        #     args.pocket.encoder_embed_dim, 128, "relu"
        # )

        

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        return cls(args, task.pocket_dictionary, task.pocket_dictionary)

    def forward(
        self,
        pocket_src_tokens_a,
        pocket_src_distance_a,
        pocket_src_edge_type_a,
        pocket_src_tokens_b,
        pocket_src_distance_b,
        pocket_src_edge_type_b,
        encode=False,
        masked_tokens=None,
        features_only=True,
        is_train=True,
        **kwargs
    ):
        def get_dist_features(dist, et):
            n_node = dist.size(-1)
            gbf_feature = self.pocket_model.gbf(dist, et)
            gbf_result = self.pocket_model.gbf_proj(gbf_feature)
            graph_attn_bias = gbf_result
            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
            return graph_attn_bias

        # pocket_a

        pocket_padding_mask_a = pocket_src_tokens_a.eq(self.pocket_model.padding_idx)
        pocket_x_a = self.pocket_model.embed_tokens(pocket_src_tokens_a)
        pocket_graph_attn_bias_a = get_dist_features(
            pocket_src_distance_a, pocket_src_edge_type_a
        )
        pocket_outputs_a = self.pocket_model.encoder(
            pocket_x_a, padding_mask=pocket_padding_mask_a, attn_mask=pocket_graph_attn_bias_a
        )
        pocket_encoder_rep_a = pocket_outputs_a[0]


        # pocket b
        pocket_padding_mask_b = pocket_src_tokens_b.eq(self.pocket_model.padding_idx)
        pocket_x_b = self.pocket_model.embed_tokens(pocket_src_tokens_b)
        pocket_graph_attn_bias_b = get_dist_features(
            pocket_src_distance_b, pocket_src_edge_type_b
        )
        pocket_outputs_b = self.pocket_model.encoder(
            pocket_x_b, padding_mask=pocket_padding_mask_b, attn_mask=pocket_graph_attn_bias_b
        )
        pocket_encoder_rep_b = pocket_outputs_b[0]



        pocket_rep_a = pocket_encoder_rep_a[:,0,:]

        

        pocket_rep_b = pocket_encoder_rep_b[:,0,:]

        pocket_emb_a = self.pocket_project(pocket_rep_a)
        pocket_emb_a = pocket_emb_a / pocket_emb_a.norm(dim=1, keepdim=True)

        pocket_emb_b = self.pocket_project(pocket_rep_b)
        pocket_emb_b = pocket_emb_b / pocket_emb_b.norm(dim=1, keepdim=True)


        pocket_rep_a = pocket_rep_a / pocket_rep_a.norm(dim=1, keepdim=True)
        pocket_rep_b = pocket_rep_b / pocket_rep_b.norm(dim=1, keepdim=True)

        pocket_a_linear = self.linear_prob(pocket_rep_a.detach())
        pocket_a_linear = pocket_a_linear / pocket_a_linear.norm(dim=1, keepdim=True)

        pocket_b_linear = self.linear_prob(pocket_rep_b.detach())
        pocket_b_linear = pocket_b_linear / pocket_b_linear.norm(dim=1, keepdim=True)
        #return torch.sum(pocket_emb_a * pocket_emb_b, dim=1)
        #return torch.sum(pocket_a_linear * pocket_b_linear, dim=1)
        #res = self.linear_prob(torch.cat([pocket_rep_a, pocket_rep_b], dim=1))
        #print(res.shape)
        #res = res.squeeze(1)
        #print(res.shape)
        #print(res)
        #return res

        return torch.sum(pocket_rep_a * pocket_rep_b, dim=1)
        #return torch.sum(pocket_a_linear * pocket_b_linear, dim=1)
        return torch.sum(pocket_emb_a * pocket_emb_b, dim=1) #* self.logit_scale.exp()



    def set_num_updates(self, num_updates):
        """State from trainer to pass along to model at every update."""

        self._num_updates = num_updates

    def get_num_updates(self):
        return self._num_updates











@register_model_architecture("pocket_matching", "pocket_matching")
def unimol_docking_architecture(args):

    parser = argparse.ArgumentParser()
    args.mol = parser.parse_args([])
    args.pocket = parser.parse_args([])

    args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15)
    args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512)
    args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048)
    args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64)
    args.mol.dropout = getattr(args, "mol_dropout", 0.1)
    args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.1)
    args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.1)
    args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0)
    args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0)
    args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512)
    args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu")
    args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh")
    args.mol.post_ln = getattr(args, "mol_post_ln", False)
    args.mol.masked_token_loss = -1.0
    args.mol.masked_coord_loss = -1.0
    args.mol.masked_dist_loss = -1.0
    args.mol.x_norm_loss = -1.0
    args.mol.delta_pair_repr_norm_loss = -1.0

    args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15)
    args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512)
    args.pocket.encoder_ffn_embed_dim = getattr(
        args, "pocket_encoder_ffn_embed_dim", 2048
    )
    args.pocket.encoder_attention_heads = getattr(
        args, "pocket_encoder_attention_heads", 64
    )
    args.pocket.dropout = getattr(args, "pocket_dropout", 0.1)
    args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.1)
    args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.1)
    args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0)
    args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0)
    args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512)
    args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu")
    args.pocket.pooler_activation_fn = getattr(
        args, "pocket_pooler_activation_fn", "tanh"
    )
    args.pocket.post_ln = getattr(args, "pocket_post_ln", False)
    args.pocket.masked_token_loss = -1.0
    args.pocket.masked_coord_loss = -1.0
    args.pocket.masked_dist_loss = -1.0
    args.pocket.x_norm_loss = -1.0
    args.pocket.delta_pair_repr_norm_loss = -1.0

    base_architecture(args)



