import numpy
import tensorflow as tf

from ProtLig_GPCRclassA.base_loader import BaseDataLoader

class AminoLoader(BaseDataLoader):
    """
    Paramters:
    ----------
    padding_n_node : int
        maximum number of nodes in one graph. Final padding size of a batch will be padding_n_nodes * batch_size

    padding_n_edge : int
        maximum number of edges in one graph. Final padding size of a batch will be padding_n_nodes * batch_size
    """
    def __init__(self, dataset, collate_fn,
                    batch_size=1,
                    n_partitions = 0,
                    shuffle=False, 
                    rng=None, 
                    drop_last=False):

        self.n_partitions = n_partitions
        if n_partitions > 0:
            assert batch_size % self.n_partitions == 0

        super(self.__class__, self).__init__(dataset,
        batch_size = batch_size,
        shuffle = shuffle,
        rng = rng,
        drop_last = drop_last,
        collate_fn = collate_fn,
        )
        
# -----------------
# Create TF loader:
# -----------------
def get_tf_loader(jax_dataset, batch_size, use_cache, shuffle, shuffle_buffer_size, drop_last, 
              id_mapping_table = None, seq_embedding_lookup = None,
              n_partitions = 0):
    loader = jax_dataset.tf_Dataset_by_example()
    # Cache:
    if use_cache:
        loader = loader.cache()
    # Shuffle:
    if shuffle:
        loader = loader.shuffle(buffer_size = shuffle_buffer_size, reshuffle_each_iteration = True)
    # Cache sequence embedding:
    if id_mapping_table is not None and seq_embedding_lookup is not None:
        loader = loader.map(lambda s, g, l: (id_mapping_table.lookup(s), g, l))
        loader = loader.map(lambda s, g, l: (seq_embedding_lookup(s), g, l))
    # pmap:
    if n_partitions > 0:
        batch_size_pmap = batch_size // n_partitions
        loader = loader.batch(batch_size_pmap, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
        loader = loader.batch(n_partitions, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    else:
        loader = loader.batch(batch_size, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    # Prefetch
    loader = loader.prefetch(buffer_size = tf.data.AUTOTUNE)
    return loader

# (_id, input_ids['input_ids'][0, ...])
def get_tf_loader_masked(jax_dataset, batch_size, use_cache, shuffle, shuffle_buffer_size, drop_last, 
              id_mapping_table = None, seq_embedding_lookup = None,
              n_partitions = 0):
    loader = jax_dataset.tf_Dataset_by_example()
    # Cache:
    if use_cache:
        loader = loader.cache()
    # Shuffle:
    if shuffle:
        loader = loader.shuffle(buffer_size = shuffle_buffer_size, reshuffle_each_iteration = True)
    # Cache sequence embedding:
    if id_mapping_table is not None and seq_embedding_lookup is not None:
        loader = loader.map(lambda s, g, l: ((id_mapping_table.lookup(s[0]), ) + s[1:], g, l))
        loader = loader.map(lambda s, g, l: (seq_embedding_lookup(s[0]) + s[1:], g, l))
    # pmap:
    if n_partitions > 0:
        batch_size_pmap = batch_size // n_partitions
        loader = loader.batch(batch_size_pmap, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
        loader = loader.batch(n_partitions, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    else:
        loader = loader.batch(batch_size, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    # Prefetch
    loader = loader.prefetch(buffer_size = tf.data.AUTOTUNE)
    return loader


def get_tf_loader_masked_sample_by_label_distribution(jax_dataset, batch_size, use_cache, shuffle, shuffle_buffer_size, drop_last, 
              id_mapping_table = None, seq_embedding_lookup = None,
              n_partitions = 0):
    print('NOTE: get_tf_loader_masked_sample_by_label_distribution should be merged with get_tf_loader_masked in the future.')
    class_dist = jax_dataset.infer_class_dist()
    weights = numpy.array(list(class_dist.values())) / sum(class_dist.values())
    loaders = []
    for i, key in enumerate(class_dist.keys()):
        _single_label_dataset = jax_dataset.copy()
        _label_mask = _single_label_dataset.data.apply(lambda x: x['_label']['_main_label'] == key, axis = 1)
        _single_label_dataset.data = _single_label_dataset.data[_label_mask].copy()
        _single_label_loader = _single_label_dataset.tf_Dataset_by_example()
        # Cache:
        if use_cache:
            _single_label_loader = _single_label_loader.cache()
        # Shuffle:
        if shuffle: # NOTE: Rescaling buffer_size by weights.
            _single_label_loader = _single_label_loader.shuffle(buffer_size = int(shuffle_buffer_size * weights[i]), reshuffle_each_iteration = True)
        loaders.append(_single_label_loader)
    
    loader = tf.data.Dataset.sample_from_datasets(loaders, weights=weights)

    
    # Cache sequence embedding:
    if id_mapping_table is not None and seq_embedding_lookup is not None:
        loader = loader.map(lambda s, g, l: ((id_mapping_table.lookup(s[0]), ) + s[1:], g, l))
        loader = loader.map(lambda s, g, l: (seq_embedding_lookup(s[0]) + s[1:], g, l))
    # pmap:
    if n_partitions > 0:
        batch_size_pmap = batch_size // n_partitions
        loader = loader.batch(batch_size_pmap, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
        loader = loader.batch(n_partitions, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    else:
        loader = loader.batch(batch_size, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    # Prefetch
    loader = loader.prefetch(buffer_size = tf.data.AUTOTUNE)
    return loader


def tf_loader_pmap(jax_dataset, batch_size_pmap, use_cache, shuffle, shuffle_buffer_size, drop_last, n_partitions):
    raise Exception('Depriciated. Moved to each get_tf_loader_*')
    loader = jax_dataset.tf_Dataset_by_example()
    if use_cache:
        loader = loader.cache()
    if seq_lookup:
        pass
    if shuffle:
        loader = loader.shuffle(buffer_size = shuffle_buffer_size, reshuffle_each_iteration = True)
    loader = loader.batch(batch_size_pmap, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    loader = loader.batch(n_partitions, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    loader = loader.prefetch(buffer_size = tf.data.AUTOTUNE)
    return loader
    

# def get_tf_loader(jax_dataset, batch_size, use_cache, shuffle, shuffle_buffer_size, drop_last, n_partitions = 0):
#     raise Exception('Depriciated.')
#     if n_partitions > 0:
#         batch_size_pmap = batch_size // n_partitions
#         loader = tf_loader_pmap(jax_dataset, batch_size_pmap, use_cache, shuffle, shuffle_buffer_size, drop_last, n_partitions)
#     else:
#         loader = tf_loader(jax_dataset, batch_size, use_cache, shuffle, shuffle_buffer_size, drop_last)
#     return loader