import e3x
import flax.linen as nn
import jax
import jax.numpy as jnp
import jraph

from typing import Callable
from .base import FeatureRepresentations
from .utils import get_e3x_activation_fn
from .utils import MLP
from .utils import promote_to_e3x
from .utils import modulate_adaLN
from ..jraph_utils import get_batch_segments, get_number_of_graphs, get_number_of_nodes
from hfm.backbones import embedding
from hfm.backbones.utils import make_to_jraph_batch_fn
from .base import BaseEdgeEmbedding, FeatureRepresentations
from .utils import get_activation_fn
from .utils import MLP
from .utils import promote_to_e3x
from .utils import modulate_adaLN
from ..jraph_utils import get_batch_segments, get_number_of_graphs, get_number_of_nodes

from .base import BaseLayer
from .base import FeatureRepresentations
from .utils import get_activation_fn
from .utils import MLP
from .utils import promote_to_e3x
from .utils import get_max_degree_from_tensor_e3x
from .utils import broadcast_equivariant_multiplication
from .utils import E3MLP
from .utils import EquivariantLayerNorm
from .utils import modulate_adaLN
from .utils import modulate_E3adaLN
from ..jraph_utils import get_number_of_nodes

import flax.linen as nn
import jraph

from jaxtyping import Array
from typing import Optional, Sequence

from .base import BaseReadout
from .base import GenerativeLayer
from .base import BaseNodeEmbedding
from .base import BaseTimeEmbedding
from .base import FeatureRepresentations
from .dit import DiTModel
from hfm.backbones import base


class NequIPInteractionBlock(nn.Module):
    max_degree: int
    include_pseudotensors: bool
    activation_fn: str = 'silu'

    @nn.compact
    def __call__(
            self,
            graph,
            features: FeatureRepresentations,
            **kwargs
    ):
        """Call function for InteractionBlock.

        Args:
            graph: jraph.GraphsTuple
            features: FeatureRepresentations

        Returns:
            Updated features.
        """

        features_node = features.nodes  # (num_node, 1 or 2, (max_degree_in + 1)**2, num_features)
        features_edges = features.edges  # (num_pairs, 1 or 2, (max_degree_in + 1)**2, num_edge_features)

        num_features = features_node.shape[-1]
        num_nodes = len(features_node)
        activation_fn = get_e3x_activation_fn(self.activation_fn)

        src_idx = graph.receivers
        dst_idx = graph.senders

        # Pre convolution self-interaction.
        z = e3x.nn.Dense(num_features)(features_node)

        # The final Dense layer is applied in the MessagePass.
        basis = e3x.nn.silu(
            e3x.nn.Dense(num_features)(
                e3x.nn.silu(e3x.nn.Dense(num_features)(features_edges))
            )
        )

        y = e3x.nn.MessagePass(
            max_degree=self.max_degree,
            include_pseudotensors=self.include_pseudotensors,
        )(
            z,
            basis,
            dst_idx=dst_idx,
            src_idx=src_idx,
            num_segments=num_nodes
        )

        # Post convolution self-interaction.
        y = e3x.nn.Dense(num_features)(y)

        # Make sure that residual connection cannot increase degrees or add
        # pseudotensors.
        features_node = e3x.nn.change_max_degree_or_type(
            features_node,
            max_degree=self.max_degree,
            include_pseudotensors=self.include_pseudotensors,
        )

        features_node = e3x.nn.Dense(num_features)(features_node)

        # Skip connection around mp block.
        y = e3x.nn.add(features_node, y)
        y = activation_fn(y)

        return FeatureRepresentations(
            nodes=y,  # Updates node features
            edges=features.edges  # Basis expansion will be re-used in the next layer.
        )


class EquivariantReadout(BaseReadout):
    activation_fn: str

    @nn.compact
    def __call__(
            self,
            graph: jraph.GraphsTuple,
            features: FeatureRepresentations,
            features_time: Array,
            *args,
            **kwargs
    ):

        features_nodes = features.nodes  # (num_nodes, 1, 1, num_features)
        assert features_nodes.ndim == 4, 'Features are assumed to be in the e3x convention.'

        num_features = features_nodes.shape[-1]
        num_nodes = len(features_nodes)

        # Calculate the shift and scale parameters for adaLN and adaLN-Zero
        c = nn.LayerNorm()(features_time)  # (num_nodes, 1, 1, num_features)

        y = e3x.nn.change_max_degree_or_type(features_nodes, max_degree=1, include_pseudotensors=False) # (num_nodes, 1, 4, num_features)

        act_fn = get_activation_fn(self.activation_fn)

        scale, shift = jnp.split(
            nn.Dense(
                features=3 * num_features,
                kernel_init=jax.nn.initializers.zeros
            )(
                act_fn(c)
            ),
            indices_or_sections=np.array(
                [
                    2 * num_features,
                ]
            ),
            axis=-1
        )

        scale = scale.reshape(num_nodes, 1, 2, num_features)
        shift = shift.reshape(num_nodes, 1, 1, num_features)

        y = modulate_E3adaLN(
            x=EquivariantLayerNorm(use_scale=False, use_bias=False)(y),
            scale=scale,
            shift=shift
        ) # (num_nodes, 1, 4, num_features)

        mean_v = E3MLP(
            num_layers=2,
            activation_fn=self.activation_fn,
            num_features=(num_features, 1)
        )(
            y
        )[:, 0, 1:, 0] # (num_nodes, 3)

        mean_f = E3MLP(
            num_layers=2,
            activation_fn=self.activation_fn,
            num_features=(num_features, 1)
        )(
            y
        )[:, 0, 1:, 0] # (num_nodes, 3)

        y = e3x.nn.change_max_degree_or_type(y, max_degree=0, include_pseudotensors=False) # (num_nodes, 1, 1, num_features)
        y = jnp.squeeze(y, axis=(1, 2))
        
        energy = MLP(
            num_layers=2,
            activation_fn=self.activation_fn,
            num_features=(num_features, 1),
            use_bias=True
        )(
            y
        ) # (num_nodes, 1)

        # sum energy contributions over atoms in each graph
        num_graphs = get_number_of_graphs(graph)
        batch_segments = get_batch_segments(graph)
        energy = jax.ops.segment_sum(energy, batch_segments, num_graphs) # (num_graphs, 1)

        return mean_v, mean_f, energy
    

def make_nequip(
        num_layers: int,
        num_features: int,
        max_degree: int,
        cutoff: float,
        batch_size: int,
        num_atoms: int,
        include_pseudotensors: bool = True,
        rpe_radial_basis: str = 'basic_fourier',
        rpe_num_basis: int = 10,
        activation_fn: str = 'silu',
        scale_pos: float = 1.0,
        scale_mom: float = 1.0,
        name: str = "nequip",
        embed_positions_absolute: bool = False,
        embed_velocities_absolute: bool = True,
        embed_velocities_relative: bool = False,
        embed_masses: bool = False,
        velocities_embed_encode_magnitude: bool = False,
        velocities_embed_num_basis: int = 8,
        velocities_embed_max_frequency: float = 4*jnp.pi,
        force_as_grad: bool = False,
        **kwargs
):
    
    if "embed_positions_relative" in kwargs:
        raise DeprecationWarning(
            f'`embed_positions_relative` has been depracated, since this should be always done. At least for the moment.'
        )

    if embed_velocities_relative:
        raise NotImplemented(f'`embed_velocities_relative` is not implemented for DiTSO3 yet.')

    if embed_positions_absolute:
        raise NotImplemented(f'`embed_positions_absolute` is not implemented for DiTSO3 yet.')


    num_features = num_heads * num_features_head

    if num_features_mlp is None:
        num_features_mlp = 4 * num_features

    # Time embedding
    time_embedding = embedding.TimeEmbedding(
        num_features=num_features,
        activation_fn=activation_fn
    )

    # Node embedding
    node_embedding = embedding.SO3NodeEmbed(
        num_features=num_features,
        activation_fn=activation_fn,
        mass_embedding_bool=embed_masses,
        # positional_embedding_bool=embed_positions_absolute,
        velocities_embedding_bool=embed_velocities_absolute,
        velocities_encode_magnitude_bool=velocities_embed_encode_magnitude,
        velocities_num_basis=velocities_embed_num_basis,
        velocities_max_frequency=velocities_embed_max_frequency
    )

    # Edge embedding
    edge_embedding = embedding.SO3EdgeEmbedding(
        max_degree=max_degree,
        activation_fn=activation_fn,
        cutoff=cutoff,
        radial_basis=rpe_radial_basis,
        num_basis=rpe_num_basis,
        # embed_rel_velocities_bool=embed_velocities_relative
    )

    # Readout block
    readout_block = EquivariantReadout(
        activation_fn=activation_fn
    )

    layers = []

    for _ in range(num_layers):
        layers.append(
            base.GenerativeLayer(
                encoder=NequIPInteractionBlock(
                    max_degree=max_degree,
                    include_pseudotensors=include_pseudotensors,
                    activation_fn=activation_fn,
                )
            )
        )

    return DiTModel(
        time_embedding=time_embedding,
        node_embedding=node_embedding,
        edge_embedding=edge_embedding,
        layers=layers,
        readout=readout_block,
        name=name,
        batch_fn=make_to_jraph_batch_fn(num_graph=batch_size, num_node=num_atoms),
        scale_pos=scale_pos,
        scale_mom=scale_mom,
        force_as_grad=force_as_grad
    )
