import torch
from torch import nn
from torch.nn import LayerNorm

from fairseq.models import FairseqDecoder

from .modules import gelu


class ProteinCoordinatePrediction(FairseqDecoder):
    def __init__(self, args, embed_dim, label_num, alphabet):
        super().__init__(alphabet)
        self.conv = nn.Conv1d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=3)
        self.dense = nn.Linear(embed_dim , int(embed_dim / 2))
        self.layer_norm = LayerNorm(int(embed_dim / 2))
        self.classifier = nn.Linear(int(embed_dim / 2), label_num)

    def forward(self, prev_result):
        input = prev_result.permute(0, 2, 1)
        x = self.conv(input).permute(0, 2, 1)
        x = self.dense(x)
        x = gelu(x)
        x = self.layer_norm(x)
        return self.classifier(x)


class ProteinInteractionPrediction(FairseqDecoder):
    def __init__(self, args, embed_dim, alphabet):
        super().__init__(alphabet)
        self.alphabet = alphabet
        self.embed_dim = embed_dim
        self.input_fc = nn.Linear(self.embed_dim, self.embed_dim)
        self.hidden_fc = nn.Linear(self.embed_dim, self.embed_dim)
        self.layer_norm = LayerNorm(embed_dim)
        self.output_fc = nn.Linear(self.embed_dim*2, int(self.embed_dim/2))
        self.half_output_fc = nn.Linear(int(self.embed_dim/2), int(self.embed_dim/8))
        self.classifier = nn.Linear(int(self.embed_dim/8), 2)

    def forward(self, prev_result, tokens):
        x = self.hidden_fc(gelu(self.input_fc(prev_result)))
        x = self.layer_norm(x)
        protein1 = x[0, :sum(tokens[0].ne(self.alphabet.padding_idx)), :]
        protein2 = x[1, :sum(tokens[1].ne(self.alphabet.padding_idx)), :]
        rel = torch.bmm(protein1, protein2.transpose(-2, -1))
        rel = torch.special.expit(rel)
        x = gelu(self.output_fc(torch.cat([torch.sum(torch.bmm(rel.transpose(0, 1), protein1), axis=0),torch.sum(torch.bmm(rel, protein2), axis=0)])))
        x = self.classifier(gelu(self.half_output_fc(x)))
        return x
