import torch
import torch.nn.functional as F
import torch.nn as nn
from critic_utils import *
import time
from multiprocessing import Pool
from critic_nn_layer import GAT_gate

#N_atom_features = 28
protein_features = 31
ligand_features = 43

class SPgnn_reg(torch.nn.Module):
    def __init__(self):
        super(SPgnn_reg, self).__init__()
        n_graph_layer = 4
        d_graph_layer = 140
        n_FC_layer = 4
        d_FC_layer = 128
        self.dropout_rate = 0.0
  
        #self.batch_size = batch_size
        self.layers1 = [d_graph_layer for i in range(n_graph_layer+1)]
        self.gconvprotein = nn.ModuleList([GAT_gate(self.layers1[i], self.layers1[i+1]) for i in range(len(self.layers1)-1)]) 
        self.gconvligand = nn.ModuleList([GAT_gate(self.layers1[i], self.layers1[i+1]) for i in range(len(self.layers1)-1)]) 
        #print("self.layers1: ", self.layers1[-1]*2, d_FC_layer)
        self.FC = nn.ModuleList([nn.Linear(self.layers1[-1]*2, d_FC_layer) if i==0 else
                                 nn.Linear(d_FC_layer, 1) if i==n_FC_layer-1 else
                                 nn.Linear(d_FC_layer, d_FC_layer) for i in range(n_FC_layer)])
        
        self.densaft = nn.Linear(32,self.layers1[-1]) 
        self.mu = nn.Parameter(torch.Tensor([4.0]).float())
        self.dev = nn.Parameter(torch.Tensor([1.0]).float())
        #self.embede = nn.Linear(2*N_atom_features, d_graph_layer, bias = False)
        self.embede_protein = nn.Linear(protein_features, d_graph_layer, bias = False)
        
        self.embede_ligand = nn.Linear(ligand_features, d_graph_layer, bias = False)

    def embede_graph_protein(self, data):
        c_hs1, c_hs2, c_adjs1, c_adjs2, c_valid = data
        c_hs2 = self.embede_protein(c_hs2)
        hs_size = c_hs2.size()
        c_adjs2 = torch.exp(-torch.pow(c_adjs2-self.mu.expand_as(c_adjs2), 2)/self.dev) 
        regularization = torch.empty(len(self.gconvprotein), device=c_hs2.device)

        for k in range(len(self.gconvprotein)):
            #c_hs1 = self.gconv1[k](c_hs, c_adjs1)
            c_hs2 = self.gconvprotein[k](c_hs2, c_adjs2)
            c_hs2 = c_hs2
            c_hs2 = F.dropout(c_hs2, p=self.dropout_rate, training=self.training)
        #print("Protein H2: ", c_hs2.shape)
        #print("Valid shape: ", c_valid.shape)
        #c_hs2 = c_hs2*c_valid.unsqueeze(-1).repeat(1, 1, c_hs2.size(-1))
        c_hs2 = c_hs2.sum(1)
        #print("c_hs2 sum: ", c_hs2.shape)
        return c_hs2

    def embede_graph_ligand(self, data):
        c_hs1, c_hs2, c_adjs1, c_adjs2, c_valid = data
        #print("c_hs1: ", c_hs1.shape)
        c_hs1 = self.embede_ligand(c_hs1)
        hs_size = c_hs1.size()
        c_adjs1 = torch.exp(-torch.pow(c_adjs1-self.mu.expand_as(c_adjs1), 2)/self.dev)
        regularization = torch.empty(len(self.gconvligand), device=c_hs1.device)
        #print("c_hs1 shape",c_hs1.shape)
        #print("c_adjs1 shape",c_adjs1.shape)
        for k in range(len(self.gconvligand)):
            c_hs1 = self.gconvligand[k](c_hs1, c_adjs1)
            #c_hs2 = self.gconv1[k](c_hs, c_adjs2)
            c_hs1 = c_hs1
            c_hs1 = F.dropout(c_hs1, p=self.dropout_rate, training=self.training)
        #c_hs1 = c_hs1*c_valid.unsqueeze(-1).repeat(1, 1, c_hs1.size(-1))
        c_hs1 = c_hs1.sum(1)
        return c_hs1

    def fully_connected(self, c_hs):
        regularization = torch.empty(len(self.FC)*1-1, device=c_hs.device)
        #print("fully_connected self.FC: ", len(self.FC))
        for k in range(len(self.FC)):
            #c_hs = self.FC[k](c_hs)
            if k<len(self.FC)-1:
                #print("slef.FC: ", len(self.FC)-1)
                #print("c_hs: ", c_hs.shape)
                c_hs = self.FC[k](c_hs)
                c_hs = F.dropout(c_hs, p=self.dropout_rate, training=self.training)
                c_hs = F.relu(c_hs)
            else:
                c_hs = self.FC[k](c_hs)

        #c_hs = torch.sigmoid(c_hs)

        return c_hs

    def train_model(self, data): #change name to forward if getting embeddings
        #embede a graph to a vector
        c_hs2 = self.embede_graph_protein(data)
        c_hs1 = self.embede_graph_ligand(data)
        c_hs3 = torch.cat((c_hs1, c_hs2), 1)
        #c_hs3 = c_hs3.resize(c_hs3.shape[0], self.batch_size)
        #c_hs3 = c_hs3.view(-1).resize((self.batch_size))
        #print("c_hs3.shape: ", c_hs3.shape)
        #c_hs3 = nn.functional.upsample(c_hs3, self.batch_size,scale_factor=None) 
        #fully connected NN
        c_hs = self.fully_connected(c_hs3)
        c_hs = c_hs.view(-1) 
        #print("c_hs: ", c_hs.shape) 
        #note that if you don't use concrete dropout, regularization 1-2 is zero
        return c_hs
    
    def test_model(self,data1 ):
        c_hs2 = self.embede_graph_protein(data1)
        c_hs1 = self.embede_graph_ligand(data1)
        c_hs3 = torch.cat((c_hs1, c_hs2), 0)
        c_hs = self.fully_connected(c_hs3)
        c_hs = c_hs.view(-1)
        return c_hs

    def forward(self, data):
        c_hs2 = self.embede_graph_protein(data)
        c_hs1 = self.embede_graph_ligand(data)
        c_hs3 = torch.cat((c_hs1, c_hs2), 1)
        c_hs = self.fully_connected(c_hs3)
        c_hs = c_hs.view(-1)
        return c_hs
