"""
Module that implements the InverseJKOnet model based on the base interface.

The models are implemented using JAX and the FLAX library, following a functional paradigm to
support efficient differentiation and optimization. The core classes include:

- ``InverseJKOnetPonetial``: A variant focusing solely on the potential energy term.
"""

from typing import Any, Callable

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core import FrozenDict
from flax.training import train_state
from jax.numpy import ndarray
from ott.neural.networks.icnn import ICNN as ottICNN
from ott.neural.networks.potentials import MLP as ottMLP
from ott.neural.networks.potentials import PotentialMLP

from dataset import PopulationDataset
from models.base import LearningDiffusionModel
from networks.energies import MLP
from networks.icnns import ICNN
from networks.optim import (
    create_train_state,
    create_train_state_from_params,
    get_optimizer,
    penalize_weights_icnn,
)
from utils.mutinfo.estimators.knn import WKL

type InverseJKOnetState = tuple[
    train_state.TrainState, train_state.TrainState, train_state.TrainState, train_state.TrainState
]


class InverseJKOnet(LearningDiffusionModel):
    """
    The full InverseJKOnet model for learning all energy terms.
    """

    def __init__(self, config: dict, data_dim: int, tau: float,dataset) -> None:
        """
        Initialize the InverseJKO model.

        Parameters
        ----------
        config : dict
            Configuration dictionary containing model and optimizer settings.
        data_dim : int
            Dimension of the input data.
        tau : float
            Represents the time scale over which the diffusion process described by the
            Fokker-Planck equation is considered.
        """
        super().__init__()
        self.tau = tau
        self.data_dim = data_dim
        self._load_config(config)
        self._initialize_energy_model()
        self._initialize_otmap()

    def _load_config(self, config: dict) -> None:
        """Load and set configuration parameters."""
        self.batch_size = int(config["train"]["batch_size"] / (config["T"] - 1))
        self.layers = config["energy"]["model"]["layers"]
        self.config_optimizer = config["energy"]["optim"]

        self.otmap_config = config["otmap"]
        self.otmap_optimizer = get_optimizer(config["otmap"]["optim"])

    def _initialize_energy_model(self) -> None:
        """Initialize the energy model (potential, internal, and interaction)."""
        self.model_potential = MLP(self.layers)
        self.model_internal = MLP([1])
        self.model_interaction = MLP(self.layers)

    def _initialize_otmap(self) -> None:
        """Initialize the OT map model and its pushforward function."""
        self.model_otmap, self._get_pushforward = self._prepare_otmap(self.otmap_config["model"]["type"])

    def _prepare_otmap(
        self, model_type: str
    ) -> tuple[ICNN | PotentialMLP | ottMLP, Callable[[jnp.ndarray], jnp.ndarray]]:
        """Prepares the OT map model and corresponding pushforward function."""

        def _pushforward_fn(
            otmap_params: dict, t: int, use_grad: bool = False, full_output: bool = False
        ) -> Callable[[jnp.ndarray], jnp.ndarray]:
            """Returns a function that computes the predicted next-step rho."""
            grad_phi = jax.grad(self.model_otmap.apply, argnums=1) if use_grad else None

            def f(x):
                x_t = jnp.concatenate((x, jnp.array([t])))
                model_output = (
                    grad_phi({"params": otmap_params}, x_t)  # there was a bug wigh [:-1]
                    if use_grad
                    else self.model_otmap.apply({"params": otmap_params}, x_t)
                )
                if full_output:
                    assert x.shape[0] == model_output.shape[0]
                else:
                    assert x.shape[0] == model_output.shape[0] - 1
                return self.otmap_config["model"]["cvx_reg"] * x + (model_output if full_output else model_output[:-1])

            return f

        model_config = self.otmap_config["models"].get(model_type)
        if not model_config:
            raise ValueError(f"Unknown model_type: {model_type}!")

        if model_type in ["ottICNN", "ICNN", "PotentialMLP", "MLP"]:
            act_fn_name = model_config.get("act_fn")
            if act_fn_name:
                model_config["act_fn"] = getattr(nn, act_fn_name)
            init_fn_name = model_config.get("init_fn")
            if init_fn_name:
                model_config["init_fn"] = getattr(nn.initializers, init_fn_name)

        if model_type == "ottICNN":
            model = ottICNN(dim_data=self.data_dim, **model_config)
            pushforward = lambda params, t: _pushforward_fn(params, t, use_grad=True)
        elif model_type == "ICNN":
            model = ICNN(**model_config)
            pushforward = lambda params, t: _pushforward_fn(params, t, use_grad=True)
        elif model_type == "PotentialMLP":
            model = PotentialMLP(**model_config)
            pushforward = lambda params, t: _pushforward_fn(params, t, use_grad=False)
        elif model_type == "MLP":
            model_config["dim_hidden"].append(self.data_dim)
            model = ottMLP(**model_config)
            pushforward = lambda params, t: _pushforward_fn(params, t, use_grad=False, full_output=True)
        else:
            raise ValueError(f"Unknown model_type: {model_type}!")

        return model, pushforward

    def load_dataset(self, dataset_name: str) -> PopulationDataset:
        """
        Load and return a dataset based on the given dataset name.

        This method creates an instance of the `PopulationDataset` class using the specified dataset name.

        Parameters
        ----------
        dataset_name : str
            The name of the dataset to load. This name is used to locate and initialize the dataset.

        Returns
        -------
        PopulationDataset
            An instance of the `PopulationDataset` class, which contains the loaded dataset.
        """
        # TODO: think may it is better to add entropy estimation for dataset in general
        dataset = PopulationDataset(
            dataset_name, self.batch_size, data_file="train_data.npy", labels_file="train_sample_labels.npy"
        )
        entropy_estimator = WKL(k_neighbors=5)
        entropies = [None] * (max(dataset.trajectory.keys()) + 1)  # WARNING: it is assumed that keys are integers
        for label, rho in dataset.trajectory.items():
            entropies[label] = entropy_estimator.entropy(rho)

        self.entropies = jnp.array(entropies)
        print(f"Entropies: {self.entropies}")
        return dataset

    def create_state(self, rng: jax.random.PRNGKey) -> InverseJKOnetState:
        """
        Create initial training states for the potential, internal, interaction and OT map models.

        Parameters
        ----------
        rng : jax.random.PRNGKey
            Random key for initialization.

        Returns
        -------
        InverseJKOnetState
            Tuple containing the training states for the potential, internal, interaction and OT map models.
        """
        potential = create_train_state(rng, self.model_potential, get_optimizer(self.config_optimizer), self.data_dim)
        internal = create_train_state(rng, self.model_internal, get_optimizer(self.config_optimizer), 1)
        interaction = create_train_state(
            rng, self.model_interaction, get_optimizer(self.config_optimizer), self.data_dim
        )

        otmap = create_train_state(rng, self.model_otmap, self.otmap_optimizer, self.data_dim + 1)

        return potential, internal, interaction, otmap

    def create_state_from_params(
        self, potential_params: dict, internal_params: dict, interaction_params: dict, otmap_params: dict
    ) -> InverseJKOnetState:
        """
        Create training states from the provided parameters.

        Parameters
        ----------
        potential_params : dict
            Parameters for the potential model.
        internal_params : dict
            Parameters for the internal model.
        interaction_params : dict
            Parameters for the interaction model.
        otmap_params : dict
            Parameters for the otmap model.

        Returns
        -------
        InverseJKOnetState
            Tuple containing the training states for the potential, internal, interaction and otmap models.
        """
        potential = create_train_state_from_params(
            self.model_potential, potential_params, get_optimizer(self.config_optimizer)
        )
        internal = create_train_state_from_params(
            self.model_internal, internal_params, get_optimizer(self.config_optimizer)
        )
        interaction = create_train_state_from_params(
            self.model_interaction, interaction_params, get_optimizer(self.config_optimizer)
        )

        otmap = create_train_state_from_params(self.model_otmap, otmap_params, self.otmap_optimizer)

        return potential, internal, interaction, otmap

    def get_params(
        self,
        state: InverseJKOnetState,
    ) -> tuple[FrozenDict[str, Any], FrozenDict[str, Any], FrozenDict[str, Any]]:
        """
        Get parameters from the training state.

        Parameters
        ----------
        state : InverseJKOnetState
            Training state containing potential, internal, interaction and otmap models.

        Returns
        -------
        tuple[dict, dict, dict]
            Tuple containing the parameters for the potential, internal, and interaction models.
        """
        potential_state, internal_state, interaction_state, _ = state
        return potential_state.params, internal_state.params, interaction_state.params

    def _compute_potential(self, rho: jnp.ndarray, potential_params: dict) -> jnp.ndarray:
        """
        Computes MC estimate of potential energy based on batch `rho`
        """
        # potential
        potential_fn = jax.vmap(lambda x: self.model_potential.apply({"params": potential_params}, x))
        potential = potential_fn(rho)

        return jnp.mean(potential)

    def _compute_interaction(self, rho: jnp.ndarray, interaction_params: dict) -> jnp.ndarray:
        """
        Computes MC estimate of interaction energy based on batch `rho`
        """
        # interaction
        diffs = rho[:, None, :] - rho[None, :, :]  # shape [N, N, D]
        # TODO: not tested: taking off-diagonal elements
        lower_diffs = diffs[jnp.tril_indices(rho.shape[0], k=-1)]  # shape (N * (N-1) / 2, D)
        interactions_fn = jax.vmap(lambda d: self.model_interaction.apply({"params": interaction_params}, d))
        interaction = 0.5 * (interactions_fn(lower_diffs) + interactions_fn(-lower_diffs))
        assert interaction.shape[0] == rho.shape[0] * (rho.shape[0] - 1) // 2
        assert len(interaction.shape) == 1

        return jnp.mean(interaction)

    def _compute_beta(self, internal_params: dict) -> jnp.ndarray:
        """
        Computes (trainable) beta internal term (scalar)
        """
        return jnp.abs(self.model_internal.apply({"params": internal_params}, jnp.asarray([1])))

    def _compute_energy(self, rho: jnp.ndarray, potential_params: dict, interaction_params: dict) -> jnp.ndarray:
        """Computes the energy without entropy."""

        return self._compute_potential(rho, potential_params) + self._compute_interaction(rho, interaction_params)

    def _compute_entropy_contribution_for_energy_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> jnp.ndarray:
        beta = self._compute_beta(internal_params)
        log_det = self._compute_log_det_pushforward(rho, otmap_params, t)

        return beta * (-self.entropies[t + 1] + self.entropies[t] + log_det)

    def _compute_entropy_contribution_for_otmap_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> jnp.ndarray:
        beta = self._compute_beta(internal_params)
        log_det = self._compute_log_det_pushforward(rho, otmap_params, t)

        return beta * (-self.entropies[t] - log_det)

    def _compute_log_det_pushforward(self, rho: jnp.ndarray, otmap_params: dict, t: int) -> jnp.ndarray:
        """Computes the log determinant using Monte Carlo estimation."""
        pushforward = self._get_pushforward(otmap_params, t)
        jacobian_fn = jax.jacfwd(pushforward)  # jax.jacrev(pushforward)
        J_matrices = jax.vmap(jacobian_fn)(rho)

        _, log_dets = jax.vmap(lambda J: jnp.linalg.slogdet(J))(J_matrices)

        return jnp.mean(log_dets)

    def loss_fn_otmap(
        self,
        potential_params: dict,
        internal_params: dict,
        interaction_params: dict,
        otmap_params: dict,
        batch: jnp.ndarray,
        t: int,
    ) -> jnp.ndarray:
        """
        Compute the loss for the otmap.

        Parameters
        ----------
        potential_params : dict
            Parameters for the potential model.
        internal_params : dict
            Parameters for the internal model.
        interaction_params : dict
            Parameters for the interaction model.
        otmap_params : dict
            Parameters for the otmap model.
        batch : jnp.ndarray
            Array of shape (num_timesteps, num_particles, num_features) containing
            the particle trajectories. Each entry in the array represents a particle's
            state at a given timestep.
        t : int
            Target particle distribution.
        key : jax.random.PRNGKey.
             Random key for Monte-Carlo estimation of log det.

        Returns
        -------
        jnp.ndarray
            Total loss value.
        """
        rho = batch[t]
        rho_next_gt = batch[t + 1]
        pushforward = self._get_pushforward(otmap_params, t)
        rho_next = jax.vmap(pushforward)(rho)

        energy = self._compute_energy(rho_next, potential_params, interaction_params)
        entropy_contribution = self._compute_entropy_contribution_for_otmap_update(
            rho, internal_params, otmap_params, t
        )
        energy_gt = self._compute_energy(rho_next_gt, potential_params, interaction_params)

        loss_e = energy + entropy_contribution
        loss_e_gt = energy_gt
        loss_p = jnp.mean(jnp.sum((rho_next - rho) ** 2, axis=1))
        loss = loss_e - loss_e_gt + (1 / (2 * self.tau)) * loss_p

        if self.otmap_config["model"]["type"] == "ICNN" and not self.otmap_config["models"]["ICNN"]["pos_weights"]:
            penalty = penalize_weights_icnn(otmap_params)
            # loss += self.otmap_config["optim"]["beta"] * penalty #TODO : it does not work!
            loss += self.otmap_config["model"]["icnn_weights_reg"] * penalty

        return loss

    def loss_fn_energy(
        self,
        potential_params: dict,
        internal_params: dict,
        interaction_params: dict,
        otmap_params: dict,
        batch: jnp.ndarray,
        t: int,
    ) -> jnp.ndarray:
        """
        Compute the total energy loss for the model by combining potential, internal, and interaction terms.

        Parameters
        ----------
        potential_params : dict
            Parameters for the potential model.
        internal_params : dict
            Parameters for the internal model.
        interaction_params : dict
            Parameters for the interaction model.
        otmap_params : dict
            Parameters for the otmap model.
        batch : jnp.ndarray
            Array of shape (num_timesteps, num_particles, num_features) containing
            the particle trajectories. Each entry in the array represents a particle's
            state at a given timestep.
        t : int
            Target particle distribution.

        Returns
        -------
        jnp.ndarray
            Total loss value.
        """
        rho = batch[t]
        rho_next_gt = batch[t + 1]
        pushforward = self._get_pushforward(otmap_params, t)
        rho_next_predicted = jax.vmap(pushforward)(rho)

        energy_predicted = self._compute_energy(rho_next_predicted, potential_params, interaction_params)
        energy_gt = self._compute_energy(rho_next_gt, potential_params, interaction_params)
        entropy_contribution = self._compute_entropy_contribution_for_energy_update(
            rho, internal_params, otmap_params, t
        )

        return energy_gt - energy_predicted + entropy_contribution

    def train_step(
        self, state: InverseJKOnetState, sample: list[jnp.ndarray]
    ) -> tuple[InverseJKOnetState, dict[str, jnp.ndarray]]:
        """
        Perform a single training step.

        Parameters
        ----------
        state : InverseJKOnetState
            Training state containing potential, internal, and interaction models.
        sample : list[jnp.ndarray]
            A list where each element is an array representing the state of a
            particle at each timestep.

        Returns
        -------
        tuple[InverseJKOnetState, dict[str, jnp.ndarray]]
            The updated training state and metrics.
        """
        batch = jnp.stack(sample, axis=0)

        grad_fn_otmap = jax.value_and_grad(jax.jit(self.loss_fn_otmap), argnums=3)
        grad_fn_energy = jax.value_and_grad(jax.jit(self.loss_fn_energy), argnums=(0, 1, 2))

        def accumulate_otmap_grads(
            carry: tuple[InverseJKOnetState, FrozenDict[str, Any]], timestep: int
        ) -> tuple[tuple[InverseJKOnetState, FrozenDict[str, Any]], dict[str, jnp.ndarray]]:
            state, accumulated_grads = carry
            potential, internal, interaction, otmap = state

            loss_otmap, grads = grad_fn_otmap(
                potential.params, internal.params, interaction.params, otmap.params, batch, timestep
            )

            new_accumulated_grads = jax.tree.map(lambda g_old, g_new: g_old + g_new, accumulated_grads, grads)
            return (state, new_accumulated_grads), {"loss_otmap": loss_otmap}

        def apply_otmap_updates(state: InverseJKOnetState, _) -> tuple[InverseJKOnetState, dict[str, jnp.ndarray]]:
            potential, internal, interaction, otmap = state

            init_otmap_grads = jax.tree.map(lambda p: jnp.zeros_like(p), otmap.params)
            initial_carry = (state, init_otmap_grads)

            final_carry, metrics = jax.lax.scan(accumulate_otmap_grads, initial_carry, jnp.arange(batch.shape[0] - 1))

            _, accumulated_otmap_grads = final_carry
            accumulated_otmap_grads = jax.tree.map(lambda x: x / batch.shape[0], accumulated_otmap_grads)
            otmap = otmap.apply_gradients(grads=accumulated_otmap_grads)
            return (potential, internal, interaction, otmap), metrics

        new_state, metrics_otmap = jax.lax.scan(
            apply_otmap_updates, state, None, length=self.otmap_config["optim"]["inner_iter"]
        )
        metrics_otmap = jax.tree.map(
            lambda x: x.mean(axis=1), metrics_otmap
        )  # dict[str, array] array of shape (inner_steps,)
        metrics_otmap = jax.tree.map(jnp.mean, metrics_otmap)  # dict[str, jnp.float]

        def accumulate_energy_grads(
            carry: tuple[InverseJKOnet, FrozenDict[str, Any]], timestep: int
        ) -> tuple[tuple[InverseJKOnet, FrozenDict[str, Any]], dict[str, jnp.ndarray]]:
            state, accumulated_grads = carry
            potential, internal, interaction, otmap = state

            loss_energy, grads = grad_fn_energy(
                potential.params, internal.params, interaction.params, otmap.params, batch, timestep
            )

            new_accumulated_grads = tuple(
                jax.tree.map(lambda g_old, g_new: g_old + g_new, g_old, g_new)
                for g_old, g_new in zip(accumulated_grads, grads)
            )
            return (state, new_accumulated_grads), {"loss_energy": -loss_energy}

        init_energy_grads = tuple(jax.tree.map(lambda p: jnp.zeros_like(p), state[i].params) for i in range(3))
        initial_carry = (new_state, init_energy_grads)

        final_carry, metrics_energy = jax.lax.scan(
            accumulate_energy_grads, initial_carry, jnp.arange(batch.shape[0] - 1)
        )
        metrics_energy = jax.tree.map(jnp.mean, metrics_energy)  # dict[str, jnp.float]

        # Apply a single update for energy gradients
        (potential, internal, interaction, otmap), accumulated_energy_grads = final_carry
        potential, internal, interaction = (
            s.apply_gradients(grads=jax.tree.map(lambda x: x / batch.shape[0], g))
            for s, g in zip((potential, internal, interaction), accumulated_energy_grads)
        )

        # Combine metrics and return updated state
        combined_metrics = metrics_energy | metrics_otmap
        return (potential, internal, interaction, otmap), combined_metrics

    def get_potential(self, state: InverseJKOnetState) -> Callable[[jnp.ndarray], jnp.ndarray]:
        """
        Get the potential function from the model state.

        Parameters
        ----------
        state : InverseJKOnetState
            Training state containing potential, internal, interaction and otmap models.

        Returns
        -------
        Callable[[jnp.ndarray], jnp.ndarray]
            Function that computes the potential.
        """
        potential, _, _, _ = state
        return lambda x: potential.apply_fn({"params": potential.params}, x)

    def get_interaction(self, state: InverseJKOnetState) -> Callable[[jnp.ndarray], jnp.ndarray]:
        """
        Get the interaction function from the model state.

        Parameters
        ----------
        state : InverseJKOnetState
            Training state containing potential, internal, interaction and otmap models.

        Returns
        -------
        Callable[[jnp.ndarray], jnp.ndarray]
            Function that computes the interaction.
        """
        _, _, interaction, _ = state
        return lambda x: 0.5 * (
            interaction.apply_fn({"params": interaction.params}, x)
            + interaction.apply_fn({"params": interaction.params}, -x)
        )

    def get_beta(self, state: InverseJKOnetState) -> float:
        """
        Get the beta value from the model state.

        Parameters
        ----------
        state : InverseJKOnetState
            Training state containing potential, internal, interaction and otmap models.

        Returns
        -------
        float
            The beta value from the internal energy model.
        """
        _, internal, _, _ = state
        return self._compute_beta(internal.params).item()
        # return abs(internal.apply_fn({"params": internal.params}, jnp.asarray([1])).item())


class InverseJKOnetPotentialInternal(InverseJKOnet):
    """
    A specialized variant of the InverseJKOnet model that only considers potential and internal terms.
    """

    def _compute_interaction(self, rho: jax.Array, interaction_params: dict) -> float:
        return 0.0

    def get_interaction(self, state: InverseJKOnetState) -> Callable[[jnp.ndarray], jnp.ndarray]:
        return lambda x: 0.0


class InverseJKOnetPotential(InverseJKOnetPotentialInternal):
    """
    A variant of the InverseJKOnet model to learn only the potential term.
    """

    def _compute_entropy_contribution_for_energy_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> float:
        return 0.0

    def _compute_entropy_contribution_for_otmap_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> jnp.ndarray:
        return 0.0

    def _compute_beta(self, internal_params: dict) -> float:
        return 0.0

    def get_beta(self, state: InverseJKOnetState) -> float:
        return 0.0


class InverseJKOnetInteractionInternal(InverseJKOnet):
    """
    A variant of the InverseJKOnet model to learn the interaction and internal terms.
    """

    def _compute_potential(self, rho: ndarray, potential_params: dict) -> float:
        return 0.0

    def get_potential(self, state: InverseJKOnetState) -> Callable[[jnp.ndarray], jnp.ndarray]:
        return lambda x: 0.0


class InverseJKOnetInteraction(InverseJKOnetInteractionInternal):
    """
    A variant of the InverseJKOnet model to learn only the interaction term.
    """

    def _compute_entropy_contribution_for_energy_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> float:
        return 0.0

    def _compute_entropy_contribution_for_otmap_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> jnp.ndarray:
        return 0.0

    def _compute_beta(self, internal_params: dict) -> ndarray:
        return 0.0

    def get_beta(self, state: InverseJKOnetState) -> float:
        return 0.0


class InverseJKOnetInternal(InverseJKOnet):
    """
    A specialized variant of the InverseJKOnet model that only considers internal (diffusion) term.
    """

    def _compute_potential(self, rho: ndarray, potential_params: dict) -> float:
        return 0.0

    def get_potential(self, state: InverseJKOnetState) -> Callable[[jnp.ndarray], jnp.ndarray]:
        return lambda x: 0.0

    def _compute_interaction(self, rho: jax.Array, interaction_params: dict) -> float:
        return 0.0

    def get_interaction(self, state: InverseJKOnetState) -> Callable[[jnp.ndarray], jnp.ndarray]:
        return lambda x: 0.0


class InverseJKOnetPotentialInteraction(InverseJKOnet):
    """
    A specialized variant of the InverseJKOnet model that only considers potential and interaction terms.
    """

    def _compute_entropy_contribution_for_energy_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> float:
        return 0.0

    def _compute_entropy_contribution_for_otmap_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> jnp.ndarray:
        return 0.0

    def _compute_beta(self, internal_params: dict) -> ndarray:
        return 0.0

    def get_beta(self, state: InverseJKOnetState) -> float:
        return 0.0


class InverseJKOnetTimePotential(InverseJKOnetPotential):

    def create_state(self, rng: jax.random.PRNGKey) -> InverseJKOnetState:
        """
        Create initial training states for the potential, internal, interaction and OT map models.

        Parameters
        ----------
        rng : jax.random.PRNGKey
            Random key for initialization.

        Returns
        -------
        InverseJKOnetState
            Tuple containing the training states for the potential, internal, interaction and OT map models.
        """
        potential = create_train_state(
            rng, self.model_potential, get_optimizer(self.config_optimizer), self.data_dim + 1
        )
        internal = create_train_state(rng, self.model_internal, get_optimizer(self.config_optimizer), 1)
        interaction = create_train_state(
            rng, self.model_interaction, get_optimizer(self.config_optimizer), (self.data_dim, self.data_dim)
        )

        otmap = create_train_state(rng, self.model_otmap, self.otmap_optimizer, self.data_dim + 1)

        return potential, internal, interaction, otmap

    def _compute_potential(
        self, rho: jnp.ndarray, potential_params: dict, interaction_params: dict, t: int
    ) -> jnp.ndarray:
        """Computes the energy without entropy."""

        t_array = t * jnp.ones([rho.shape[0], 1])
        rho_t = jnp.concat([rho, t_array], axis=1)

        def energy_fn(x: jnp.ndarray) -> jnp.ndarray:
            potential_energy = self.model_potential.apply({"params": potential_params}, x)
            return potential_energy

        return jnp.mean(jax.vmap(energy_fn)(rho_t))
