"""
General loader that takes sequence ID and get precomputed value.
Primarily aim at precompute loader.
"""
import numpy
import jraph
import jax
from jax import numpy as jnp
import functools
from collections import ChainMap

from ProtLig_GPCRclassA.utils import pad_graph_and_line_graph, pad_graph, Sequence, Label


def transpose_batch(batch):
    """
    Move first dimension of pytree into batch. I.e. pytree with (n_parts, n_elements, (batch_size, features, ...)) will be 
    changed to (n_elements, (n_parts, batch_size, features, ...)).
    
    Example:
    --------
    List of tuples [(X, Y, Z)] with dim(X) = (batch_size, x_size), dim(Y) = (batch_size, y_size), dim(Z) = (batch_size, z_size)
    is chaged to tuple (X', Y', Z') where dim(X') = (1, batch_size, x_size), dim(Y) = (1, batch_size, y_size), dim(Z) = (1, batch_size, z_size).
    """
    return jax.tree_map(lambda *x: jnp.stack(x, axis = 0), *batch)



class AminoCollate:
    """
    Bookkeeping and clean manipulation with collate function.

    The aim of this class is to ease small modifications of some parts of collate function without
    the need for code redundancy. In Loader use output of make_collate as a collate function.
    """
    def __init__(self, tokenizer, padding_n_node, padding_n_edge, line_graph = True, n_partitions = 0, seq_max_length = 2048):
        self.tokenizer = tokenizer
        self.padding_n_node = padding_n_node
        self.padding_n_edge = padding_n_edge
        self.n_partitions = n_partitions
        self.seq_max_length = seq_max_length

        if line_graph:
            self._graph_collate = functools.partial(self._graph_collate_with_line_graph, padding_n_node = padding_n_node, padding_n_edge = padding_n_edge, n_partitions = n_partitions)
        else:
            self._graph_collate = functools.partial(self._graph_collate_without_line_graph, padding_n_node = padding_n_node, padding_n_edge = padding_n_edge, n_partitions = n_partitions)

    def _seq_collate_seqs(self, batch):
        """
        """
        # [s.seq_id for s in batch]
        # batch = [s.seq for s in batch]
        tokenizer = self.tokenizer
        n_partitions = self.n_partitions
        # seqs = dict(tokenizer(batch, return_tensors='np', padding = True))
        seqs = dict(tokenizer(batch, return_tensors='np', padding = 'max_length', max_length = self.seq_max_length, truncation = True)) # 2048
        if 'position_ids' not in seqs.keys():
                seqs['position_ids'] = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seqs['input_ids']).shape[-1]), seqs['input_ids'].shape)      
        if n_partitions > 0:
            partition_size = len(batch) // n_partitions
            _seqs = []
            for i in range(n_partitions): # n_partitions
                _seq = {}
                for key in seqs.keys():
                    _seq[key] = seqs[key][i*partition_size:(i+1)*partition_size]
                _seqs.append(_seq)
            return _seqs
        else:
            return seqs

    def _seq_collate(self, batch):
        """
        """
        # [s.seq_id for s in batch]
        batch = [s.seq for s in batch]
        seqs = self._seq_collate_seqs(batch)
        return seqs

    @staticmethod
    def _graph_collate_with_line_graph(batch, padding_n_node, padding_n_edge, n_partitions):
        """
        For n_edges in a padding graph for line graph see https://stackoverflow.com/questions/6548283/max-number-of-paths-of-length-2-in-a-graph-with-n-nodes
        We expect on average degree of a node to be 3 (C has max degree 4, but it has often implicit hydrogen/benzene ring/double bond)
        Error will be raised if the assumption is not enough.

        Notes:
        ------
        Most of the molecules have small number of edges, so for them padding can be small. Thus padding is branched into two branches, one for small graph
        and the other for big graphs. This will triger retracing twice in jitted processing, but for most molecules only small version will be used.
        """
        mols = []
        line_mols = []
        for mol, line_mol in batch:
            if len(mol.senders) == 0 or len(mol.receivers) == 0:
                print(mol)
                raise ValueError('Molecule with no bonds encountered.')
            if len(line_mol.senders) == 0 or len(line_mol.receivers) == 0:
                print(line_mol)
                raise ValueError('Molecule with no edges encountered (line molecule with no bonds).')
            padded_mol, padded_line_mol = pad_graph_and_line_graph(mol, 
                                                                line_mol, 
                                                                padding_n_node = padding_n_node, 
                                                                padding_n_edge = padding_n_edge)
            mols.append(padded_mol)
            line_mols.append(padded_line_mol)
        if n_partitions > 0:
            partition_size = len(batch) // n_partitions        
            return [(jraph.batch(mols[i*partition_size:(i+1)*partition_size]), 
                     jraph.batch(line_mols[i*partition_size:(i+1)*partition_size])) for i in range(n_partitions)]
        else:
            return (jraph.batch(mols), jraph.batch(line_mols))

    @staticmethod
    def _graph_collate_without_line_graph(batch, padding_n_node, padding_n_edge, n_partitions):
        """
        For n_edges in a padding graph for line graph see https://stackoverflow.com/questions/6548283/max-number-of-paths-of-length-2-in-a-graph-with-n-nodes
        We expect on average degree of a node to be 3 (C has max degree 4, but it has often implicit hydrogen/benzene ring/double bond)
        Error will be raised if the assumption is not enough.

        Notes:
        ------
        Most of the molecules have small number of edges, so for them padding can be small. Thus padding is branched into two branches, one for small graph
        and the other for big graphs. This will triger retracing twice in jitted processing, but for most molecules only small version will be used.
        """
        mols = []
        for mol in batch:
            mol = mol[0] # Output of dataset __getitem__ is a tuple even if line_graph = False.
            if len(mol.senders) == 0 or len(mol.receivers) == 0:
                print(mol)
                raise ValueError('Molecule with no bonds encountered.')
            padded_mol = pad_graph(mol, 
                                padding_n_node = padding_n_node, 
                                padding_n_edge = padding_n_edge)
            mols.append(padded_mol)
        if n_partitions > 0:
            partition_size = len(batch) // n_partitions        
            return [jraph.batch(mols[i*partition_size:(i+1)*partition_size]) for i in range(n_partitions)]
        else:
            return jraph.batch(mols)

    def _label_collate(self, batch):
        """
        """        
        n_partitions = self.n_partitions
        output = {}
        if n_partitions > 0:
            partition_size = len(batch) // n_partitions
            for col in batch[0].keys():
                _batch_col = [x[col] for x in batch]
                output[col] = [jnp.stack(_batch_col[i*partition_size:(i+1)*partition_size]) for i in range(n_partitions)]
        else:
            for col in batch[0].keys():
                _batch_col = [x[col] for x in batch]
                output[col] = jnp.stack(_batch_col)
        return output

    def _numeric_collate(self, batch):
        """
        """
        n_partitions = self.n_partitions
        if n_partitions > 0:
            partition_size = len(batch) // n_partitions
            return [jnp.stack(batch[i*partition_size:(i+1)*partition_size]) for i in range(n_partitions)] # [numpy.stack(batch[i*partition_size:(i+1)*partition_size]) for i in range(n_partitions)]
        else:
            return jnp.stack(batch) # numpy.stack

    def make_collate(self):
        """
        Create collate function that is the input to Loader.
        """
        n_partitions = self.n_partitions
        
        def _collate(batch):
            if isinstance(batch[0], Sequence): # Sequence is a subclass of dict()
                return self._seq_collate(batch)
            elif isinstance(batch[0], Label): # Label is a subclass of dict()
                return self._label_collate(batch)
            elif isinstance(batch[0], (numpy.integer, numpy.floating)):
                return self._numeric_collate(batch)
            elif isinstance(batch[0], (tuple,list)):
                if isinstance(batch[0][0], jraph.GraphsTuple):
                    return self._graph_collate(batch)
                else:
                    transposed = zip(*batch)
                    _batch = tuple([_collate(samples) for samples in transposed])
                    if n_partitions > 0:
                        return tuple(zip(*_batch))
                    else:
                        return _batch
            else:
                raise ValueError('Unexpected type passed from dataset to loader: {}'.format(type(batch[0])))

        def collate(batch):
            batch = _collate(batch)
            if n_partitions > 0:
                batch = transpose_batch(batch)
            return batch
        
        return collate


# ----------------------------------
# Precomputed amino acids embedding:
# ----------------------------------
class AminoCollatePrecompute(AminoCollate):
    def __init__(self, bert_table, padding_n_node, padding_n_edge, n_partitions, line_graph = True, from_disk = False):
        self.bert_table = bert_table
        self.padding_n_node = padding_n_node
        self.padding_n_edge = padding_n_edge
        self.n_partitions = n_partitions

        if line_graph:
            self._graph_collate = functools.partial(self._graph_collate_with_line_graph, padding_n_node = padding_n_node, padding_n_edge = padding_n_edge, n_partitions = n_partitions)
        else:
            self._graph_collate = functools.partial(self._graph_collate_without_line_graph, padding_n_node = padding_n_node, padding_n_edge = padding_n_edge, n_partitions = n_partitions)

        if not from_disk:
            # bert_dict = {}
            bert_dict = {'states':{}, 'attn':{}}
            # print(self.bert_table.read())
            for row in self.bert_table.iterrows():
                bert_dict['states'][row['id'].decode('utf-8')] = row['hidden_states']
                bert_dict['attn'][row['id'].decode('utf-8')] = row['attention_mask']
            print('Table loaded...')
            self._seq_collate = functools.partial(self._seq_collate_from_ram, bert_dict = bert_dict, n_partitions = n_partitions)
        else:
            self._seq_collate = functools.partial(self._seq_collate_from_disk, bert_table = bert_table, n_partitions = n_partitions)

    @staticmethod
    def _seq_collate_from_disk(batch, bert_table, n_partitions):
        # bert_table = self.bert_table
        # n_partitions = self.n_partitions
        seqs_hidden_states = []
        seqs_attn = []

        unique_seq_ids = {s['_seq_id'] for s in batch}

        batch_bert_dict = {'hidden_states' : {}, 'attention_mask' : {}}

        # unique_seq_map = ChainMap(*batch)

        # for _seq in batch:
        #     _id = next(iter(_seq))
        #     i = 0
        #     for x in bert_table.where('(id == b"{}")'.format(_id)):
        #         seqs_hidden_states.append(x['hidden_states'])
        #         seqs_attn.append(x['attention_mask'])
        #         i+=1 
        for _id in unique_seq_ids:
            i = 0
            for x in bert_table.where('(id == b"{}")'.format(_id)):
                batch_bert_dict['hidden_states'][_id] = x['hidden_states']
                batch_bert_dict['attention_mask'][_id] = x['attention_mask']
                i+=1 

            if i == 0:
                raise ValueError('No record found in bert_table for id: {}'.format(_id))
            elif i > 1:
                raise ValueError('Multiple records found in bert_table for id: {}'.format(_id))

        # seqs_hidden_states = [batch_bert_dict['hidden_states'][next(iter(s))] for s in batch]
        # seqs_attn = [batch_bert_dict['attention_mask'][next(iter(s))] for s in batch]
        seqs_hidden_states = [batch_bert_dict['hidden_states'][s['_seq_id']] for s in batch]
        seqs_attn = [batch_bert_dict['attention_mask'][s['_seq_id']] for s in batch]
        if n_partitions > 0:
            raise NotImplementedError('needs to be checked...')
            seqs = []
            partition_size = len(batch) // n_partitions
            for i in range(n_partitions):
                seqs.append(jnp.stack(seqs_hidden_states[i*partition_size:(i+1)*partition_size]))
        else:
            # seqs = jnp.stack(seqs_hidden_states)
            seqs = (jnp.stack(seqs_hidden_states), jnp.stack(seqs_attn))
        return seqs

    @staticmethod
    def _seq_collate_from_ram(batch, bert_dict, n_partitions):
        seqs_hidden_states = []
        seqs_attn = []
        for _seq in batch:
            _id = _seq['_seq_id'] # next(iter(_seq))
            x = bert_dict['states'][_id]
            seqs_hidden_states.append(x)
            attn = bert_dict['attn'][_id]
            seqs_attn.append(attn)

        if n_partitions > 0:
            raise NotImplementedError('needs to be checked...')
            seqs = []
            partition_size = len(batch) // n_partitions
            for i in range(n_partitions):
                seqs.append(jnp.stack(seqs_hidden_states[i*partition_size:(i+1)*partition_size]))
        else:
            seqs = (jnp.stack(seqs_hidden_states), jnp.stack(seqs_attn))
        return seqs




# -----------------------------------------------
# Precomputed amino acids embedding with masking:
# -----------------------------------------------
class AminoCollatePrecomputeMasked(AminoCollatePrecompute):
    def __init__(self, tokenizer, seq_max_length, bert_table, padding_n_node, padding_n_edge, n_partitions, line_graph = True, from_disk = False, seq_col = None):
        self.tokenizer = tokenizer
        self.seq_max_length = seq_max_length

        self.bert_table = bert_table
        self.padding_n_node = padding_n_node
        self.padding_n_edge = padding_n_edge
        self.n_partitions = n_partitions
        self.seq_col = seq_col

        if line_graph:
            self._graph_collate = functools.partial(self._graph_collate_with_line_graph, padding_n_node = padding_n_node, padding_n_edge = padding_n_edge, n_partitions = n_partitions)
        else:
            self._graph_collate = functools.partial(self._graph_collate_without_line_graph, padding_n_node = padding_n_node, padding_n_edge = padding_n_edge, n_partitions = n_partitions)

        if not from_disk:
            # bert_dict = {}
            bert_dict = {'states':{}, 'attn':{}}
            # print(self.bert_table.read())
            for row in self.bert_table.iterrows():
                bert_dict['states'][row['id'].decode('utf-8')] = row['hidden_states']
                bert_dict['attn'][row['id'].decode('utf-8')] = row['attention_mask']
            print('Table loaded...')
            self._seq_collate_ids = functools.partial(self._seq_collate_from_ram, bert_dict = bert_dict, n_partitions = n_partitions)
        else:
            self._seq_collate_ids = functools.partial(self._seq_collate_from_disk, bert_table = bert_table, n_partitions = n_partitions)

    def _seq_collate(self, batch):
        batch_seqs = [seq[self.seq_col] for seq in batch]
        batch_ids = batch

        input_ids = self._seq_collate_seqs(batch_seqs)['input_ids']
        seqs = self._seq_collate_ids(batch_ids)
        
        return seqs + (input_ids, )