import flax.linen as nn
import jax.numpy as jnp

from typing import Sequence

from reform.utils.typing import Obs, Action, FloatScalar
from reform.utils.networks import MLP


class ActorVectorField(nn.Module):
    """Actor vector field network for flow matching.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        action_dim: Action dimension.
        layer_norm: Whether to apply layer normalization.
        encoder: Optional encoder module to encode the inputs.
    """

    hidden_dims: Sequence[int]
    action_dim: int
    layer_norm: bool = False
    encoder: nn.Module = None

    @nn.compact
    def __call__(self, observations: Obs, actions: Action, times: FloatScalar = None, is_encoded: bool = False):
        """Return the vectors at the given states, actions, and times (optional).

        Args:
            observations: Observations.
            actions: Actions.
            times: Times (optional).
            is_encoded: Whether the observations are already encoded.
        """
        if not is_encoded and self.encoder is not None:
            observations = self.encoder(observations)
        if times is None:
            inputs = jnp.concatenate([observations, actions], axis=-1)
        else:
            inputs = jnp.concatenate([observations, actions, times], axis=-1)

        v = MLP((*self.hidden_dims, self.action_dim), activate_final=False, layer_norm=self.layer_norm)(inputs)

        return v
