import os
from Bio.PDB import PDBParser,Chain,Model,Structure
from Bio.PDB.PDBIO import PDBIO
from Bio.PDB import is_aa
from Bio.PDB.Residue import DisorderedResidue,Residue
from Bio.PDB.Atom import DisorderedAtom
import warnings
from Bio.PDB.StructureBuilder import PDBConstructionWarning
import numpy as np
from rdkit.Chem import Descriptors
import copy
import numpy as np
#from biopandas.mol2 import PandasMol2
from rdkit import Chem
import rdkit.Chem.AllChem as AllChem
from tqdm import tqdm
from io import StringIO
from rdkit.Chem.rdmolfiles import MolToPDBBlock
from Bio.PDB.NeighborSearch import NeighborSearch
import multiprocessing as mp
from freesasa import calcBioPDB
import freesasa
from Bio.PDB.Polypeptide import protein_letters_3to1
warnings.filterwarnings(
    action='ignore',
    category=PDBConstructionWarning)


aaSMILES = {'G':  'NCC(=O)',
            'A':  'N[C@@]([H])(C)C(=O)',
            'R':  'N[C@@]([H])(CCCNC(=N)N)C(=O)',
            'N':  'N[C@@]([H])(CC(=O)N)C(=O)',
            'D':  'N[C@@]([H])(CC(=O)O)C(=O)',
            'C':  'N[C@@]([H])(CS)C(=O)',
            'E':  'N[C@@]([H])(CCC(=O)O)C(=O)',
            'Q':  'N[C@@]([H])(CCC(=O)N)C(=O)',
            'H':  'N[C@@]([H])(CC1=CN=C-N1)C(=O)',
            'I':  'N[C@@]([H])(C(CC)C)C(=O)',
            'L':  'N[C@@]([H])(CC(C)C)C(=O)',
            'K':  'N[C@@]([H])(CCCCN)C(=O)',
            'M':  'N[C@@]([H])(CCSC)C(=O)',
            'F':  'N[C@@]([H])(Cc1ccccc1)C(=O)',
            'P':  'N1[C@@]([H])(CCC1)C(=O)',
            'S':  'N[C@@]([H])(CO)C(=O)',
            'T':  'N[C@@]([H])(C(O)C)C(=O)',
            'W':  'N[C@@]([H])(CC(=CN2)C1=C2C=CC=C1)C(=O)',
            'Y':  'N[C@@]([H])(Cc1ccc(O)cc1)C(=O)',
            'V':  'N[C@@]([H])(C(C)C)C(=O)'}

def AA2SMILES(aalist,C_amidation=True,N_actylation=True,three_letters=True):
    if three_letters:
        aalist = [protein_letters_3to1[x] for x in aalist]
    if N_actylation:
        pep = 'CC(=O)'
    else:
        pep = ''
    for i in aalist:
        pep += aaSMILES[i]
    if C_amidation:
        pep += 'N'
    else:
        pep += 'O'
    return pep

class arbitraryClassifier(freesasa.Classifier):

    # this must be set explicitly in all derived classifiers
    purePython = True
    def __init__(self):
        super(arbitraryClassifier, self).__init__()
        self.atom_radius = {" H":1.10," C":1.70," N":1.55," O":1.52," P":1.80," S":1.80,"SE":1.90," F":1.47,"CL":1.75,"BR":1.83," I":1.98,"LI":1.81,"BE":1.53,"NA":2.27,"MG":1.73," K":2.75,"CA":2.31,"RB":3.03,"SR":2.49,"CS":3.43,"BA":2.68,"FR":3.48,"RA":2.83,"SC":2.11,"TI":1.95," V":1.06,"CR":1.13,"MN":1.19,"FE":1.26,"CO":1.13,"NI":1.63,"CU":1.40,"ZN":1.39," Y":1.61,"ZR":1.42,"NB":1.33,"MO":1.75,"TC":2.00,"RU":1.20,"RH":1.22,"PD":1.63,"AG":1.72,"CD":1.58,"HF":1.40,"TA":1.22," W":1.26,"RE":1.30,"OS":1.58,"IR":1.22,"PT":1.75,"AU":1.66,"HG":1.55,"AL":1.84,"GA":1.87,"IN":1.93,"SN":2.17,"TL":1.96,"PB":2.02,"BI":2.07,"PO":1.97," B":1.92,"SI":2.10,"GE":2.11,"AS":1.85,"SB":2.06,"TE":2.06,"AT":2.02,"HE":1.40,"NE":1.54,"AR":1.88,"KR":2.02,"XE":2.16,"RN":2.20,"LA":1.83,"CE":1.86,"PR":1.62,"ND":1.79,"PM":1.76,"SM":1.74,"EU":1.96,"GD":1.69,"TB":1.66,"DY":1.63,"HO":1.61,"ER":1.59,"TM":1.57,"YB":1.54,"LU":1.53,"AC":2.12,"TH":1.84,"PA":1.60," U":1.86,"NP":1.71,"PU":1.67,"AM":1.66,"CM":1.65,"BK":1.64,"CF":1.63,"ES":1.62,"FM":1.61,"MD":1.60,"NO":1.59,"LR":1.58}
        self.default_classifier = freesasa.Classifier()

    def classify(self, residueName, atomName):
        return self.default_classifier.classify(residueName,atomName)

    def radius(self, residueName, atomName):
        radius = self.default_classifier.radius(residueName,atomName)
        if radius < 0:
            try:
                radius = self.atom_radius[atomName[:2]]
            except:
                pass
        return radius

def relative_sasa(recp_chain,lig_chain):
    recp_chain = recp_chain.copy()
    recp_chain.id = 'R'
    lig_chain = lig_chain.copy()
    lig_chain.id = 'L'
    tmp_structure = Structure.Structure('tmp')
    tmp_model = Model.Model(0)
    tmp_structure.add(tmp_model) 
    tmp_model.add(lig_chain)
    unbounded_SASA = calcBioPDB(tmp_structure,options={'hetatm':True},classifier=arbitraryClassifier())[0].residueAreas()['L']
    unbounded_SASA = sum([k.total for k in unbounded_SASA.values()])
    tmp_model.add(recp_chain)
    bounded_SASA = calcBioPDB(tmp_structure,options={'hetatm':True},classifier=arbitraryClassifier())[0].residueAreas()['L']
    bounded_SASA = sum([k.total for k in bounded_SASA.values()])
    abs_bsa = unbounded_SASA - bounded_SASA
    rel_bsa = abs_bsa/unbounded_SASA
    return rel_bsa,abs_bsa

def res_is_connected(residue1,residue2):
    ca1 = residue1['CA']
    ca2 = residue2['CA']
    distance = ca1-ca2
    if abs(distance-3.8)<=0.2:
        return 1
    else:
        return 0 

def merge_chains(pdbstruct,af2=False,tgt=None):
    break_point = []
    #p = PDBParser()
    model = pdbstruct[0]
    tmp_chain = Chain.Chain('A') 
    rid = 0
    last_res = None
    for chain in model:
        break_point.append(rid)
        for res in chain:
            try:
                res.detach_parent()
                if not(is_aa(res,standard=True)):
                    continue 
                if af2 and res['CA'].bfactor<50:
                    continue
                if res.is_disordered():
                    if isinstance(res,DisorderedResidue):
                        res = res.selected_child
                        res.id = (res.id[0],rid,res.id[2])      
                    else:
                        new_res = Residue(res.id,res.resname,res.segid)
                        for atom in res:
                            if isinstance(atom,DisorderedAtom):  
                                atom.selected_child.disordered_flag = 0
                                new_res.add(atom.selected_child.copy())
                            else:
                                new_res.add(atom)
                        res = new_res
                        res.id = (res.id[0],rid,res.id[2]) 
                else:
                    res.id = (res.id[0],rid,res.id[2])
                
                if last_res is not None and not res_is_connected(res,last_res):
                    break_point.append(rid)
                last_res = copy.deepcopy(res)
                tmp_chain.add(res.copy())
                rid += 1 
            except:
                pass
    tmp_structure = Structure.Structure(pdbstruct.id)
    tmp_model = Model.Model(0)
    tmp_structure.add(tmp_model) 
    tmp_model.add(tmp_chain)
    if tgt is not None:
        io = PDBIO()
        io.set_structure(tmp_structure)
        io.save(os.path.join(tgt,pdbstruct.id+'.pdb'))    
    return tmp_structure,break_point


def read_mol2_ligand(path,return_pdbstring = False):
    '''
    mol = Chem.MolFromMol2File(path, removeHs=True, sanitize=False)
    if mol is None:
        print("cannot read mol2", path)
    coords = mol.GetConformer().GetPositions()
    atom_types = [a.GetSymbol() for a in mol.GetAtoms()]
    return {'coord': np.array(coords), 'atom_type': atom_types, 'mol': mol, 'smi': Chem.MolToSmiles(mol)}
    '''
    # try:
    #     mol2_df = PandasMol2().read_mol2(path ,columns={0:('atom_id', int), 1:('atom_name', str), 2:('x', float), 3:('y', float), 4:('z', float), 5:('atom_type', str), 6:('subst_id', int), 7:('residue_name', str), 8:('useless1', float), 9:('useless2', str)})

    #     coords = mol2_df.df[['x', 'y', 'z']]
    # except:
    mol = Chem.MolFromMol2File(path)
    mw = Descriptors.MolWt(mol)
    mol_noH = Chem.RemoveHs(mol)
    coords = mol_noH.GetConformer().GetPositions()
    if return_pdbstring:
        pdb_string = MolToPDBBlock(mol,flavor=2)
        return np.array(coords),mw,pdb_string
    else:
        return np.array(coords),mw

def get_binding_pockets(original_pdb,lig_coord,output=None):
    chain = original_pdb[0]['A'] #only deal with A chain
    tmp_chain = Chain.Chain('A') 
    for res in chain:
        res_coord = np.array([i.get_coord() for i in res.get_atoms()])
        dist = np.linalg.norm(res_coord[:,None,:]-lig_coord[None,:,:],axis=-1).min()
        if dist<=6:
            tmp_chain.add(res.copy())
    tmp_structure = Structure.Structure(original_pdb.id)
    tmp_model = Model.Model(0)
    tmp_structure.add(tmp_model) 
    tmp_model.add(tmp_chain)
    if output is not None:
        io = PDBIO()
        io.set_structure(tmp_structure)
        io.save(os.path.join(output,original_pdb.id+'.pdb'))  
    return tmp_structure    


def get_contact(biopy_struc,radius=6.0,NCmask=5):
    searcher = NeighborSearch(list(biopy_struc.get_atoms()))
    contact_pairs = searcher.search_all(radius=radius,level='R')
    contact_map = np.zeros([len(biopy_struc[0]['A']),len(biopy_struc[0]['A'])],dtype=np.int8)
    for pairs in contact_pairs:
        contact_map[pairs[0].id[1],pairs[1].id[1]] = 1
    contact_map = np.triu(contact_map,NCmask+1)
    contact_map = contact_map+contact_map.transpose()
    return contact_map


def mp_process(home,filelist,func,n_cpu = 32):   
    tbar = tqdm(total=len(filelist))
    result_list = []
    def callback(return_data):
        tbar.update(1)
    pool = mp.Pool(n_cpu)
    for f in filelist:
        result = pool.apply_async(func=func, args=(os.path.join(home,f),), callback=callback)
        result_list.append(result)
    pool.close()
    pool.join()
    return [r.get() for r in result_list]

def gen_conformation(smiles, num_conf=1, num_worker=1):
    if len(smiles)>250:
        print("exceede max smiles lens", smiles)
        return None
    mol = Chem.MolFromSmiles(smiles)
    try:
        mol = Chem.AddHs(mol)
        AllChem.EmbedMultipleConfs(mol, numConfs=num_conf, numThreads=num_worker, pruneRmsThresh=1, maxAttempts=1000, useRandomCoords=False)
        AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=num_worker)
        mol = Chem.RemoveHs(mol)
    except:
        print("cannot gen conf", smiles)
        return None
    if mol.GetNumConformers() == 0:
        print("cannot gen conf", smiles)
        return None
    return mol

if __name__ == '__main__':
    #test code
    p = PDBParser()
    model = p.get_structure('0', 'pdbfile')
    model,breakpoint = merge_chains(model)
    get_contact(model)
