import dataclasses
import numpy as np
import torch
import string
import io
from Bio.PDB import PDBParser
from typing import Any, Sequence, Mapping, Optional
from rdkit import Chem
from entity import (
    entity_constants as ec,
    molecule_constants as mc,
    molecule_processing
)


@dataclasses.dataclass(frozen=True)
class Entity:
    # [NEW FEATURE]
    # Pairwise features, currently only the one-hot bond type of heavy atoms in
    # molecules
    entity_type: str

    # [UPDATED FEATURE]
    # Cartesian coordinates of atoms in angstroms. The atom types correspond to
    # residue_constants.atom_types, i.e. the first three are N, CA, CB.
    atom_positions: np.ndarray  # [num_res, num_atom_type, 3]

    # [NEW FEATURE]
    # Token type for each residue represented as an integer between 0 and
    # 20 + 119, where 20 is 'X'. 20+1 ~ 20+118 are heavy atom types, 119
    # is unknown heavy atom.
    token_type: np.ndarray  # [num_res]

    # [UPDATED FEATURE]
    # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
    # is present and 0.0 if not. This should be used for loss masking.
    atom_mask: np.ndarray  # [num_res, num_atom_type]

    # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
    token_index: np.ndarray  # [num_res]

    # 0-indexed number corresponding to the chain in the protein that this residue
    # belongs to.
    chain_index: np.ndarray  # [num_res]

    # chain id in the file, i-th value represents the original id for index i
    original_chain_ids: Sequence[str]  # [num_res]

    # B-factors, or temperature factors, of each residue (in sq. angstroms units),
    # representing the displacement of the residue from its ground truth mean
    # value.
    b_factors: np.ndarray  # [num_res, num_atom_type]

    # [NEW FEATURE]
    # Additional features, including features generated from RDKit for heavy atoms
    # in molecules.
    extra_feat: np.ndarray = None  #[num_res, num_extra_feat]

    # [NEW FEATURE]
    # Pairwise features, currently only the one-hot bond type of heavy atoms in
    # molecules
    pair_feat: np.ndarray = None  #[num_res, num_res, num_pair_feat]

    # Templates used to generate this protein (prediction-only)
    parents: Optional[Sequence[str]] = None

    # Chain corresponding to each parent
    parents_chain_index: Optional[Sequence[int]] = None

    fape_frame_idx: np.ndarray = None

    edges: np.ndarray = None


def entity_from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Entity:
    """Takes a PDB string and constructs a Entity object.

    WARNING: All non-standard residue types will be converted into UNK. All
      non-standard atoms will be ignored.

    Args:
      pdb_str: The contents of the pdb file
      chain_id: If None, then the whole pdb file is parsed. If chain_id is specified (e.g. A), then only that chain
        is parsed.

    Returns:
      A new `Entity` parsed from the pdb contents.
    """
    pdb_fh = io.StringIO(pdb_str)
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("none", pdb_fh)
    models = list(structure.get_models())
    # if len(models) != 1:
    #     raise ValueError(
    #         f"Only single model PDBs are supported. Found {len(models)} models."
    #     )

    all_models = []

    for model_id in range(len(models)):
        model = models[model_id]

        atom_positions = []
        token_type = []
        atom_mask = []
        token_index = []
        chain_ids = []
        b_factors = []

        residue_count = 0

        for chain in model:
            if(chain_id is not None and chain.id != chain_id):
                continue
            last_ca = None
            for res in chain:
                residue_count += 1
                token_type_idx = ec.token_type_order.get(
                    res.resname, ec.token_type_order["UNK"]
                )
                pos = np.zeros((ec.atom_type_num, 3))
                mask = np.zeros((ec.atom_type_num,))
                res_b_factors = np.zeros((ec.atom_type_num,))
                current_ca = None
                for atom in res:
                    if atom.name not in ec.atom_types:
                        continue
                    pos[ec.atom_order[atom.name]] = atom.coord
                    if atom.name == "CA":
                        current_ca = atom.coord
                    mask[ec.atom_order[atom.name]] = 1.0
                    res_b_factors[
                        ec.atom_order[atom.name]
                    ] = atom.bfactor
                if current_ca is None:
                    continue
                if last_ca is not None:
                    if np.linalg.norm(current_ca - last_ca) > 4:
                        residue_count += 1
                last_ca = current_ca
                if np.sum(mask) < 0.5:
                    # If no known atom positions are reported for the residue then skip it.
                    continue
                token_type.append(token_type_idx)
                atom_positions.append(pos)
                atom_mask.append(mask)
                token_index.append(residue_count)
                chain_ids.append(chain.id)
                b_factors.append(res_b_factors)
            residue_count += 200

        entity_token_num = len(token_type)

        parents = None
        parents_chain_index = None
        if("PARENT" in pdb_str):
            parents = []
            parents_chain_index = []
            chain_id = 0
            for l in pdb_str.split("\n"):
                if("PARENT" in l):
                    if(not "N/A" in l):
                        parent_names = l.split()[1:]
                        parents.extend(parent_names)
                        parents_chain_index.extend([
                            chain_id for _ in parent_names
                        ])
                    chain_id += 1

        unique_chain_ids = np.unique(chain_ids)
        chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
        chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids], dtype=np.float32)

        edges = np.zeros((entity_token_num, entity_token_num, ec.edge_type_num), dtype=np.float32)
        for i in range(entity_token_num - 1):
            edges[i, i + 1, ec.edge_type_order["prot_adjacent"]] = 1
            edges[i + 1, i, ec.edge_type_order["prot_adjacent"]] = 1

        all_models.append(Entity(
            entity_type=ec.entity_type_order["protein"],
            atom_positions=np.array(atom_positions, dtype=np.float32),
            atom_mask=np.array(atom_mask, dtype=np.float32),
            token_type=np.array(token_type, dtype=np.int32),
            token_index=np.array(token_index, dtype=np.int32),
            chain_index=chain_index,
            original_chain_ids = list(chain_ids),
            b_factors=np.array(b_factors, dtype=np.float32),
            parents=parents,
            parents_chain_index=parents_chain_index,
            extra_feat=np.zeros((entity_token_num, ec.extra_feat_num), dtype=np.float32),
            pair_feat=np.zeros((entity_token_num, entity_token_num, ec.pair_feat_num), dtype=np.float32),
            fape_frame_idx=np.zeros((entity_token_num, 3)),
            edges=edges,
        ))

    return all_models

def entity_from_mol_file(file_name: str) -> Entity:
    mol = molecule_processing.mol_from_file(file_name)
    mol_atom_coords = mol.GetConformer().GetPositions()
    atom_positions = []
    token_type = []
    atom_mask = []
    token_index = []
    chain_ids = []
    b_factors = []
    for idx, atom in enumerate(mol.GetAtoms()):
        atom_name = f"*{atom.GetAtomicNum()}"
        token_type_idx = ec.token_type_order.get(
            atom_name, ec.token_type_order[mc.mol_unk_token_name]
        )
        pos = np.zeros((ec.atom_type_num, 3))
        mask = np.zeros((ec.atom_type_num,))
        res_b_factors = np.zeros((ec.atom_type_num,))
        pos[ec.atom_order[mc.mol_atom_name]] = mol_atom_coords[idx]
        mask[ec.atom_order[mc.mol_atom_name]] = 1.0
        token_type.append(token_type_idx)
        atom_positions.append(pos)
        atom_mask.append(mask)
        token_index.append(0)
        chain_ids.append("A")
        b_factors.append(res_b_factors)

    unique_chain_ids = np.unique(chain_ids)
    chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
    chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids], dtype=np.float32)
    pair_feat, edges = molecule_processing.mol_pair_featurizer(mol)
    fape_frame_idx = molecule_processing.mol_make_fape_frame_idx(mol)

    return Entity(
        entity_type=ec.entity_type_order["molecule"],
        atom_positions=np.array(atom_positions, dtype=np.float32),
        atom_mask=np.array(atom_mask, dtype=np.float32),
        token_type=np.array(token_type, dtype=np.int32),
        token_index=np.array(token_index, dtype=np.int32),
        chain_index=chain_index,
        b_factors=np.array(b_factors, dtype=np.float32),
        parents=None,
        parents_chain_index=None,
        extra_feat=molecule_processing.mol_extra_featurizer(mol),
        pair_feat=pair_feat,
        original_chain_ids=list(chain_ids),
        fape_frame_idx=fape_frame_idx,
        edges=edges,
    )
