import functools
from jax import numpy as jnp

import tensorflow as tf

from ProtLig_GPCRclassA.utils import create_line_graph, Sequence, Label, pad_graph


# ------
# Utils:
# ------
def _get_seq_tokenized(seq, tokenizer, seq_max_length):
        """
        """
        seqs = dict(tokenizer(seq, return_tensors='np', padding = 'max_length', max_length = seq_max_length, truncation = True)) # 2048
        if 'position_ids' not in seqs.keys():
                seqs['position_ids'] = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seqs['input_ids']).shape[-1]), seqs['input_ids'].shape)
        return seqs

# def _get_seq_embedding_from_ram(seq, bert_dict):
#     _id = seq['_seq_id']
#     seqs_hidden_states = bert_dict['states'][_id]
#     seqs_attn = bert_dict['attn'][_id]
#     seqs = (seqs_hidden_states, seqs_attn)
#     return seqs

# def _get_seq_embedding_from_disk():
#     raise NotImplementedError('Not implemented yet. Check old collate.')

# def create_bert_dict(bert_table):
#     # bert_dict = {}
#     bert_dict = {'states':{}, 'attn':{}}
#     # print(self.bert_table.read())
#     for row in bert_table.iterrows():
#         bert_dict['states'][row['id'].decode('utf-8')] = row['hidden_states']
#         bert_dict['attn'][row['id'].decode('utf-8')] = row['attention_mask']
#     print('Table loaded...')
#     return bert_dict

# def make_get_seq_embedding(bert_table, from_disk):
#     if not from_disk:
#         bert_dict = create_bert_dict(bert_table)
#         _get_seq_embedding = functools.partial(_get_seq_embedding_from_ram, bert_dict = bert_dict)
#     else:
#         _get_seq_embedding = functools.partial(_get_seq_embedding_from_disk, bert_table = bert_table)
#     return _get_seq_embedding


# ----------------------------------
# Precomputed amino acids embedding:
# ----------------------------------
class AminoElementPrecompute:
    """
    """
    def __init__(self,
                bert_table = None,
                padding_n_node = None,
                padding_n_edge = None,
                from_disk = None,
                seq_lookup = False):
        # Moved from collate:
        self.bert_table = bert_table
        self.padding_n_node = padding_n_node
        self.padding_n_edge = padding_n_edge
        self.from_disk = from_disk
        self.seq_lookup = seq_lookup

        self.id_mapping_table = None
        self.seq_embedding_lookup = None
 
        # self._get_seq_embedding = make_get_seq_embedding(bert_table, from_disk)
        
        self.seq_prepro = self.make_seq_prepro()
        self.mol_prepro = self.make_mol_prepro()

    @staticmethod
    def create_bert_dict(bert_table):
        # bert_dict = {}
        bert_dict = {'states':{}, 'attn':{}}
        # print(self.bert_table.read())
        for row in bert_table.iterrows():
            bert_dict['states'][row['id'].decode('utf-8')] = row['hidden_states']
            bert_dict['attn'][row['id'].decode('utf-8')] = row['attention_mask']
        print('Table loaded...')
        return bert_dict

    def create_seq_embedding_lookups(self):
        i = 0
        lookup_keys = []
        lookup_vals = []
        hidden_states_list = []
        attention_mask_list = []
        for row in self.bert_table.iterrows():
            row_id = row['id'].decode('utf-8')
            lookup_keys.append(row_id)
            lookup_vals.append(i)
            hidden_states_list.append(row['hidden_states'])
            attention_mask_list.append(row['attention_mask'])
            i += 1
        print('Table loaded...')

        with tf.device('cpu'):
            keys_tensor = tf.constant(lookup_keys)
            vals_tensor = tf.constant(lookup_vals)
            init = tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor)
            id_mapping_table = tf.lookup.StaticHashTable(init, default_value=-1)

            hidden_states_tensor = tf.constant(hidden_states_list)
            attention_mask_tensor = tf.constant(attention_mask_list)

            def seq_embedding_lookup(ids):
                states = tf.nn.embedding_lookup(hidden_states_tensor, ids)
                mask = tf.nn.embedding_lookup(attention_mask_tensor, ids)
                return (states, mask)
        print('Lookups created...')

        return id_mapping_table, seq_embedding_lookup

    def make_seq_prepro(self):
        if self.from_disk:
            raise NotImplementedError('Not implemented yet. Check old collate...')
        else:
            if self.seq_lookup:
                def seq_prepro(seq):
                    _id = seq['_seq_id']
                    return _id
            else:
                bert_dict = self.create_bert_dict(self.bert_table)
                def seq_prepro(seq):
                    _id = seq['_seq_id']
                    seqs_hidden_states = bert_dict['states'][_id]
                    seqs_attn = bert_dict['attn'][_id]
                    seqs = (seqs_hidden_states, seqs_attn)
                    return seqs
        return seq_prepro

    def make_mol_prepro(self):
        def mol_prepro(mol):
            """
            TODO: Maybe efficiency can be enhanced by move this to dataset.read_data? Now the same graph is padded everytime it is sampled. Alternatively, 
            it could be padded in read_data preprocessing.
            """
            mol = mol[0] # Output of dataset __getitem__ is a tuple even if line_graph = False.
            mol = pad_graph(mol, padding_n_node = self.padding_n_node, padding_n_edge = self.padding_n_edge)
            return mol
        return mol_prepro

    # def seq_prepro(self, seq):
    #     seq = self._get_seq_embedding(seq)
    #     return seq
    
    # def mol_prepro(self, mol):
    #     """
    #     TODO: Maybe efficiency can be enhanced by move this to dataset.read_data? Now the same graph is padded everytime it is sampled. Alternatively, 
    #     it could be padded in read_data preprocessing.
    #     """
    #     mol = mol[0] # Output of dataset __getitem__ is a tuple even if line_graph = False.
    #     mol = pad_graph(mol, padding_n_node = self.padding_n_node, padding_n_edge = self.padding_n_edge)
    #     return mol

    def make_element_preprocess(self):
        def element_preprocess(seq, mol, label):
            seq = dict(seq)
            seq = self.seq_prepro(seq)
            mol = self.mol_prepro(mol)
            label = dict(label)
            return seq, mol, label
        return element_preprocess



# ----------------------------------------------------
# Precomputed amino acids embedding with masking loss:
# ----------------------------------------------------
class AminoElementPrecomputeMasked(AminoElementPrecompute):
    """
    """
    def __init__(self,
                tokenizer, 
                seq_max_length,
                bert_table = None,
                padding_n_node = None,
                padding_n_edge = None,
                from_disk = None,
                seq_lookup = False,
                seq_col = None):
        # Moved from collate:
        self.tokenizer = tokenizer
        self.seq_max_length = seq_max_length

        self.bert_table = bert_table
        self.padding_n_node = padding_n_node
        self.padding_n_edge = padding_n_edge
        self.from_disk = from_disk
        self.seq_lookup = seq_lookup
        self.seq_col = seq_col

        self.id_mapping_table = None
        self.seq_embedding_lookup = None

        # self._get_seq_embedding = make_get_seq_embedding(bert_table, from_disk)
        self.seq_prepro = self.make_seq_prepro()
        self.mol_prepro = self.make_mol_prepro()

    def make_seq_prepro(self):
        if self.from_disk:
            raise NotImplementedError('Not implemented yet. Check old collate...')
        else:
            if self.seq_lookup:
                def seq_prepro(seq):
                    _id = seq['_seq_id']
                    input_ids = _get_seq_tokenized(seq[self.seq_col], tokenizer = self.tokenizer, seq_max_length = self.seq_max_length)
                    return (_id, input_ids['input_ids'][0, ...])
            else:
                bert_dict = self.create_bert_dict(self.bert_table)
                def seq_prepro(seq):
                    _id = seq['_seq_id']
                    seqs_hidden_states = bert_dict['states'][_id]
                    seqs_attn = bert_dict['attn'][_id]
                    seqs = (seqs_hidden_states, seqs_attn)
                    input_ids = _get_seq_tokenized(seq[self.seq_col], tokenizer = self.tokenizer, seq_max_length = self.seq_max_length)
                    # seq = self._get_seq_embedding(seq)
                    return seqs + (input_ids['input_ids'][0, ...], )
        return seq_prepro

    # def seq_prepro(self, seq):
    #     """
    #     TODO: Enhance efficiency via tf.data.Dataset.map. See notes.
    # 
    #     Notes:
    #     ------
    #     A more efficient way would be to tokenize batches in Dataset.map instead of take one sequence at a time here.
    #     """
    #     input_ids = _get_seq_tokenized(seq[self.seq_col], tokenizer = self.tokenizer, seq_max_length = self.seq_max_length)
    #     seq = self._get_seq_embedding(seq)
    #     return seq + (input_ids['input_ids'][0, ...], )
    
    # def mol_prepro(self, mol):
    #     """
    #     TODO: Maybe efficiency can be enhanced by move this to dataset.read_data? Now the same graph is padded everytime it is sampled. Alternatively, 
    #     it could be padded in read_data preprocessing.
    #     """
    #     mol = mol[0] # Output of dataset __getitem__ is a tuple even if line_graph = False.
    #     mol = pad_graph(mol, padding_n_node = self.padding_n_node, padding_n_edge = self.padding_n_edge)
    #     return mol

    def make_element_preprocess(self):
        def element_preprocess(seq, mol, label):
            seq = dict(seq)
            seq = self.seq_prepro(seq)
            mol = self.mol_prepro(mol)
            label = dict(label)
            return seq, mol, label
        return element_preprocess