import re
import jax
from jax import numpy as jnp
import jraph

from ProtLig_GPCRclassA.mol2graph.jraph.convert import smiles_to_jraph

from ProtLig_GPCRclassA.utils import create_line_graph

from ProtLig_GPCRclassA.amino_GNN.concentration.element_conc import AminoConcentrationElementPrecompute


def make_predict_single_step(model, return_intermediates = False, num_classes = 2):
    if num_classes == 2:
        if return_intermediates:
            def predict_step(params, batch):
                logits, intermediates = model.apply(params, batch, deterministic = True, mutable=['intermediates'])
                pred_probs = jax.nn.sigmoid(logits['_main_label'])
                return pred_probs, intermediates
        else:
            def predict_step(params, batch):
                logits = model.apply(params, batch, deterministic = True)
                pred_probs = jax.nn.sigmoid(logits['_main_label'])
                return pred_probs
        # return predict_step
    else:
        if return_intermediates:
            def predict_step(params, batch):
                logits, intermediates = model.apply(params, batch, deterministic = True, mutable=['intermediates'])
                pred_probs = jax.nn.softmax(logits['_main_label'])
                return pred_probs, intermediates
        else:
            def predict_step(params, batch):
                logits = model.apply(params, batch, deterministic = True)
                pred_probs = jax.nn.softmax(logits['_main_label'])
                return pred_probs
        # return predict_step
    return jax.jit(predict_step)


# ---------------
# predict single:
# ---------------
def make_predict_conc_single_apply(model, apply_seqs_model, padding_n_node, padding_n_edge, return_intermediates = False, num_classes = 2, self_loops = False, line_graph = False, IncludeHs = False, line_graph_max_size = None):
    """
    Helper function to create predict_epoch function.
    """
    raise NotImplementedError('Apply is not implemented yet.')


def make_predict_conc_single_precompute(model, padding_n_node, padding_n_edge, h5_table, from_disk = False, return_intermediates = False, num_classes = 2, self_loops = False, line_graph = False, IncludeHs = False, line_graph_max_size = None):
    """
    Helper function to create predict_epoch function.
    """
    predict_single_step = make_predict_single_step(model = model, return_intermediates = return_intermediates, num_classes = num_classes)
    element = AminoConcentrationElementPrecompute(bert_table = h5_table,
                                                padding_n_node = padding_n_node, 
                                                padding_n_edge = padding_n_edge,
                                                from_disk = from_disk)
    # ---->  Query the table HERE...

    def _read_graph(x, u = None):
        """
        """
        G = smiles_to_jraph(x, u = u, validate = False, IncludeHs = IncludeHs,
                        atom_features = model.atom_features, bond_features = model.bond_features,
                        self_loops = self_loops)
        if line_graph:
            return (G, create_line_graph(G, max_size = line_graph_max_size))
        else:
            return (G, )

    def predict_single(params, list_of_smiles, seq_ids, concentrations):
        _seqs = []
        _seqs_attn_mask = []
        mols = []
        for i in range(len(seq_ids)):
            seq = {'_seq_id' : seq_ids[i]}
            _seq, _seq_attn_mask = element.seq_prepro(seq)
            mol = _read_graph(list_of_smiles[i]) # element expects tuple for mol.
            mol = element.mol_prepro(mol, jnp.array([concentrations[i]]))
            _seqs.append(_seq)
            _seqs_attn_mask.append(_seq_attn_mask)
            mols.append(mol)

        S = jnp.stack(_seqs)
        S_attn_mask = jnp.stack(_seqs_attn_mask)
        G = jraph.batch(mols)
        batch = ((S, S_attn_mask), G)
        output = predict_single_step(params, batch)
        return output
    return predict_single