from typing import Any, Callable, Sequence, Optional
import jax
from jax import numpy as jnp
import flax

from ProtLig_GPCRclassA.amino_GNN.model.essentials.BasicMPNN import BasicTruncatedNormalDynamicMessagePassing
from ProtLig_GPCRclassA.amino_GNN.model.essentials.embeddings import *
from ProtLig_GPCRclassA.amino_GNN.model.essentials.GIN import GINLayer


class GIN(flax.linen.Module):
    mlp_num_layers : int 
    mlp_hidden_dim : int
    node_embedding_size : int
    n_layers : int

    @flax.linen.compact
    def __call__(self, inputs, deterministic):
        G = inputs
        for _ in range(self.n_layers):
            G, _globals = GINLayer(mlp_num_layers = self.mlp_num_layers, mlp_hidden_dim = self.mlp_hidden_dim, node_embedding_size = self.node_embedding_size)(G)
        return G


class FeedForwardNetwork(flax.linen.Module):
    widening_factor : int
    dropout_rate : float
    out_size : float = None

    @flax.linen.compact
    def __call__(self, inputs, deterministic):
        x = inputs
        hidden_size = x.shape[-1]
        if self.out_size is None:
            out_size = hidden_size
        else:
            out_size = self.out_size
        x = flax.linen.Dense(self.widening_factor * hidden_size, use_bias = True)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        x = flax.linen.Dense(out_size, use_bias = True)(x)
        x = flax.linen.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        return x 


class Simple_sum_pooling_with_MLP(flax.linen.Module):
    out_size : int

    @flax.linen.compact
    def __call__(self, graph, deterministic):
        nodes, edges, receivers, senders, globals_, n_node, n_edge = graph
        sum_n_node = jax.tree_util.tree_leaves(nodes)[0].shape[0]
        n_graph = n_node.shape[0]
        graph_idx = jnp.arange(n_graph)
        # To aggregate nodes and edges from each graph to global features,
        # we first construct tensors that map the node to the corresponding graph.
        # For example, if you have `n_node=[1,2]`, we construct the tensor
        # [0, 1, 1]. We then do the same for edges.
        node_gr_idx = jnp.repeat(graph_idx, n_node, axis=0, total_repeat_length=sum_n_node)
        # edge_gr_idx = jnp.repeat(graph_idx, n_edge, axis=0, total_repeat_length=sum_n_edge)
        # We use the aggregation function to pool the nodes/edges per graph.
        node_attributes = jax.tree_map(lambda n: jax.ops.segment_sum(n, node_gr_idx, n_graph), nodes)
        pooled = flax.linen.Dense(self.out_size, use_bias = True)(node_attributes)
        return pooled


class MainCrossBlock(flax.linen.Module):
    node_d_model : int

    def setup(self):
        self.LN_cross_mha_XSS_S = flax.linen.LayerNorm()
        self.LN_cross_mha_XSS_X = flax.linen.LayerNorm()
        self.cross_mha_XSS      = flax.linen.MultiHeadDotProductAttention(num_heads = 4, dropout_rate = 0.1)
        # self.cross_mha_aXX      = flax.linen.MultiHeadDotProductAttention(num_heads = 4, dropout_rate = 0.1)
        self.LN_FFN_XSS         = flax.linen.LayerNorm()
        self.FFN_XSS            = FeedForwardNetwork(widening_factor = 2, dropout_rate = 0.2)
        self.mpnn_Q             = GIN(mlp_num_layers = 2, 
                                    mlp_hidden_dim = self.node_d_model, 
                                    node_embedding_size = self.node_d_model, 
                                    n_layers = 2)
        self.LN_cross_mha_SXX_S = flax.linen.LayerNorm()
        self.LN_cross_mha_SXX_X = flax.linen.LayerNorm()
        self.cross_mha_SXX      = flax.linen.MultiHeadDotProductAttention(num_heads = 4, dropout_rate = 0.1)
        # self.cross_mha_aSS      = flax.linen.MultiHeadDotProductAttention(num_heads = 4, dropout_rate = 0.1)
        self.LN_FFN_SXX         = flax.linen.LayerNorm()
        self.FFN_SXX            = FeedForwardNetwork(widening_factor = 2, dropout_rate = 0.2)
        # Output:
        self.LN_selfAttn            = flax.linen.LayerNorm()
        self.selfAttn               = flax.linen.SelfAttention(num_heads = 4, dropout_rate = 0.1)
        self.LN_FFN_selfAttn_out    = flax.linen.LayerNorm()
        self.FFN_selfAttn_out       = FeedForwardNetwork(widening_factor = 2, dropout_rate = 0.2)

    def __call__(self, mols, S, mols_padding_mask, seq_attn_mask, attn_mask_XS, attn_mask_XX, attn_mask_SX, attn_mask_SS,  deterministic):
        batch_size = S.shape[0]
        # --------------------------
        # cross mha XSS:
        X = jnp.reshape(mols.nodes, (batch_size, -1, self.node_d_model))
        H_Q = X
        F_KV = S
        F_KV = self.LN_cross_mha_XSS_S(F_KV)
        H_Q  = self.LN_cross_mha_XSS_X(H_Q)
        H = self.cross_mha_XSS(inputs_q = H_Q, inputs_kv = F_KV, mask = attn_mask_XS, deterministic = deterministic)
        # H = jax.nn.relu(H)
        # H = self.cross_mha_aXX(inputs_q = H, inputs_kv = H_Q, mask = attn_mask_XX, deterministic = deterministic)
        H = H * jnp.expand_dims(mols_padding_mask, axis = -1)
        X = H + H_Q # residual X_in + X_attn

        # FFN:
        H = self.LN_FFN_XSS(X)
        H = self.FFN_XSS(H, deterministic = deterministic)
        H = H + X # residual X_in + X_attn + X_ffn
        H = jnp.reshape(H, (-1, self.node_d_model))
        mols = mols._replace(nodes = H)

        # GNN:
        mols = self.mpnn_Q(mols, deterministic = deterministic)
        mols = mols._replace(nodes = mols.nodes + H) # residual X_attn + X_mpnn
        
        # cross mha SXX:
        H_KV = jnp.reshape(mols.nodes, (batch_size, -1, self.node_d_model))
        F_Q = S
        F_Q  = self.LN_cross_mha_SXX_S(F_Q)
        H_KV = self.LN_cross_mha_SXX_X(H_KV)

        F = self.cross_mha_SXX(inputs_q = F_Q, inputs_kv = H_KV, mask = attn_mask_SX, deterministic = deterministic)
        # F = jax.nn.relu(F)
        # F = self.cross_mha_aSS(inputs_q = F, inputs_kv = F_Q, mask = attn_mask_SS, deterministic = deterministic)
        F = F * jnp.expand_dims(seq_attn_mask, axis = -1)
        S = F + S # residual S_in + S_attn

        # FFN:
        F = self.LN_FFN_SXX(S)
        F = self.FFN_SXX(F, deterministic = deterministic)
        S = S + F # residual S_attn + S_ffn

        # Self-Attn:
        F = self.LN_selfAttn(S)
        F = self.selfAttn(inputs_q = F, mask = attn_mask_SS, deterministic = deterministic)
        S = S + F

        # FFN:
        F = self.LN_FFN_selfAttn_out(S)
        F = self.FFN_selfAttn_out(F, deterministic = deterministic)
        S = S + F

        return mols, S



class ASMI(flax.linen.Module):
    out_features : int = 3
    node_d_model : int = 256 # 72 # 64
    edge_d_model : int = 36 # 32
    seq_d_model : int = -1
    vocab_size : int = -1
    edge_embedding_size : int = 36 # 128
    atom_features : Sequence = ('AtomicNum', 'ChiralTag', 'Hybridization', 'FormalCharge', 
                                'NumImplicitHs', 'ExplicitValence', 'Mass', 'IsAromatic')
    bond_features : Sequence = ('BondType', 'Stereo', 'IsAromatic') # ('BondType', 'Stereo', 'IsAromatic')
    mpnn_message_activation : Callable = flax.linen.tanh # flax.linen.relu
    # Internal attributes:
    _eps = 10e-7
    atom_embed_funcs = {}
    atom_embed_features_pos = {}
    atom_other_features_pos = []
    edge_embed_funcs = {}
    edge_embed_features_pos = {}
    edge_other_features_pos = []

    def setup(self):

        assert self.edge_d_model == self.edge_embedding_size

        # Atom embedding:
        # NOTE: cleanning needs to be done becuase after init atom_other_features_pos is updated
        #       Without cleaning this would lead to size mismatch.
        self.atom_embed_funcs.clear()
        self.atom_embed_features_pos.clear()
        self.atom_other_features_pos.clear()
        for i, name in enumerate(self.atom_features):
            if name == 'AtomicNum':
                self.atomic_num_embed = AtomicNumEmbedding(self.node_d_model)
                self.atom_embed_funcs[name] = self.atomic_num_embed
                self.atom_embed_features_pos[name] = i
            elif name == 'ChiralTag':
                self.chiral_tag_embed = ChiralTagEmbedding(self.node_d_model)
                self.atom_embed_funcs[name] = self.chiral_tag_embed
                self.atom_embed_features_pos[name] = i
            elif name == 'Hybridization':
                self.hybridization_embed = HybridizationEmbedding(self.node_d_model)
                self.atom_embed_funcs[name] = self.hybridization_embed
                self.atom_embed_features_pos[name] = i
            else:
                self.atom_other_features_pos.append(i)

        # Edge embedding:
        # NOTE: cleaning, see above.
        self.edge_embed_funcs.clear()
        self.edge_embed_features_pos.clear()
        self.edge_other_features_pos.clear()
        for i, name in enumerate(self.bond_features):
            if name == 'BondType':
                self.bond_type_embed = BondTypeEmbedding(self.edge_d_model)
                self.edge_embed_funcs[name] = self.bond_type_embed
                self.edge_embed_features_pos[name] = i
            elif name == 'Stereo':
                self.stereo_embed = StereoEmbedding(self.edge_d_model)
                self.edge_embed_funcs[name] = self.stereo_embed
                self.edge_embed_features_pos[name] = i
            else:
                self.edge_other_features_pos.append(i)

        self.X_proj_non_embeded = flax.linen.Dense(self.node_d_model) # kernel_regularizer is missing 
        self.E_proj_non_embeded = flax.linen.Dense(self.edge_d_model) # kernel_regularizer is missing

        self.X_proj = flax.linen.Dense(self.node_d_model) # kernel_regularizer is missing 
        self.E_proj = flax.linen.Dense(self.edge_d_model) # kernel_regularizer is missing

        self.compress_seq_embedding = flax.linen.Dense(features = self.node_d_model, use_bias = True)

        if self.seq_d_model > 0 and self.vocab_size > 0:
            self.mask_token_embedding = self.param('mask_token_embedding',
                                    flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0), # init
                                    (1, self.seq_d_model),
                                    jnp.float32)
    

        self.mpnn_Q_init = BasicTruncatedNormalDynamicMessagePassing(edge_embedding_size=self.edge_embedding_size, 
                                                                    message_activation=self.mpnn_message_activation, 
                                                                    a=3, b=9, mean=6.0, stddev=1.0)
        # -------------------------------------------------------
        self.main_cross_block = MainCrossBlock(node_d_model = self.node_d_model)

        # -------------------------------------------------------
        self.main_cross_block_1 = MainCrossBlock(node_d_model = self.node_d_model)

        # -------------------------------------------------------
        # Output:
        self.LN_concat_mha_out      = flax.linen.LayerNorm()
        self.concat_mha_out         = flax.linen.SelfAttention(num_heads = 4, dropout_rate = 0.1)
        self.LN_FFN_out             = flax.linen.LayerNorm()
        self.FFN_out                = FeedForwardNetwork(widening_factor = 2, dropout_rate = 0.2, out_size = 64)
        # Pooling:
        self.pool_logits = flax.linen.Dense(features = 1, use_bias = False)
        # add relu      
        self.dropout = flax.linen.Dropout(rate = 0.5)
        if self.out_features == 2:
            self.out = flax.linen.Dense(features = 1, use_bias = True)
        else:
            self.out = flax.linen.Dense(features = self.out_features, use_bias = True)

        # MLM:
        if self.vocab_size > 0:
            self.mlm_head = flax.linen.Dense(features = self.vocab_size, use_bias = True)



    def __call__(self, inputs, deterministic, input_ids_mask = None):
        seq, G = inputs
        S, seq_attn_mask = seq

        # if input_ids_mask is None or deterministic:
        #     input_ids_mask = jnp.zeros(S.shape[:-1], dtype=bool)
        #     # _tmp = jnp.round(S.shape[1]/10).astype(jnp.int64)
        #     # input_ids_mask = jnp.concatenate([jnp.ones((S.shape[0], _tmp), dtype=bool), jnp.zeros((S.shape[0], S.shape[1] - _tmp), dtype=bool)], axis = 1)

        # Masked sequence:
        if input_ids_mask is not None and not deterministic:
            S_mask_only = jnp.einsum('ij,jk->ijk', input_ids_mask, self.mask_token_embedding)
            input_ids_mask = jnp.expand_dims(input_ids_mask, axis = -1)
            S = S * jnp.logical_not(input_ids_mask) + S_mask_only

        mols = G

        batch_size = S.shape[0]
        n_seq_tokens = S.shape[1]

        assert 2*batch_size == len(mols.n_node)

        mols_padding_mask = mols.globals['node_padding_mask']
        line_mols_padding_mask = mols.globals['edge_padding_mask']
        mols = mols._replace(globals = None)

        # Construct attention masks:
        attn_mask_XX = jnp.einsum('bi,bj->bij', mols_padding_mask, mols_padding_mask)
        attn_mask_XX = jnp.expand_dims(attn_mask_XX, axis = 1)

        attn_mask_XS = jnp.einsum('bi,bj->bij', mols_padding_mask, seq_attn_mask)
        attn_mask_XS = jnp.expand_dims(attn_mask_XS, axis = 1)
        
        attn_mask_SX = jnp.einsum('bi,bj->bij', seq_attn_mask, mols_padding_mask)
        attn_mask_SX = jnp.expand_dims(attn_mask_SX, axis = 1)
        
        attn_mask_SS = jnp.einsum('bi,bj->bij', seq_attn_mask, seq_attn_mask)
        attn_mask_SS = jnp.expand_dims(attn_mask_SS, axis = 1)

        _attn_mask_concat = jnp.concatenate([mols_padding_mask, seq_attn_mask], axis = -1)
        attn_mask_concat = jnp.einsum('bi,bj->bij', _attn_mask_concat, _attn_mask_concat)
        attn_mask_concat = jnp.expand_dims(attn_mask_concat, axis = 1)

        # Embedding for atoms:
        X = mols.nodes
        _X_embed_tree = jax.tree_map(lambda idx, embed_fun: embed_fun(X[:, idx]), self.atom_embed_features_pos, self.atom_embed_funcs)
        _X_other = X[:, self.atom_other_features_pos]
        _X_other = self.X_proj_non_embeded(_X_other)      

        # Combining embeddings:
        _X = jnp.concatenate(jax.tree_util.tree_leaves(_X_embed_tree) + [_X_other], axis = -1)
        _X = jax.nn.relu(_X)
        _X = self.X_proj(_X)
        _X = _X * jnp.reshape(mols_padding_mask, newshape=(-1, 1)) # Set padding features back to 0.
        mols = mols._replace(nodes = _X)

        # Embedding for edges:
        E = mols.edges # line_mols.nodes
        _E_embed_tree = jax.tree_map(lambda idx, embed_fun: embed_fun(E[:, idx]), self.edge_embed_features_pos, self.edge_embed_funcs)
        _E_other = E[:, self.edge_other_features_pos]
        _E_other = self.E_proj_non_embeded(_E_other)

        # Combining embeddings:
        _E = sum(jax.tree_util.tree_leaves(_E_embed_tree)) + _E_other # TODO: Redundant??
        
        _E = jnp.concatenate(jax.tree_util.tree_leaves(_E_embed_tree) + [_E_other], axis = -1)
        _E = jax.nn.relu(_E)
        _E = self.E_proj(_E)
        _E = _E * jnp.reshape(line_mols_padding_mask, newshape=(-1, 1)) # Set padding features back to 0.
        mols = mols._replace(edges = _E)

        # --------------------------
        # init seq:
        S = self.compress_seq_embedding(S)

        # --------------------------
        # init GNN:
        mols = self.mpnn_Q_init(mols, deterministic = deterministic)

        # --------------------------
        mols, S = self.main_cross_block(mols, S, mols_padding_mask, seq_attn_mask, attn_mask_XS, attn_mask_XX, attn_mask_SX, attn_mask_SS,  deterministic)

        # --------------------------
        mols, S = self.main_cross_block_1(mols, S, mols_padding_mask, seq_attn_mask, attn_mask_XS, attn_mask_XX, attn_mask_SX, attn_mask_SS,  deterministic)

        # --------------------------
        H = jnp.reshape(mols.nodes, (batch_size, -1, self.node_d_model))
        C = jnp.concatenate([H, S], axis = 1)

        # self-attention:
        C = self.LN_concat_mha_out(C)
        _C = self.concat_mha_out(inputs_q = C, mask = attn_mask_concat, deterministic = deterministic)
        _C = _C * jnp.expand_dims(_attn_mask_concat, axis = -1)
        C = C + _C # residual S_in + S_attn

        # FFN_out:
        C = self.LN_FFN_out(C)
        C = self.FFN_out(C, deterministic = deterministic)

        # GlobalAttentionPool: (NOTE: not GlobalAttnSumPool used previously)
        # a = self.pool_logits(C)
        # a = jax.nn.sigmoid(a) * jnp.expand_dims(_attn_mask_concat, axis = -1)
        # GlobalAttnSumPool:
        # print(''''WARNING: Global sum pooling is probably incorrect!!!!!
        #       Line: jax.nn.softmax(a * jnp.expand_dims(_attn_mask_concat, axis = -1), axis = 1)
        #       is treating padding mask as 0 and 1 but it should be -inf and 1 so that in softmax the values are ignored''')
        a = self.pool_logits(C)
        big_neg = jnp.finfo(C.dtype).min
        a = jnp.where(jnp.expand_dims(_attn_mask_concat, axis = -1), a, big_neg)
        a = jax.nn.softmax(a, axis = 1).astype(C.dtype)
        a = a * jnp.expand_dims(_attn_mask_concat, axis = -1)
        x = jnp.sum(a * C, axis = 1)

        # out
        x = jax.nn.relu(x)
        x = self.dropout(x, deterministic = deterministic)
        x_out = self.out(x)

        # raise Exception('Many body... Implement Line graph here in the model. It is then more consistent and easier for loader implementation??')

        if self.vocab_size > 0:
            S_out = C[:,-n_seq_tokens:,:]
            mlm_out = self.mlm_head(S_out)
            return {'_main_label' : x_out, '_mlm_logits': mlm_out}
        else:
            return {'_main_label' : x_out}