import os
from tqdm import tqdm
import numpy as np

from deli import save_json, load_json

from Bio.PDB import PDBParser
import prody
from prody import confProDy
confProDy(verbosity='none')

from flowdock.utils.preprocessing import extract_protein

three_to_one = {'ALA':	'A',
                'ARG':	'R',
                'ASN':	'N',
                'ASP':	'D',
                'CYS':	'C',
                'GLN':	'Q',
                'GLU':	'E',
                'GLY':	'G',
                'HIS':	'H',
                'ILE':	'I',
                'LEU':	'L',
                'LYS':	'K',
                'MET':	'M',
                'MSE':	'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
                'PHE':	'F',
                'PRO':	'P',
                'SER':	'S',
                'THR':	'T',
                'TRP':	'W',
                'TYR':	'Y',
                'VAL':	'V',}

biopython_parser = PDBParser(QUIET=True)

# def get_structure_from_file(file_path):
#     structure = biopython_parser.get_structure('random_id', file_path)
#     structure = structure[0]
#     l = []
#     for i, chain in enumerate(structure):
#         seq = ''
#         for res_idx, residue in enumerate(chain):
#             if residue.get_resname() == 'HOH':
#                 continue
#             c_alpha, n, c = None, None, None
#             for atom in residue:
#                 if atom.name == 'CA':
#                     c_alpha = list(atom.get_vector())
#                 if atom.name == 'N':
#                     n = list(atom.get_vector())
#                 if atom.name == 'C':
#                     c = list(atom.get_vector())
#             if c_alpha != None and n != None and c != None:  # only append residue if it is an amino acid
#                 try:
#                     seq += three_to_one[residue.get_resname()]
#                 except Exception as e:
#                     seq += '-'
#                     print("encountered unknown AA: ", residue.get_resname(), ' in the complex ', file_path, '. Replacing it with a dash - .')
#         l.append(seq)
#     return l


def get_structure_from_file(file_path):
    rec = prody.parsePDB(file_path)
    seq = rec.ca.getSequence()

    res_chain_ids = rec.ca.getChids()
    res_seg_ids = rec.ca.getSegnames()
    res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)])
    chain_ids = np.unique(res_chain_ids)
    seq = np.array([s for s in seq])

    chain_sequences = []
    for i, id in enumerate(chain_ids):
        chain_mask = res_chain_ids == id
        chain_seq = ''.join(seq[chain_mask])
        chain_sequences.append(chain_seq)
    return chain_sequences




data_dir = '<path_to_data astex_diverse_set>'
save_id2seq_path = '<path_to_save_id2seq>'
names = os.listdir(data_dir)

# astex or posebusters
id2seq = {}
bad_ids = []
for name in tqdm(names):
    if name.startswith('.'): continue
    rec_path = os.path.join(data_dir, name)
    rec_fname = [fname for fname in os.listdir(rec_path) if fname.endswith('.pdb')][0]
    rec_path = os.path.join(rec_path, rec_fname)
    try:
        l = get_structure_from_file(rec_path)
    except Exception as e:
        bad_ids.append(name)
        continue

    for i, seq in enumerate(l):
        id2seq[f'{name}_chain_{i}'] = seq

print('BAD IDS', len(bad_ids))
save_json(id2seq, save_id2seq_path)
