import torch.nn as nn
from models.protein_encoder import ProteinEncoder
from models.ptransGPT_cross import Protein_Translator

class ProteinModel(nn.Module):
    def __init__(self, pos_grad=False, down_t=2, stride_t=2, depth=3,  kernel_size=15, n_token=100, protein_vocab_size=30, 
                latent_conv_dim=256, latent_trans_dim=256, nb_layer=4, nb_head=4, text_vocab_size=30522, embed_dim=256, 
                protein_dim=256, block_size=52, num_layers=4, n_head=8, drop_out_rate=0.1, fc_rate=4):
        super().__init__()
        self.protein_encoder = ProteinEncoder(pos_grad=pos_grad, 
                                              down_t=down_t, 
                                              stride_t=stride_t, 
                                              depth=depth,  
                                              kernel_size=kernel_size, 
                                              n_token=n_token, 
                                              protein_vocab_size=protein_vocab_size,
                                              latent_conv_dim=latent_conv_dim, 
                                              latent_trans_dim=latent_trans_dim, 
                                              nb_layer=nb_layer, 
                                              nb_head=nb_head
                                              )
        self.protein_translator = Protein_Translator(text_vocab_size=text_vocab_size, 
                                                     embed_dim=embed_dim, 
                                                     protein_dim=protein_dim, 
                                                     block_size=block_size, 
                                                     num_layers=num_layers, 
                                                     n_head=n_head, 
                                                     drop_out_rate=drop_out_rate, 
                                                     fc_rate=fc_rate
                                                     )
    
    def forward(self, protein_idx, function_idx):
        global_feature, protein_feature = self.protein_encoder(protein_idx)
        function_logits = self.protein_translator(function_idx, protein_feature)
        return global_feature, protein_feature, function_logits
    
    
def get_model(args):
    net = ProteinModel( pos_grad=args.pos_grad_pe, 
                        down_t=args.down_t, 
                        stride_t=args.stride_t, 
                        depth=args.depth,  
                        kernel_size=args.kernel_size, 
                        n_token=args.seq_len_max,
                        protein_vocab_size=args.protein_vocab_size,
                        latent_conv_dim=args.latent_conv_dim, 
                        latent_trans_dim=args.latent_trans_dim, 
                        nb_layer=args.nb_layer_pe, 
                        nb_head=args.nb_head_pe,
                        text_vocab_size=args.text_vocab_size, 
                        embed_dim=args.embed_dim_pt, 
                        protein_dim=args.protein_dim, 
                        block_size=args.block_size, 
                        num_layers=args.num_layers, 
                        n_head=args.n_head_pt, 
                        drop_out_rate=args.drop_out_rate, 
                        fc_rate=args.ff_rate
                        )
    return net
