import torch
import warnings
from Bio import BiopythonWarning
from Bio.PDB import PDBIO
from Bio.PDB.StructureBuilder import StructureBuilder

from .constants import AA, restype_to_heavyatom_names


def save_pdb(data, path=None):

    def _mask_select(v, mask):
        if isinstance(v, str):
            return ''.join([s for i, s in enumerate(v) if mask[i]])
        elif isinstance(v, list):
            return [s for i, s in enumerate(v) if mask[i]]
        elif isinstance(v, torch.Tensor):
            return v[mask]
        else:
            return v

    def _build_chain(builder, aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, chain_id_ch, resseq_ch, icode_ch):
        builder.init_chain(chain_id_ch[0])
        builder.init_seg('    ')

        for aa_res, pos_allatom_res, mask_allatom_res, resseq_res, icode_res in \
            zip(aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, resseq_ch, icode_ch):
            restype = AA(aa_res.item())
            builder.init_residue(
                resname = str(restype),
                field = ' ',
                resseq = resseq_res.item(),
                icode = icode_res,
            )

            for i, atom_name in enumerate(restype_to_heavyatom_names[restype]):
                if atom_name == '': continue    # No expected atom
                if (~mask_allatom_res[i]).any(): continue     # Atom is missing
                if len(atom_name) == 1: fullname = ' %s  ' % atom_name
                elif len(atom_name) == 2: fullname = ' %s ' % atom_name
                elif len(atom_name) == 3: fullname = ' %s' % atom_name
                else: fullname = atom_name # len == 4
                builder.init_atom(atom_name, pos_allatom_res[i].tolist(), 0.0, 1.0, ' ', fullname,)

    warnings.simplefilter('ignore', BiopythonWarning)
    builder = StructureBuilder()
    builder.init_structure(0)
    builder.init_model(0)

    unique_chain_nb = data['chain_nb'].unique().tolist()
    for ch_nb in unique_chain_nb:
        mask = (data['chain_nb'] == ch_nb)
        aa = _mask_select(data['aa'], mask)
        pos_heavyatom = _mask_select(data['pos_heavyatom'], mask)
        mask_heavyatom = _mask_select(data['mask_heavyatom'], mask)
        chain_id = _mask_select(data['chain_id'], mask)
        resseq = _mask_select(data['resseq'], mask)
        icode = _mask_select(data['icode'], mask)

        _build_chain(builder, aa, pos_heavyatom, mask_heavyatom, chain_id, resseq, icode)
    
    structure = builder.get_structure()
    if path is not None:
        io = PDBIO()
        io.set_structure(structure)
        io.save(path)
    return structure
