import re
import jax
from jax import numpy as jnp
import jraph

from ProtLig_GPCRclassA.mol2graph.jraph.convert import smiles_to_jraph

from ProtLig_GPCRclassA.utils import pad_graph, create_line_graph


def make_predict_single_step(model, return_intermediates = False, num_classes = 2):
    if num_classes == 2:
        if return_intermediates:
            def predict_step(params, batch):
                logits, intermediates = model.apply(params, batch, deterministic = True, mutable=['intermediates'])
                pred_probs = jax.nn.sigmoid(logits['_main_label'])
                return pred_probs, intermediates
        else:
            def predict_step(params, batch):
                logits = model.apply(params, batch, deterministic = True)
                pred_probs = jax.nn.sigmoid(logits['_main_label'])
                return pred_probs
        # return predict_step
    else:
        if return_intermediates:
            def predict_step(params, batch):
                logits, intermediates = model.apply(params, batch, deterministic = True, mutable=['intermediates'])
                pred_probs = jax.nn.softmax(logits['_main_label'])
                return pred_probs, intermediates
        else:
            def predict_step(params, batch):
                logits = model.apply(params, batch, deterministic = True)
                pred_probs = jax.nn.softmax(logits['_main_label'])
                return pred_probs
        # return predict_step
    return jax.jit(predict_step)


# ---------------
# predict single:
# ---------------
def make_apply_bert(model):
    """
    """
    # @jax.jit
    def apply_bert(seq):
        bert_output = bert_model.module.apply({'params': bert_model.params}, **seq, deterministic = True,
                             output_attentions = False,
                             output_hidden_states = True, 
                             return_dict = True)
        S = bert_output.hidden_states
        S = jnp.stack(S[-5:], axis = 1)
        S = jnp.reshape(S[:, :, 0, :], newshape = (S.shape[0], -1))
        return S
    return apply_bert


def make_preprocess_seq(tokenizer, seq_max_length):
    """
    """
    def preprocess_seq(seq):
        _seq = ' '.join(list(seq))
        _seq = re.sub(r"[UZOB]", "X", _seq)

        batch = [_seq]

        seqs = dict(tokenizer(batch, return_tensors='np', padding = 'max_length', max_length = seq_max_length, truncation = True)) # 2048
        if 'position_ids' not in seqs.keys():
                seqs['position_ids'] = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seqs['input_ids']).shape[-1]), seqs['input_ids'].shape)      
        return seqs

    return preprocess_seq


def make_preprocess_graph(padding_n_node, padding_n_edge):
    def preprocess_graph(mol):
        if len(mol.senders) == 0 or len(mol.receivers) == 0:
            print(mol)
            raise ValueError('Molecule with no bonds encountered.')
        padded_mol = pad_graph(mol, 
                            padding_n_node = padding_n_node, 
                            padding_n_edge = padding_n_edge)
        mols = [padded_mol]
        return jraph.batch(mols)
        
    return preprocess_graph


def make_preprocess_graphs(padding_n_node, padding_n_edge, n_partitions = 0):
    def preprocess_graphs(batch):
        mols = []
        for mol in batch:
            mol = mol[0] # Output of dataset __getitem__ is a tuple even if line_graph = False.
            if len(mol.senders) == 0 or len(mol.receivers) == 0:
                print(mol)
                raise ValueError('Molecule with no bonds encountered.')
            padded_mol = pad_graph(mol, 
                                padding_n_node = padding_n_node, 
                                padding_n_edge = padding_n_edge)
            mols.append(padded_mol)
        if n_partitions > 0:
            partition_size = len(batch) // n_partitions        
            return [jraph.batch(mols[i*partition_size:(i+1)*partition_size]) for i in range(n_partitions)]
        else:
            return jraph.batch(mols)
    
    return preprocess_graphs


def apply_seqs_model(seqs):
    _seqs = preprocess_seqs(seqs)
    output = apply_model(_seqs)
    _states = serialize_ESM2_hidden_states(output.hidden_states)

    seqs_hidden_states = []
    seqs_attn = []
    for i in range(len(seqs)):
        x = _states[i][-1,:,:].astype(jnp.float32) # index 0 becuase we have only one sequence
        seqs_hidden_states.append(x)
        attn = _seqs['attention_mask'][i].detach().numpy().astype(jnp.int32) # index 0 becuase we have only one sequence
        seqs_attn.append(attn)

    # hidden_states = _states[0][-1,:,:].astype(jnp.float32) # index 0 becuase we have only one sequence
    # attention_mask = _seqs['attention_mask'][0].detach().numpy().astype(jnp.int32) # index 0 becuase we have only one sequence
    # seqs = (jnp.expand_dims(hidden_states, axis = 0), jnp.expand_dims(attention_mask, axis = 0))

    if n_partitions > 0:
        raise NotImplementedError('needs to be checked...')
        seqs = []
        partition_size = len(batch) // n_partitions
        for i in range(n_partitions):
            seqs.append(jnp.stack(seqs_hidden_states[i*partition_size:(i+1)*partition_size]))
    else:
        seqs = (jnp.stack(seqs_hidden_states), jnp.stack(seqs_attn))
    return seqs




def make_query_sequences(h5_table, from_disk = False, n_partitions = 0):
    if not from_disk:
        bert_dict = {'states':{}, 'attn':{}}
        for row in h5_table.iterrows():
            bert_dict['states'][row['id'].decode('utf-8')] = row['hidden_states']
            bert_dict['attn'][row['id'].decode('utf-8')] = row['attention_mask']
        print('Table loaded...')
        # -----
        def query_sequences(seq_ids):
            seqs_hidden_states = []
            seqs_attn = []
            for _id in seq_ids:
                states = bert_dict['states'][_id]
                attn = bert_dict['attn'][_id]
                seqs_hidden_states.append(states)
                seqs_attn.append(attn)

            if n_partitions > 0:
                raise NotImplementedError('needs to be checked...')
                seqs = []
                partition_size = len(batch) // n_partitions
                for i in range(n_partitions):
                    seqs.append(jnp.stack(seqs_hidden_states[i*partition_size:(i+1)*partition_size]))
            else:
                seqs = (jnp.stack(seqs_hidden_states), jnp.stack(seqs_attn))
            return seqs
    else:
        def query_sequences(seq_ids):
            seqs_hidden_states = []
            seqs_attn = []
            for _id in seq_ids:
                i = 0
                for x in h5_table.where('(id == b"{}")'.format(_id)):
                    states = x['hidden_states']
                    attn = x['attention_mask']
                    i+=1 

                if i == 0:
                    raise ValueError('No record found in bert_table for id: {}'.format(_id))
                elif i > 1:
                    raise ValueError('Multiple records found in bert_table for id: {}'.format(_id))

                seqs_hidden_states.append(states)
                seqs_attn.append(attn)

            if n_partitions > 0:
                raise NotImplementedError('needs to be checked...')
                seqs = []
                partition_size = len(batch) // n_partitions
                for i in range(n_partitions):
                    seqs.append(jnp.stack(seqs_hidden_states[i*partition_size:(i+1)*partition_size]))
            else:
                seqs = (jnp.stack(seqs_hidden_states), jnp.stack(seqs_attn))
            return seqs

    return query_sequences


def make_predict_single_apply(model, apply_seqs_model, padding_n_node, padding_n_edge, return_intermediates = False, num_classes = 2, self_loops = False, line_graph = False, IncludeHs = False, line_graph_max_size = None):
    """
    Helper function to create predict_epoch function.
    """
    predict_single_step = make_predict_single_step(model = model, return_intermediates = return_intermediates, num_classes = num_classes)
    preprocess_graphs = make_preprocess_graphs(padding_n_node = padding_n_node, padding_n_edge = padding_n_edge)

    def _read_graph(x, u = None):
        """
        """
        G = smiles_to_jraph(x, u = u, validate = False, IncludeHs = IncludeHs,
                        atom_features = model.atom_features, bond_features = model.bond_features,
                        self_loops = self_loops)
        if line_graph:
            return (G, create_line_graph(G, max_size = line_graph_max_size))
        else:
            return (G, )

    def predict_single(params, list_of_smiles, seqs):
        S = apply_seqs_model(seqs)
        mols = [_read_graph(x) for x in list_of_smiles]
        G = preprocess_graphs(mols)
        batch = (S, G)
        output = predict_single_step(params, batch)
        return output
    return predict_single


def make_predict_single_precompute(model, padding_n_node, padding_n_edge, h5_table, from_disk = False, return_intermediates = False, num_classes = 2, self_loops = False, line_graph = False, IncludeHs = False, line_graph_max_size = None):
    """
    Helper function to create predict_epoch function.
    """
    predict_single_step = make_predict_single_step(model = model, return_intermediates = return_intermediates, num_classes = num_classes)
    query_seqs = make_query_sequences(h5_table = h5_table, from_disk = from_disk)
    preprocess_graphs = make_preprocess_graphs(padding_n_node = padding_n_node, padding_n_edge = padding_n_edge)
    # ---->  Query the table HERE...

    def _read_graph(x, u = None):
        """
        """
        G = smiles_to_jraph(x, u = u, validate = False, IncludeHs = IncludeHs,
                        atom_features = model.atom_features, bond_features = model.bond_features,
                        self_loops = self_loops)
        if line_graph:
            return (G, create_line_graph(G, max_size = line_graph_max_size))
        else:
            return (G, )

    def predict_single(params, list_of_smiles, seq_ids):
        S = query_seqs(seq_ids)
        mols = [_read_graph(x) for x in list_of_smiles]
        G = preprocess_graphs(mols)
        batch = (S, G)
        output = predict_single_step(params, batch)
        return output
    return predict_single