from jax import numpy as jnp
import jax
import copy
import time
import functools

from ProtLig_GPCRclassA.utils import create_line_graph, Sequence, Label, pad_graph
from ProtLig_GPCRclassA.amino_GNN.element import AminoElementPrecompute, AminoElementPrecomputeMasked

# ----------------------------------
# Precomputed amino acids embedding:
# ----------------------------------
class AminoConcentrationElementPrecompute(AminoElementPrecompute):
    """
    """    
    def make_mol_prepro(self):
        def mol_prepro(mol, conc):
            """
            TODO: Maybe efficiency can be enhanced by move this to dataset.read_data? Now the same graph is padded everytime it is sampled. Alternatively, 
            it could be padded in read_data preprocessing.
            """
            mol = mol[0] # Output of dataset __getitem__ is a tuple even if line_graph = False.
            # Put conc as a global to the graph:
            mol = pad_graph(mol, padding_n_node = self.padding_n_node, padding_n_edge = self.padding_n_edge)
            if mol.globals is None:
                _globals = {'_conc' : conc}
            else:
                _globals = mol.globals.copy()
                _globals.update({'_conc' : conc})
            mol = mol._replace(globals = _globals)
            return mol
        return mol_prepro

    def make_element_preprocess(self):
        def element_preprocess(seq, mol, conc, label):
            seq = dict(seq)
            seq = self.seq_prepro(seq)
            mol = self.mol_prepro(mol, conc)
            label = dict(label)
            return seq, mol, label
        return element_preprocess
    

# ----------------------------------------------------
# Precomputed amino acids embedding with masking loss:
# ----------------------------------------------------
class AminoConcentrationElementPrecomputeMasked(AminoElementPrecomputeMasked):
    """
    """    
    def make_mol_prepro(self):
        def mol_prepro(mol, conc):
            """
            TODO: Maybe efficiency can be enhanced by move this to dataset.read_data? Now the same graph is padded everytime it is sampled. Alternatively, 
            it could be padded in read_data preprocessing.
            """
            mol = mol[0] # Output of dataset __getitem__ is a tuple even if line_graph = False.
            # Put conc as a global to the graph:
            mol = pad_graph(mol, padding_n_node = self.padding_n_node, padding_n_edge = self.padding_n_edge)
            if mol.globals is None:
                _globals = {'_conc' : conc}
            else:
                _globals = mol.globals.copy()
                _globals.update({'_conc' : conc})
            mol = mol._replace(globals = _globals)
            return mol
        return mol_prepro

    def make_element_preprocess(self):
        def element_preprocess(seq, mol, conc, label):
            seq = dict(seq)
            seq = self.seq_prepro(seq)
            mol = self.mol_prepro(mol, conc)
            label = dict(label)
            return seq, mol, label
        return element_preprocess