from dataclasses import dataclass
from typing import Any, Dict, List, Callable, Union, List

import numpy as np

from coarsebind_public.esm2.io_schema import ESM2IOSchema
from coarsebind_public.mol_encoder.io_schema import MolEncIOSchema


@dataclass(slots=True)
class OutputSchemaDisto:

    pw_distance_cutoff: float = None
    res_type: np.ndarray = None
    entity_id: np.ndarray = None
    sym_id: np.ndarray = None
    asym_id: np.ndarray = None
    res_num: np.ndarray = None
    res_name: np.ndarray = None

    potency_ligand_mask: np.ndarray = None
    bin_probs: np.ndarray = None
    norm_bin_entropy: np.ndarray = None
    pw_distances: np.ndarray = None
    save_layer_reps: Dict[str, np.ndarray] = None
    save_s_reps: Dict[str, np.ndarray] = None
    within_cutoff_mask: np.ndarray = None
    within_cutoff_pair_mask: np.ndarray = None

    # coarse cofold output
    coarse_cofold_protein_pdb_str: str = None
    coarse_cofold_protein_pdb_uri: str = None
    coarse_cofold_ligand_pdb_str: str = None
    coarse_cofold_ligand_pdb_uri: str = None

    # TODO: where should this go..?
    template_path: str = None
    chain_id: str = None
    template_res_idxs: np.ndarray = None
    coarse_cofold_template_coords: np.ndarray = None
    # false for unresolved residues
    template_mask: np.ndarray = None


@dataclass(slots=True)
class OutputSchemaDistoPotency:

    pred_quant: float = None
    pred_binary: float = None
    pred_epinet_samples: np.ndarray = None
    pred_25_pct: float = None
    pred_75_pct: float = None


@dataclass(slots=True)
class IOSchemaCoarseBind:

    # input data
    smiles: str

    # official protein_db complex_id only
    complex_id: str = None
    target_name: str = None
    mol_id: str = None
    assay_id: str = None
    sequence: Union[str, List[str]] = None

    # enumerated sequence residue ids (list of lists, one per chain)
    res_num: List[List[int]] = None

    # res_num in the user specified binding site (list of lists, one per chain)
    binding_site_res_num: List[List[int]] = None

    # deployment info
    deployment_id: str = None
    artifact_id: str = None
    model_task: str = None
    model_source_env: str = None
    artifact_s3: str = None

    # prelim features
    canon_smiles: str = None

    # Full MolEnc output with smiles_embed and e3nn_embed
    mol_enc_io: MolEncIOSchema = None
    # Full ESM2 output with embeddings
    esm2_io: ESM2IOSchema = None

    # output features

    # disto model output
    disto_output: OutputSchemaDisto = None

    disto_potency_output: OutputSchemaDistoPotency = None

    # coarsebind output
    pred: float = None

    # error handling

    error: bool = False
    error_msg: str = ""

    # cache id used to identify this run
    cache_id: str = None
