#!/usr/bin/python
# -*- coding:utf-8 -*-
from typing import Optional, List

from ..utils import recur_index, is_standard_aa
from ..vocab import VOCAB
from ..hierarchy import Complex, Molecule, Block, Atom, BondType
from .. import const


def _is_peptide_bond(cplx, start_id, end_id, bond_type):
    start_block, end_block = recur_index(cplx, start_id[:-1]), recur_index(cplx, end_id[:-1])
    start_atom, end_atom = recur_index(cplx, start_id), recur_index(cplx, end_id)

    both_aa = is_standard_aa(start_block.name) and is_standard_aa(end_block.name)
    single_bond = bond_type == BondType.SINGLE

    if not (both_aa and single_bond): return False

    # sequence distance
    seq_dist = end_block.id[0] - start_block.id[0]
    if seq_dist == 0: # same position id. different insertion code
        start_iccode = start_block.id[1]
        end_iccode = end_block.id[1]
        if start_iccode == '': start_iccode = chr(ord('A') - 1)
        if end_iccode == '': end_iccode = chr(ord('A') - 1)
        seq_dist = ord(end_iccode) - ord(start_iccode)
    if abs(seq_dist) != 1: return False # not consecutive
    if seq_dist == -1: start_atom, end_atom = end_atom, start_atom  # swap

    return start_atom.name == 'C' and end_atom.name == 'N'


def complex_to_pdb(
        cplx: Complex,
        pdb_path: str,
        selected_chains: Optional[List[str]]=None,
        title: Optional[str]=None,
        explict_bonds: Optional[List[tuple]]=None
    ):
    '''
        Args:
            cplx: Complex, the complex to written into pdb file
            pdb_path: str, output path
            selected_chains: list of chain ids to write
            title: the title of the pdb file
            explict_bonds: list of bonds to write as CONECT (each bond is represented as (id1, id2, bond_type)).
                The bond_type will be ignored as pdb do not record such information. The id1 and id2 should be
                provided as numerical ids, e.g. (0, 10, 1) means the atom at cplx[0][10][1].
    '''
    fout = open(pdb_path, 'w')
    if title is not None: fout.write(f'TITLE     {title.upper()}\n')
    mol: Molecule = None
    block: Block = None
    atom: Atom = None
    atom_number = 1
    id2atom_number = {}
    for i, mol in enumerate(cplx): # chain
        if mol.id not in selected_chains: continue
        for j, block in enumerate(mol):
            block_name = block.name
            if block_name not in const.AA_GEOMETRY: # fragments
                atom_mark = 'HETATM'
                block_name = VOCAB.abrv_to_symbol(block_name).replace('f', '')[-3:]
            else:
                atom_mark = 'ATOM  '
            insert_code = block.id[1]
            if 'original_name' in block.properties:
                block_name = block.properties['original_name']
                insert_code = ''.join([s for s in insert_code if not s.isdigit()])
            if insert_code.isdigit(): insert_code = chr(ord('A') + int(insert_code))
            # sometimes fragment will lead to insert code like A0, A1 if the residue already has one insert code.
            # As pdb only permit single insert code, we manually let the second digit occupy a space following the insert code.
            # Therefore the insert code need to be left-justed instead of right-justed
            insert_code = insert_code.ljust(2)  
            for k, atom in enumerate(block):
                occupancy = atom.get_property('occupancy', 1.0)
                bfactor = atom.get_property('bfactor', 0.0)
                fout.write(''.join([
                    atom_mark,                                  # 1-6, ATOM, or HETATM
                    str(atom_number).rjust(5),                  # 7-11, atom serial number
                    ' ',                                        # 12
                    atom.name.ljust(4),                         # 13-16, atom name
                    ' ',                                        # 17, alternate location indicator
                    block_name.rjust(3),                        # 18-20, residue name
                    ' ',                                        # 21
                    mol.id[0].strip(),                          # 22, chain identifier
                    str(block.id[0]).rjust(4),                  # 23-26, residue sequence number
                    insert_code,                                # 27-28, code for insertions of residues, 28 for illegal digits generated by fragmentation
                    '  ',                                       # 29-30
                    str(round(atom.coordinate[0], 3)).rjust(8), # 31-38, X orthogonal \AA coordinate
                    str(round(atom.coordinate[1], 3)).rjust(8), # 39-46, X orthogonal \AA coordinate
                    str(round(atom.coordinate[2], 3)).rjust(8), # 47-54, X orthogonal \AA coordinate
                    str(round(occupancy, 2)).rjust(6),          # 55-60, occupancy
                    str(round(bfactor, 2)).rjust(6),            # 61-66, temperature factor
                    ' ' * 6,                                    # 67-72
                    ' ' * 4,                                    # 73-76, segment identifier
                    atom.element.rjust(2),                      # 77-78, element symbol
                    ' ' * 2,                                    # 79-80, charge
                    '\n'
                ]))
                id2atom_number[(i, j, k)] = atom_number
                atom_number += 1
        fout.write(''.join([
            'TER ',                                     # 1-4, TER
            ' ' * 2,                                    # 5-6
            str(atom_number).rjust(5),                  # 7-11, atom serial number
            ' ',                                        # 12
            ''.ljust(4),                                # 13-16, atom name
            ' ',                                        # 17, alternate location indicator
            block_name.rjust(3),                        # 18-20, residue name (last block)
            ' ',                                        # 21
            mol.id[0],                                  # 22, chain identifier
            str(block.id[0]).rjust(4),                  # 23-26, residue sequence number
            insert_code,                                # 27, code for insertions of residues
            '\n'
        ]))
    # write explicit bonds (drop normal peptide bond)
    if explict_bonds is not None:
        bonds = {}
        for start_id, end_id, bond_type in explict_bonds:
            if _is_peptide_bond(cplx, start_id, end_id, bond_type): continue    # PDB do not record normal peptide bond
            start_atom_number = id2atom_number[start_id]
            end_atom_number = id2atom_number[end_id]
            if start_atom_number not in bonds: bonds[start_atom_number] = []
            if end_atom_number not in bonds: bonds[end_atom_number] = []
            bonds[start_atom_number].append(end_atom_number)
            bonds[end_atom_number].append(start_atom_number)
        for atom_number in bonds:
            connects = ' '.join([str(n).rjust(4, ' ') for n in bonds[atom_number]])
            fout.write(f'CONECT {str(atom_number).rjust(4, " ")} {connects}\n')
    
    fout.close()