import os, sys
from typing import Mapping, Optional, Sequence, Any
import numpy as np
from entity import entity, entity_constants, residue_constants, molecule_processing
import pickle
import torch


FeatureDict = Mapping[str, np.ndarray]


def _token_type_to_token_seq(token_type: np.ndarray):
    return list(map(lambda token: entity_constants.token_types[token],
               token_type.tolist()))


def _make_entity_type_feature(token_num, entity_type):
    entity_type_feat = np.zeros(
        (token_num, entity_constants.entity_type_num),
        dtype=np.float32
    )
    entity_type_feat[:, entity_type] = 1
    return entity_type_feat


def make_token_seq_features(
    entity_type: str,
    token_seq: Sequence[str], # 
    token_index: np.ndarray,
    entity_token_num: int
) -> FeatureDict:
    """Construct a feature dict of sequence features."""
    features = {}
    features["target_feat"] = entity_constants.token_seq_to_onehot(
        sequence=token_seq,
        mapping=entity_constants.token_type_order,
        map_unknown_to=("UNK" if entity_type==entity_constants.entity_type_order["protein"] else None),
    )
    features["token_index"] = token_index
    features["seq_length"] = np.array([entity_token_num] * entity_token_num, dtype=np.int32)
    return features


def make_entity_features(
    entity_object: entity.Entity, 
    is_distillation: bool = False,
    confidence_threshold: float = 50.,
) -> FeatureDict:
    entity_feats = {}
    token_type = entity_object.token_type
    token_seq = _token_type_to_token_seq(token_type)
    entity_feats.update(
        make_token_seq_features(
            entity_type=entity_object.entity_type,
            token_seq=token_seq,
            token_index=entity_object.token_index,
            entity_token_num=len(entity_object.token_type),
        )
    )

    entity_feats["entity_type"] = _make_entity_type_feature(
        len(entity_object.token_type), entity_object.entity_type
    )

    all_atom_positions = entity_object.atom_positions
    all_atom_mask = entity_object.atom_mask
    extra_feat = entity_object.extra_feat
    pair_feat = entity_object.pair_feat
    fape_frame_idx = entity_object.fape_frame_idx
    edges = entity_object.edges

    entity_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
    entity_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
    entity_feats["extra_feat"] = extra_feat.astype(np.float32)
    entity_feats["pair_feat"] = pair_feat.astype(np.float32)
    entity_feats["fape_frame_idx"] = fape_frame_idx.astype(np.int32)
    entity_feats["edges"] = edges.astype(np.float32)

    entity_feats["resolution"] = np.array([0.]).astype(np.float32)
    entity_feats["is_distillation"] = np.array(
        1. if is_distillation else 0.
    ).astype(np.float32)

    if(is_distillation):
        high_confidence = entity_object.b_factors > confidence_threshold
        high_confidence = np.any(high_confidence, axis=-1)
        entity_feats["all_atom_mask"] *= high_confidence[..., None]

    return entity_feats


def make_complex_feature(
    entity_feat_list: Sequence[Mapping[Any, Any]],
):
    feature = {}
    token_num = 0

    current_feat = []
    for entity_feat in entity_feat_list:
        current_feat.append(entity_feat)

    for entity_feat in current_feat:
        token_num += len(entity_feat["target_feat"])
    all_entity_feats = {
        "target_feat": {
            "shape": (token_num, entity_constants.token_type_num),
            "type": np.int32
        },
        "token_index": {
            "shape": (token_num, ),
            "type": np.int32
        },
        "seq_length": {
            "shape": (token_num, ),
            "type": np.int32
        },
        "entity_type": {
            "shape": (token_num, entity_constants.entity_type_num),
            "type": np.int32
        },
        "all_atom_positions": {
            "shape": (token_num, entity_constants.atom_type_num, 3),
            "type": np.float32
        },
        "all_atom_mask": {
            "shape": (token_num, entity_constants.atom_type_num),
            "type": np.float32
        },
        "extra_feat": {
            "shape": (token_num, entity_constants.extra_feat_num),
            "type": np.float32
        },
        "pair_feat": {
            "shape": (token_num, token_num, entity_constants.pair_feat_num),
            "type": np.float32
        },
        "fape_frame_idx": {
            "shape": (token_num, 3),
            "type": np.int32
        },
        "edges": {
            "shape": (token_num, token_num, entity_constants.edge_type_num),
            "type": np.int32
        },
    }
    for feat, feat_info in all_entity_feats.items():
        feature[feat] = np.zeros(feat_info["shape"], dtype=feat_info["type"])
    pos = 0
    token_index_offset = 0
    for entity_feat in current_feat:
        entity_length = len(entity_feat["target_feat"])
        for feat, feat_info in all_entity_feats.items():
            # convert to use config
            if feat not in ["pair_feat", "token_index", "fape_frame_idx", "edges"]:
                feature[feat][pos : pos + entity_length] = entity_feat[feat]
            elif feat in ["pair_feat", "edges"]:
                feature[feat][pos : pos + entity_length, pos : pos + entity_length] = entity_feat[feat]
            elif feat == "token_index":
                feature[feat][pos : pos + entity_length] = entity_feat[feat] + token_index_offset
            elif feat == "fape_frame_idx":
                feature[feat][pos : pos + entity_length] = entity_feat[feat] + pos
        token_index_offset += entity_feat["token_index"].max() + 200
        pos += entity_length
    feature["seq_length"][...] = pos
    return feature


def dummy_feat(entity_type, seq_length):
    feature = {}
    if entity_type == "protein":
        feature["target_feat"] = np.zeros(
            (seq_length, entity_constants.token_type_num),
            dtype=np.int32
        )
        feature["token_index"] = np.linspace(
            1, seq_length, num=seq_length, dtype=np.int32
        )
        feature["seq_length"] = np.ones(
            seq_length, dtype=np.int32
        ) * seq_length
        feature["entity_type"] = np.zeros(
            (seq_length, entity_constants.entity_type_num),
            dtype=np.int32
        )
        feature["entity_type"][:, entity_constants.entity_type_order["protein"]] = 1
        feature["all_atom_positions"] = np.zeros(
            (seq_length, entity_constants.atom_type_num, 3),
            dtype=np.float32
        )
        feature["all_atom_mask"] = np.zeros(
            (seq_length, entity_constants.atom_type_num),
            dtype=np.float32
        )
        feature["all_atom_mask"][:, :5] = 1
        feature["extra_feat"] = np.zeros(
            (seq_length, entity_constants.extra_feat_num),
            dtype=np.float32
        )
        feature["pair_feat"] = np.zeros(
            (seq_length, seq_length, entity_constants.pair_feat_num),
            dtype=np.float32
        )
        feature["edges"] = np.zeros(
            (seq_length, seq_length, entity_constants.edge_type_num),
            dtype=np.float32
        )
        for i in range(seq_length-1):
            feature["edges"][i, i + 1, entity_constants.edge_type_order["prot_adjacent"]] = 1
            feature["edges"][i + 1, i, entity_constants.edge_type_order["prot_adjacent"]] = 1
        feature["fape_frame_idx"] = np.zeros((seq_length, 3), dtype=np.int32)
        return feature
    else:
        raise NotImplementedError


class AllAtomDataPipeline:
    def __init__(self):
        pass

    def process_pdb(
        self,
        pdb_path: str,
        # alignment_dir: str,
        chain_id: Optional[str] = None,
        # _structure_index: Optional[str] = None,
        # alignment_index: Optional[str] = None,
        # seqemb_mode: bool = False,
    ) -> FeatureDict:
        """
            Assembles features for a protein in a PDB file.
        """
        with open(pdb_path, 'r') as f:
            pdb_str = f.read()
        protein_objects = entity.entity_from_pdb_string(pdb_str, chain_id)
        all_model_feats = []
        for protein_object in protein_objects:
            pdb_feats = make_entity_features(
                protein_object, 
            )
            all_model_feats.append({**pdb_feats})
        return all_model_feats
    
    def process_mol(
        self,
        mol_path: str,
        is_distillation=False
    ):
        """
            Assembles features for a molecule in a SDF/MOL2 file.
        """
        mol_object = entity.entity_from_mol_file(mol_path)
        mol_feats = make_entity_features(
            mol_object,
            is_distillation=is_distillation
        )
        return {**mol_feats}
    
    def process_data(self, name, current_data, data_dir, mode):
        entity_feat_list = []
        seq_length = 0
        filtered = False
        for entity in current_data["entities"]:
            success = False
            if "is_target" in entity and entity["is_target"] == True and mode != "train":
                if "length" in entity:
                    entity_feat = dummy_feat(entity["type"], entity["length"])
                    success = True
            elif entity["type"] == "protein":
                for possible_path in entity["path"]:
                    try:
                        entity_feat = self.process_pdb(
                            os.path.join(data_dir, possible_path),
                        )
                        success = True
                        break
                    except (KeyboardInterrupt, SystemExit):
                        raise
                    except Exception as e:
                        raise
            elif entity["type"] == "molecule":
                for possible_path in entity["path"]:
                    try:
                        entity_feat = self.process_mol(
                            os.path.join(data_dir, possible_path),
                        )
                        success = True
                        break
                    except (KeyboardInterrupt, SystemExit):
                        raise
                    except Exception as e:
                        print(e)
                        pass
            else:
                raise ValueError(f"unsupported entity type in {name}")
            if not success:
                filtered = True
                break
            entity_feat_list.append(entity_feat)
            if type(entity_feat) == list:
                entity["length"] = 0
                for model in entity_feat:
                    entity["length"] = max(entity["length"], int(model["seq_length"][0]))
            else:
                entity["length"] = int(entity_feat["seq_length"][0])
            seq_length += entity["length"]
        if not filtered:
            current_data["seq_length"] = int(seq_length)
            return entity_feat_list
        else:
            return None
