import numpy as np
import torch
import torch.nn as nn
from imports import *

class Protein_dock(nn.Module):
    def __init__(self,popSize=128,dim=12,training=True,test='1ATN_7'):
        super().__init__()
        self.dim=dim
        self.popSize=popSize
        self.init()
        self.training=training
        self.test=test
    
    def init(self):
        self.scoor_init, self.sq, self.se, self.sr, self.sbasis, self.seval = self.data_loader()  
        self.natoms=100
        self.dtype=torch.float32
        self.coor_init=torch.from_numpy(self.scoor_init).to(DEVICE)
        self.q=torch.from_numpy(self.sq).to(DEVICE)
        self.e=torch.from_numpy(self.se).to(DEVICE)
        self.r=torch.from_numpy(self.sr).to(DEVICE)
        self.basis=torch.from_numpy(self.sbasis).to(DEVICE)
        self.eigval=torch.from_numpy(self.seval).to(DEVICE)
        self.eigval=1.0/torch.sqrt(self.eigval)
        self.softmax=nn.Softmax(dim=-1)
    
    def set_training(self,x):
        self.training=x
    
    def set_testfun(self,x='1ATN_7'):
        self.test=x
        self.init()
    
    def get_len(self):
        return self.scoor_init.shape[0]
    
    def data_loader(self):
        scoor_init=[]
        sq=[]
        se=[]
        sr=[]
        sbasis=[]
        seval=[]
        if self.training:
            protein_list = [
                            '1ATN','1AVX','1AY7','1BJ1','1BVN','1CGI','1DFJ','1EAW',
                            '1EWY','1EZU','1GRN','1IBR','1IJK','1IQD','1JPS','1KXQ','1M10',
                            '1MAH','1N8O','1PPE','1R0R','1XQS','2B42','2C0L','2HRK',
                            ] 
            
            #28个种蛋白质

            for i in range(len(protein_list)):

                #if(i+1 ==len(protein_list)):
                #    n=9
                #else:
                #    n=6
                n=6

                for j in range(1,n):
                    x = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/coor_init"))
                    q = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/q"))
                    e = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/e"))
                    r = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/r"))
                    basis = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/basis"))
                    eigval = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/eigval"))
                    
                    #print (x.shape, q.shape, e.shape, r.shape, basis.shape, eigval.shape)
                    

                    q=np.tile(q, (1, 1))
                    e=np.tile(e, (1,1))

                    q = np.matmul(q.T, q)
                    e = np.sqrt(np.matmul(e.T, e))
                    r = (np.tile(r, (len(r), 1)) + np.tile(r, (len(r), 1)).T)/2


                    scoor_init.append(x)
                    sq.append(q)
                    se.append(e)
                    sr.append(r)
                    sbasis.append(basis)
                    seval.append(eigval)
            scoor_init = np.array(scoor_init)
            sq = np.array(sq)
            se = np.array(se)
            sr = np.array(sr)
            sbasis = np.array(sbasis)
            seval = np.array(seval)
            print (sq.shape, se.shape, seval.shape)
            
        else:
            
            
            #28个种蛋白质
            protein_list = [
                            self.test
                            ] 
            for i in (protein_list):
                x = np.float32(np.loadtxt("data_protein/"+i+"/coor_init"))
                q = np.float32(np.loadtxt("data_protein/"+i+"/q"))
                e = np.float32(np.loadtxt("data_protein/"+i+"/e"))
                r = np.float32(np.loadtxt("data_protein/"+i+"/r"))
                basis = np.float32(np.loadtxt("data_protein/"+i+"/basis"))
                eigval = np.float32(np.loadtxt("data_protein/"+i+"/eigval"))
                q=np.tile(q, (1, 1))
                e=np.tile(e, (1,1))
                q = np.matmul(q.T, q)
                e = np.sqrt(np.matmul(e.T, e))
                r = (np.tile(r, (len(r), 1)) + np.tile(r, (len(r), 1)).T)/2
                scoor_init.append(x)
                sq.append(q)
                se.append(e)
                sr.append(r)
                sbasis.append(basis)
                seval.append(eigval)

        

            scoor_init = np.array(scoor_init)
            sq = np.array(sq)
            se = np.array(se)
            sr = np.array(sr)
            sbasis = np.array(sbasis)
            seval = np.array(seval)
            scoor_init=np.tile(scoor_init,(self.popSize,1,1))
            sq=np.tile(sq,(self.popSize,1,1))
            se=np.tile(se,(self.popSize,1,1))
            sr=np.tile(sr,(self.popSize,1,1))
            sbasis=np.tile(sbasis,(self.popSize,1,1))
            seval=np.tile(seval,(self.popSize,1))
            
            
            print (sq.shape, se.shape, seval.shape)
        return scoor_init, sq, se, sr, sbasis, seval
        
    def forward(self,x,idx):
       x=x.permute(1,0,2)
       product=torch.squeeze(torch.unsqueeze(x*self.eigval[idx],1)@self.basis[idx])
       new_coor=torch.reshape(product,(x.shape[0],self.coor_init[idx].shape[-2],self.coor_init[idx].shape[-1]))+self.coor_init[idx]
       p2=torch.sum(new_coor*new_coor,2)
       p3=new_coor@new_coor.permute(0,2,1)
       p2=torch.unsqueeze(p2,-1)
       stmp=p2-2*p3+torch.permute(p2,(0,2,1))+0.01
       pair_dis=torch.pow(stmp,1/2)
       c7_small=(pair_dis<7).float()
       c7=(pair_dis>7).float()
       c0=(pair_dis>0.1).float()
       c9=(pair_dis<9).float()
       c79=c7*c9*c0
       c7_small=c7_small*c0
       tmp=[torch.eye(self.natoms,self.natoms).to(DEVICE) for i in range(x.shape[0])]
       tmp=torch.stack(tmp,dim=0)
       pair_dis+=tmp
       coeff=self.q[idx]/(4*pair_dis)+torch.sqrt(self.e[idx])*((self.r[idx]/pair_dis)**12-(self.r[idx]/pair_dis)**6)
       a1=c7_small*coeff
       c1= a1*10+10*c79*coeff*((9-pair_dis)**2*(-12+2*pair_dis)/ 8)
       res=torch.sum(c1,1)
       energy=torch.mean(res,-1)-7000
       energy=energy.view(-1,energy.shape[-1])
       return energy
       
  
if __name__=='__main__':
    x=torch.randn((10,1,12)).to(DEVICE)
    problem=Protein_dock()
    y=problem(x,0)
    pass

