import numpy as np

class Protein_dock():
    def __init__(self,popSize=128,dim=12):
        super().__init__()
        self.dim=dim
        self.popSize=popSize
        self.scoor_init, self.sq, self.se, self.sr, self.sbasis, self.seval = self.data_loader()  
        self.natoms=100
        self.coor_init=self.scoor_init
        self.q=self.sq
        self.e=self.se
        self.r=self.sr
        self.basis=self.sbasis
        self.eigval=self.seval
        self.eigval=1.0/np.sqrt(self.eigval)


        
    def get_len(self):
        return self.scoor_init.shape[0]
    
    def data_loader(self):
        scoor_init=[]
        sq=[]
        se=[]
        sr=[]
        sbasis=[]
        seval=[]
        protein_list = [
                        # '1ATN','1AVX','1AY7','1BJ1','1BVN','1CGI','1DFJ','1EAW',
                        # '1EWY','1EZU','1GRN','1IBR','1IJK','1IQD','1JPS','1KXQ','1M10',
                        # '1MAH','1N8O','1PPE','1R0R','1XQS','2B42','2C0L','2HRK'
                        '1ATN',
                        # '2JEL',
                        # '7CEI'
                        ] 
        for i in range(len(protein_list)):
            n=6
            for j in range(7,8):
                x = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/coor_init"))
                q = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/q"))
                e = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/e"))
                r = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/r"))
                basis = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/basis"))
                eigval = np.float32(np.loadtxt("data_protein/"+protein_list[i]+'_'+str(j)+"/eigval"))
                q=np.tile(q, (1, 1))
                e=np.tile(e, (1,1))
                q = np.matmul(q.T, q)
                e = np.sqrt(np.matmul(e.T, e))
                r = (np.tile(r, (len(r), 1)) + np.tile(r, (len(r), 1)).T)/2
                scoor_init.append(x)
                sq.append(q)
                se.append(e)
                sr.append(r)
                sbasis.append(basis)
                seval.append(eigval)
        scoor_init = np.array(scoor_init)
        sq = np.array(sq)
        se = np.array(se)
        sr = np.array(sr)
        sbasis = np.array(sbasis)
        seval = np.array(seval)
        #eval
        # scoor_init=np.tile(scoor_init,(125,1,1))
        # sq=np.tile(sq,(125,1,1))
        # se=np.tile(se,(125,1,1))
        # sr=np.tile(sr,(125,1,1))
        # sbasis=np.tile(sbasis,(125,1,1))
        # seval=np.tile(seval,(125,1))
        print (sq.shape, se.shape, seval.shape)
        return scoor_init, sq, se, sr, sbasis, seval
        
    def energy(self,x):
        #x -> (1*1*12)
       product=np.squeeze(np.expand_dims(x*self.eigval,1)@self.basis)
       new_coor=np.reshape(product,(x.shape[0],self.coor_init.shape[-2],self.coor_init.shape[-1]))+self.coor_init
       p2=np.sum(new_coor*new_coor,2)
       p3=new_coor@new_coor.transpose(0,2,1)
       p2=np.expand_dims(p2,-1)
       stmp=p2-2*p3+p2.transpose(0,2,1)+0.01
       pair_dis=np.sqrt(stmp)
       c7_small=np.float32(pair_dis<7)
       c7_small1=np.float32(pair_dis<7)
       c7=np.float32(pair_dis>7)
       c0=np.float32(pair_dis>0.1)
       c9=np.float32(pair_dis<9)
       c79=c7*c9*c0
       c7_small=c7_small*c0
       tmp=[np.eye(self.natoms,self.natoms)for i in range(x.shape[0])]
       tmp=np.stack(tmp,axis=0)
       pair_dis+=tmp
    #    torch.isnan().any()
       
       coeff=self.q/(4*pair_dis)+np.sqrt(self.e)*((self.r/pair_dis)**12-(self.r/pair_dis)**6)
       a1=c7_small*coeff
       c1= a1*10+10*c79*coeff*((9-pair_dis)**2*(-12+2*pair_dis)/ 8)
       res=np.sum(c1,1)
       energy=np.mean(res,-1)-7000
       
    #    np.save('np_pd.npy',pair_dis)
    #    np.save('np_c7_small.npy',c7_small1)
    #    np.save('np_c0.npy',c0)
    #    np.save('np_c9.npy',c9)
    #    np.save('np_coeff.npy',coeff)
       return energy[0]
       
       
def test():
    
    npc7sm=np.load('./np_c7_small.npy')
    nppd=np.load('./np_pd.npy')
    npc0=np.load('./np_c0.npy')
    npc9=np.load('./np_c9.npy')
    npcoeff=np.load('./np_coeff.npy')
    tfc7sm=np.load('./tf_c7_small.npy')
    tfpd=np.load('./tf_pd.npy')
    tfc0=np.load('./tf_c0.npy')
    tfc9=np.load('./tf_c9.npy')
    tfcoeff=np.load('./tf_coeff.npy')

  
    print('l1_c7small',np.sum(np.abs(npc7sm-tfc7sm)))
    print('l1_pd',np.sum(np.abs(nppd-tfpd)))
    print('l1_c0',np.sum(np.abs(npc0-tfc0)))
    print('l1_c9',np.sum(np.abs(npc9-tfc9)))
    print('l1_ceff',np.sum(np.abs(npcoeff-tfcoeff)))

if __name__=='__main__':
    problem=Protein_dock()
    vec=np.load('./testvector.npy')
    print(problem.energy(vec))
    
    test()
    