from utils import merge_chains,get_contact,relative_sasa
from Bio.PDB import PDBParser
import gzip
import torch.nn.functional as F
import torch
import numpy as np
from Bio.PDB import Chain,Model,Structure
from Bio.PDB.PDBIO import PDBIO
import os
import random
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
from Bio.PDB.Residue import Residue
cpu_num = 32
os.environ ['OMP_NUM_THREADS'] = str(cpu_num)
os.environ ['OPENBLAS_NUM_THREADS'] = str(cpu_num)
os.environ ['MKL_NUM_THREADS'] = str(cpu_num)
os.environ ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num)
os.environ ['NUMEXPR_NUM_THREADS'] = str(cpu_num)
torch.set_num_threads(cpu_num)
def normpdf(x,mu,sig):
    y = 0.398942/sig*np.exp(-0.5*(x-mu)**2/sig**2)
    return y

def tmp_sample_func(rbsa):
    ratio = normpdf(rbsa,0.875,0.175)/normpdf(rbsa,0.9,0.11)*1.1
    if random.random() <= ratio:
        return True
    else:
        return False

def mp_process(filelist,func,n_cpu = 32):   
    tbar = tqdm(total=len(filelist))
    result_list = []
    def callback(return_data):
        tbar.update(1)
    pool = mp.Pool(n_cpu)
    for f in filelist:
        result = pool.apply_async(func=func, args=(f,), callback=callback)
        result_list.append(result)
    pool.close()
    pool.join()
    return [r.get() for r in result_list]

def save_pdb(biopy_chain,ligres,recpres,filename,bsa=False,samplefunc=None):
    L_chain = Chain.Chain('L')
    tmp_res = biopy_chain[ligres[0]-1]
    N_res = Residue(tmp_res.id,tmp_res.resname,tmp_res.segid)
    N_res.add(tmp_res['C'])
    N_res.add(tmp_res['CA'])
    N_res.add(tmp_res['O'])
    L_chain.add(N_res)
    for i in ligres:
        L_chain.add(biopy_chain[i])
    tmp_res = biopy_chain[ligres[-1]+1]
    C_res = Residue(tmp_res.id,tmp_res.resname,tmp_res.segid)
    C_res.add(tmp_res['N'])
    L_chain.add(C_res)
    R_chain = Chain.Chain('R')
    for i in recpres:
        if i < ligres[0]-1 or i > ligres[-1]+1:
            R_chain.add(biopy_chain[i])
    if bsa:
        relbsa = relative_sasa(R_chain,L_chain)
    else:
        relbsa = [1,0]
    if (samplefunc is None) or (not(bsa)) or samplefunc(relbsa[0]):
        tmp_structure = Structure.Structure('tmp')
        tmp_model = Model.Model(0)
        tmp_structure.add(tmp_model) 
        tmp_model.add(L_chain)
        tmp_model.add(R_chain)
        io = PDBIO()
        io.set_structure(tmp_structure)
        io.save(filename)  

        return relbsa
    else:
        return [0,0]

def process_one_pdb(pdb_file,home,tgt):
    sample_rate = np.load('sample_rate_clip22.npy')
    #random.seed(42)
    input_file = os.path.join(home,pdb_file)
    all_rel_bsa = []
    try:
        f = gzip.open(input_file,'rt')
        parser = PDBParser(PERMISSIVE=1)
        structure = parser.get_structure('pdb_file', f) #get structure
        f.close()
        preprocessed,breakpoint = merge_chains(structure)
        contact_map = torch.tensor(get_contact(preprocessed)).unsqueeze_(1).float()
        kmer_nn_results =[(contact_map>0).sum(dim=0).squeeze_()]
        cmap = [(contact_map>0).squeeze_()]
        sized_kernel = [(i,torch.ones((1,1,i))) for i in range(2,8)]
        for k,kernal in sized_kernel:
            kmer_nn = F.conv1d(contact_map,kernal)
            kmer_nn = (kmer_nn>0).squeeze_()
            cmap.append(kmer_nn.transpose(0,1))
            kmer_nn = kmer_nn.sum(dim=0)
            for i in breakpoint:
                kmer_nn[max(0,i-k):i+1] = 0
            kmer_nn_results.append(kmer_nn)
        i = 0
        for k,(kmer_nn,contact) in enumerate(zip(kmer_nn_results,cmap)):
            for p,(r_num,c) in enumerate(zip(kmer_nn,contact)):
                rand_num = random.random()
                if r_num>=1 and r_num<=63 and rand_num<sample_rate[k,r_num-1]:
                    try:
                        i += 1
                        all_rel_bsa.append(save_pdb(
                            biopy_chain=preprocessed[0]['A'],
                            ligres=list(range(p,p+k+1)),
                            recpres=torch.where(c)[0].tolist(),
                            filename=os.path.join(tgt,pdb_file.replace('.pdb1.gz',f'_{i}.pdb'),),
                            bsa=True,
                            samplefunc=tmp_sample_func
                        ))
                    except:
                        pass
    except:
        pass
    return all_rel_bsa
if __name__ == '__main__':
    process_one_pdb(
        home='pdb_home',
        tgt = 'output_dir',
        pdb_file='pdb1.gz_filename'
    )


    
