"""
Module that implements the InverseJKOnet model based on the base interface.

The models are implemented using JAX and the FLAX library, following a functional paradigm to
support efficient differentiation and optimization. The core classes include:

- ``InverseJKOnetPotential``: A variant focusing solely on the potential energy term.
"""

from typing import Any, Callable

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct
from flax.core import FrozenDict
from flax.training.train_state import TrainState
from jax import lax
from jax.nn import one_hot
from jax.random import beta as beta_sample
from jax.scipy.stats import beta

from dataset import PopulationDataset
from models.base import LearningDiffusionModel
#from models.inverse_jko_multimap2 import spectral_norm_jacobian_fn
from networks.energies import MLP
from networks.maps import GradientICNNMap, TimeVaryingTransport, TransportMap, UNetMLP
from networks.optim import create_train_state, get_optimizer
from utils.mutinfo.estimators.knn import WKL

def spectral_norm_jacobian_fn(
    j_fn: Callable[[jnp.ndarray], jnp.ndarray],
    x: jnp.ndarray,
    n_power_iter: int = 1,
    eps: float = 1e-12,
) -> jnp.ndarray:
    """
    Estimate the spectral norm of the Jacobian ∇ₓ J(x) using power iteration.

    Parameters
    ----------
    j_fn : Callable
        Scalar-valued function J(x): ℝ^D → ℝ
    x : jnp.ndarray
        Input of shape (N, D)
    n_power_iter : int
        Number of power iterations
    eps : float
        Numerical stability term

    Returns
    -------
    jnp.ndarray
        Spectral norm estimates, shape (N,)
    """

    def single_sample(x_i, key):
        # Compute full Jacobian of j_fn at x_i: shape (1, D)
        J = jax.jacrev(j_fn)(x_i)[None, :]  # shape (1, D)

        # Initialize random vector u for power iteration: shape (1,)
        u = jax.random.normal(key, (1,))
        u = u / (jnp.linalg.norm(u) + eps)

        def power_iter(u, _):
            v = jnp.matmul(u, J)  # (1, D)
            v = v / (jnp.linalg.norm(v) + eps)
            u_new = jnp.matmul(v, J.T)  # (1, 1)
            u_new = u_new / (jnp.linalg.norm(u_new) + eps)
            return u_new, None

        u_final, _ = jax.lax.scan(power_iter, u, None, length=n_power_iter)
        v_final = jnp.matmul(u_final, J)  # (1, D)
        spectral_norm = jnp.linalg.norm(v_final)
        return spectral_norm

    assert x.ndim == 2  # x: (N, D)
    keys = jax.random.split(jax.random.PRNGKey(0), x.shape[0])
    spectral_norms = jax.vmap(single_sample)(x, keys)  # (N,)
    return spectral_norms


@struct.dataclass
class InverseJKOnetState:
    potential: TrainState
    internal: TrainState
    interaction: TrainState
    otmaps: TrainState
    key: jax.random.PRNGKey
    potential_ema_params: Any = None
    internal_ema_params: Any = None
    interaction_ema_params: Any = None


def ema_update(ema_params, new_params, decay):
    return jax.tree_util.tree_map(lambda e, p: decay * e + (1 - decay) * p, ema_params, new_params)


#EMA_FROM_EPOCH = 1000
EMA_FROM_EPOCH = 100


def maybe_ema_update(step, ema_params, new_params, decay):
    return lax.cond(
        step > EMA_FROM_EPOCH, lambda _: ema_update(ema_params, new_params, decay), lambda _: ema_params, operand=None
    )


def beta_weights(T: int, alpha: float = 0.9, beta_param: float = 0.9, eps: float = 1e-3) -> jnp.ndarray:
    t_norm = jnp.linspace(eps, 1 - eps, T)

    raw_weights = beta.pdf(t_norm, a=alpha, b=beta_param)

    weights = raw_weights / jnp.sum(raw_weights)
    return weights


def sample_beta_weights(key: jax.random.PRNGKey, T: int, alpha: float, beta_param: float) -> jnp.ndarray:
    weights = beta_sample(key, a=alpha, b=beta_param, shape=(T,))
    return weights / jnp.sum(weights)


def random_one_hot_weights(key: jax.random.PRNGKey, T: int) -> jnp.ndarray:
    idx = jax.random.randint(key, shape=(), minval=0, maxval=T)  # random index in [0, T)
    weights = one_hot(idx, T)  # shape (T,), 1 at idx, 0 elsewhere
    return weights


def aggregate_losses(
    losses: jnp.ndarray, method: str = "mean", weights: jnp.ndarray = None, alpha: float = 1.0
) -> jnp.ndarray:
    if method == "mean":
        return jnp.mean(losses)

    elif method == "softmax":
        soft_weights = jax.nn.softmax(alpha * losses)
        return jnp.sum(soft_weights * losses)

    elif method in {"weighted", "random_one"}:
        assert weights is not None, "Weights must be provided for 'weighted' aggregation."
        return jnp.sum(weights * losses)

    elif method == "max":
        return jnp.max(losses)

    elif method == "norm":
        return jnp.linalg.norm(losses, ord=None)

    elif method == "infnorm":
        return jnp.linalg.norm(losses, ord=jnp.inf)

    else:
        raise ValueError(f"Unknown aggregation method: {method}")


class InverseJKOnetMultimap(LearningDiffusionModel):
    """
    The full InverseJKOnet model for learning all energy terms.
    """

    def __init__(self, config: dict, data_dim: int, tau, dataset) -> None:
        """
        Initialize the InverseJKO model.

        Parameters
        ----------
        config : dict
            Configuration dictionary containing model and optimizer settings.
        data_dim : int
            Dimension of the input data.
        tau : float
            Represents the time scale over which the diffusion process described by the
            Fokker-Planck equation is considered.
        """
        super().__init__()
        # self.tau = tau
        self.tau = tau
        self.data_dim = data_dim
        self.ema_decay = 0.999
        self.use_ema = True
        self.dataset = dataset
        self._load_config(config)
        self._initialize_energy_model()
        self._initialize_otmaps()

    def _load_config(self, config: dict) -> None:
        """Load and set configuration parameters."""
        self.batch_size = int(config["train"]["batch_size"] / (config["T"] - 1))
        self.layers = config["energy"]["model"]["layers"]
        self.config_energy_optimizer = config["energy"]["optim"]
        self.K = config["T"]

        self.otmap_config = config["otmap"]
        self.config_otmap_optimizer = config["otmap"]["optim"]

        self.config = config

    def _initialize_energy_model(self) -> None:
        """Initialize the energy model (potential, internal, and interaction)."""
        self.model_potential = MLP(self.layers)
        self.model_internal = MLP([1])
        self.model_interaction = MLP(self.layers)

        self.energy_gp_reg = self.config["train"]["energy_gp_reg"]

        print(
            f"Initializing energy with {self.config["loss_averaging"]["energy"]} aggregation and gradient-penalty: {self.energy_gp_reg} ..."
        )
        self.get_energy_weights = self._make_loss_weights_fn(strategy=self.config["loss_averaging"]["energy"])
        self.aggregate_energy_losses = self._make_aggregate_fn(method=self.config["loss_averaging"]["energy"])

        self.lr_log_sigma2 = 1e-3
        self.log_sigma2_energy = MLP([1])

    def _initialize_otmaps(self) -> None:
        """Initialize time-varying transport maps."""
        model_type = self.otmap_config["model"]["type"]

        if model_type == "gradICNN":
            net_kwargs_list = []
            for k in range(self.K - 1):
                source_k = jnp.array(self.dataset.trajectory[k])
                target_k = jnp.array(self.dataset.trajectory[k + 1])
                net_kwargs_list.append(
                    {
                        "gaussian_map_samples": (source_k, target_k),
                        "init_fn": nn.initializers.normal(stddev=0.1 + 0.01 * k),
                    }
                )

            self.model_otmaps = TimeVaryingTransport(
                K=self.K,
                layers=self.otmap_config["models"][model_type]["dim_hidden"],
                input_dim=self.data_dim,
                net_cls=GradientICNNMap,
                net_kwargs_list=net_kwargs_list,
            )
        else:
            str_to_class = {"MLP": TransportMap, "UNet": UNetMLP}
            model_cls = str_to_class[model_type]
            net_kwargs_list = [{} for _ in range(self.K - 1)]

            self.model_otmaps = TimeVaryingTransport(
                K=self.K,
                layers=self.otmap_config["models"][model_type]["dim_hidden"],
                input_dim=self.data_dim,
                net_cls=model_cls,
                net_kwargs_list=net_kwargs_list,
                reg=self.otmap_config["model"]["cvx_reg"],
            )

        print(f"Initializing otmaps with {self.config["loss_averaging"]["otmap"]} aggregation...")
        self.get_otmaps_weights = self._make_loss_weights_fn(strategy=self.config["loss_averaging"]["otmap"])
        self.aggregate_otmap_losses = self._make_aggregate_fn(method=self.config["loss_averaging"]["otmap"])

    def _make_loss_weights_fn(self, strategy: str):
        """Return a function to compute weights based on the strategy."""
        if strategy == "weighted":
            return lambda key: sample_beta_weights(key, self.K - 1, **self.config["sampling"]["beta"])
        elif strategy == "random_one":
            return lambda key: random_one_hot_weights(key, self.K - 1)
        else:
            return lambda key: jnp.ones((self.K - 1,))

    def _make_aggregate_fn(self, method: str):
        """Return a function to aggregate losses based on the averaging method."""
        return lambda losses, weights: aggregate_losses(losses, method=method, weights=weights)

    def load_dataset(self, dataset_name: str) -> PopulationDataset:
        """
        Load and return a dataset based on the given dataset name.

        This method creates an instance of the `PopulationDataset` class using the specified dataset name.

        Parameters
        ----------
        dataset_name : str
            The name of the dataset to load. This name is used to locate and initialize the dataset.

        Returns
        -------
        PopulationDataset
            An instance of the `PopulationDataset` class, which contains the loaded dataset.
        """
        # TODO: think may it is better to add entropy estimation for dataset in general
        dataset = PopulationDataset(dataset_name, self.batch_size, "train_data.npy", "train_sample_labels.npy")

        # entropy_estimator = WKL(k_neighbors=5)
        # entropies = [None] * len(dataset.trajectory.keys())  # WARNING: it is assumed that keys are integers
        # for label, rho in dataset.trajectory.items():
        #     entropies[label] = entropy_estimator.entropy(rho)

        # self.entropies = jnp.array(entropies)
        self.dataset = dataset
        # print(f"Entropies: {self.entropies}")
        return dataset

    def create_state(self, rng: jax.random.PRNGKey) -> InverseJKOnetState:
        """
        Create initial training states for the potential, internal, interaction and OT map models.

        Parameters
        ----------
        rng : jax.random.PRNGKey
            Random key for initialization.

        Returns
        -------
        InverseJKOnetState
            Tuple containing the training states for the potential, internal, interaction and OT map models.
        """
        initial_potential_state = create_train_state(
            rng, self.model_potential, get_optimizer(self.config_energy_optimizer), self.data_dim
        )
        initial_internal_state = create_train_state(
            rng, self.model_internal, get_optimizer(self.config_energy_optimizer), 1
        )
        initial_interaction_state = create_train_state(
            rng, self.model_interaction, get_optimizer(self.config_energy_optimizer), self.data_dim
        )
        initial_otmaps_state = create_train_state(
            rng, self.model_otmaps, get_optimizer(self.config_otmap_optimizer), (self.K - 1, self.data_dim)
        )
        state = InverseJKOnetState(
            potential=initial_potential_state,
            internal=initial_internal_state,
            interaction=initial_interaction_state,
            otmaps=initial_otmaps_state,
            potential_ema_params=initial_potential_state.params,
            internal_ema_params=initial_internal_state.params,
            interaction_ema_params=initial_interaction_state.params,
            key=rng,
        )
        return state

    def create_state_from_params(
        self, potential_params: dict, internal_params: dict, interaction_params: dict, otmap_params: dict
    ) -> InverseJKOnetState:
        """
        Create training states from the provided parameters.

        Parameters
        ----------
        potential_params : dict
            Parameters for the potential model.
        internal_params : dict
            Parameters for the internal model.
        interaction_params : dict
            Parameters for the interaction model.
        otmap_params : dict
            Parameters for the otmap model.

        Returns
        -------
        InverseJKOnetState
            Tuple containing the training states for the potential, internal, interaction and otmap models.
        """
        raise NotImplementedError("`create_state_from_params` has not implemented yet!")

    def get_params(
        self,
        state: InverseJKOnetState,
    ) -> tuple[FrozenDict[str, Any], FrozenDict[str, Any], FrozenDict[str, Any]]:
        """
        Get parameters from the training state.

        Parameters
        ----------
        state : InverseJKOnetState
            Training state containing potential, internal, interaction and otmap models.

        Returns
        -------
        tuple[dict, dict, dict]
            Tuple containing the parameters for the potential, internal, and interaction models.
        """
        return (
            (state.potential_ema_params, state.internal_ema_params, state.interaction_ema_params)
            if self.use_ema
            else (
                state.potential.params,
                state.internal.params,
                state.interaction.params,
            )
        )

    def _compute_energy_single_step(
        self, rho: jnp.ndarray, potential_params: dict, interaction_params: dict
    ) -> jnp.ndarray:
        """Computes the energy without entropy."""

        diffs = rho[:, None, :] - rho[None, :, :]  # shape [N, N, D]
        interactions_fn = jax.vmap(jax.vmap(lambda d: self.model_interaction.apply({"params": interaction_params}, d)))
        interaction = interactions_fn(diffs)

        potential_fn = jax.vmap(lambda x: self.model_potential.apply({"params": potential_params}, x))

        return jnp.mean(potential_fn(rho)) + jnp.mean(interaction)

    def loss_fn_otmap(
        self,
        potential_params: dict,
        internal_params: dict,
        interaction_params: dict,
        otmap_params: dict,
        batch: jnp.ndarray,
        epoch: int,
    ) -> tuple[jnp.ndarray, jnp.ndarray]:
        """
        Compute the loss for the otmap.

        Parameters
        ----------
        potential_params : dict
            Parameters for the potential model.
        internal_params : dict
            Parameters for the internal model.
        interaction_params : dict
            Parameters for the interaction model.
        otmap_params : dict
            Parameters for the otmap model.
        batch : jnp.ndarray
            Array of shape (num_timesteps, num_particles, num_features) containing
            the particle trajectories. Each entry in the array represents a particle's
            state at a given timestep.

        Returns
        -------
        jnp.ndarray
            Total loss value.
        """
        rho = batch[: self.K - 1]  # (K - 1, N, D)

        rho_next_gt = batch[1:]  # (K - 1, N, D)
        rho_next = jnp.swapaxes(
            jax.vmap(self.model_otmaps.apply, in_axes=(None, 0))({"params": otmap_params}, jnp.swapaxes(rho, 0, 1)),
            1,
            0,
        )  # (K - 1, N, D)

        energy = jax.vmap(self._compute_energy_single_step, in_axes=(0, None, None))(
            rho_next, potential_params, interaction_params
        )  # (K - 1)
        energy_gt = jax.vmap(self._compute_energy_single_step, in_axes=(0, None, None))(
            rho_next_gt, potential_params, interaction_params
        )  # (K - 1)

        loss_e = energy
        loss_e_gt = energy_gt
        loss_p = jnp.mean(jnp.sum((rho_next - rho) ** 2, axis=2), axis=1)
        base_loss = loss_e - loss_e_gt + (1 / (2 * self.tau)) * loss_p
        #print(self.tau)
        return base_loss  # K - 1

    def loss_fn_energy(
        self,
        potential_params: dict,
        internal_params: dict,
        interaction_params: dict,
        otmap_params: dict,
        batch: jnp.ndarray,
    ) -> jnp.ndarray:
        """
        Compute the total energy loss for the model by combining potential, internal, and interaction terms.

        Parameters
        ----------
        potential_params : dict
            Parameters for the potential model.
        internal_params : dict
            Parameters for the internal model.
        interaction_params : dict
            Parameters for the interaction model.
        otmap_params : dict
            Parameters for the otmap model.
        batch : jnp.ndarray
            Array of shape (num_timesteps, num_particles, num_features) containing
            the particle trajectories. Each entry in the array represents a particle's
            state at a given timestep.

        Returns
        -------
        jnp.ndarray
            Total loss value.
        """
        rho = batch[: self.K - 1]
        rho_next_gt = batch[1:]
        # pushforward = self._get_pushforward(otmap_params, t)
        rho_next_predicted = jnp.swapaxes(
            jax.vmap(self.model_otmaps.apply, in_axes=(None, 0))({"params": otmap_params}, jnp.swapaxes(rho, 0, 1)),
            1,
            0,
        )

        energy_predicted = jax.vmap(self._compute_energy_single_step, in_axes=(0, None, None))(
            rho_next_predicted, potential_params, interaction_params
        )
        energy_gt = jax.vmap(self._compute_energy_single_step, in_axes=(0, None, None))(
            rho_next_gt, potential_params, interaction_params
        )

        return energy_gt - energy_predicted
        # def with_penalty(_):
        #     def spectral_penalty_step(carry, t):
        #         x_t = rho_next_predicted[t]

        #         def j_fn(x):
        #             return self._compute_energy_single_step(x[None, :], potential_params, interaction_params)

        #         penalty_t = jnp.mean(spectral_norm_jacobian_fn(j_fn, x_t))
        #         return carry, penalty_t

        #     # Run scan to get per-timestep penalties: shape (K-1,)
        #     _, sn_loss = jax.lax.scan(spectral_penalty_step, None, jnp.arange(self.K - 1))
        #     return energy_gt - energy_predicted + self.energy_gp_reg * sn_loss

        # def without_penalty(_):
        #     return energy_gt - energy_predicted

        # apply_penalty = jnp.asarray(self.energy_gp_reg > 0.0)
        # return jax.lax.cond(apply_penalty, with_penalty, without_penalty, operand=None)

    def train_step(
        self,
        state: InverseJKOnetState,
        sample: list[jnp.ndarray],
        epoch: int,
    ) -> tuple[InverseJKOnetState, dict[str, jnp.ndarray]]:
        batch = jnp.stack(sample, axis=0)  # shape: (T, N, D)

        key, key_eval = jax.random.split(state.key)

        def otmap_loss_fn(potential_p, internal_p, interaction_p, otmap_p):
            otmap_losses = self.loss_fn_otmap(potential_p, internal_p, interaction_p, otmap_p, batch, epoch)
            weights = self.get_otmaps_weights(key)
            return self.aggregate_otmap_losses(otmap_losses, weights)

        def energy_loss_fn(potential_p, internal_p, interaction_p, otmap_p):
            energy_losses = self.loss_fn_energy(potential_p, internal_p, interaction_p, otmap_p, batch)
            weights = self.get_energy_weights(key)
            return self.aggregate_energy_losses(energy_losses, weights)

        # Get gradient functions
        grad_fn_otmap = jax.value_and_grad(otmap_loss_fn, argnums=3)
        grad_fn_energy = jax.value_and_grad(energy_loss_fn, argnums=(0, 1, 2))

        # OTMap optimization loop
        def update_otmaps(otmaps_):
            def body(_, otmaps_):
                _, otmap_grads = grad_fn_otmap(
                    state.potential.params, state.internal.params, state.interaction.params, otmaps_.params
                )
                return otmaps_.apply_gradients(grads=otmap_grads)

            otmaps_ = jax.lax.fori_loop(0, self.otmap_config["optim"]["inner_iter"], body, otmaps_)
            # Recompute final loss
            otmap_loss, _ = grad_fn_otmap(
                state.potential.params, state.internal.params, state.interaction.params, otmaps_.params
            )
            return otmaps_, otmap_loss

        new_otmaps, otmap_loss = update_otmaps(state.otmaps)

        # Energy optimization
        energy_loss, energy_grads = grad_fn_energy(
            state.potential.params, state.internal.params, state.interaction.params, new_otmaps.params
        )

        # Apply updates to energy parameters
        new_potential = state.potential.apply_gradients(grads=energy_grads[0])
        new_internal = state.internal.apply_gradients(grads=energy_grads[1])
        new_interaction = state.interaction.apply_gradients(grads=energy_grads[2])

        potential_ema = maybe_ema_update(epoch, state.potential_ema_params, new_potential.params, self.ema_decay)
        internal_ema = maybe_ema_update(epoch, state.internal_ema_params, new_internal.params, self.ema_decay)
        interaction_ema = maybe_ema_update(epoch, state.interaction_ema_params, new_interaction.params, self.ema_decay)

        # Create final updated state
        new_state = InverseJKOnetState(
            potential=new_potential,
            internal=new_internal,
            interaction=new_interaction,
            otmaps=new_otmaps,
            potential_ema_params=potential_ema,
            internal_ema_params=internal_ema,
            interaction_ema_params=interaction_ema,
            key=key_eval,
        )

        metrics = {
            "loss_energy": energy_loss,
            "loss_otmap": otmap_loss,
        }

        return new_state, metrics

    def get_potential(self, state: InverseJKOnetState) -> Callable[[jnp.ndarray], jnp.ndarray]:
        """
        Get the potential function from the model state.

        Parameters
        ----------
        state : InverseJKOnetState
            Training state containing potential, internal, interaction and otmap models.

        Returns
        -------
        Callable[[jnp.ndarray], jnp.ndarray]
            Function that computes the potential.
        """
        params = state.potential_ema_params if self.use_ema else state.potential.params
        return lambda x: state.potential.apply_fn({"params": params}, x)

    def get_interaction(self, state: InverseJKOnetState) -> Callable[[jnp.ndarray], jnp.ndarray]:
        """
        Get the interaction function from the model state.

        Parameters
        ----------
        state : InverseJKOnetState
            Training state containing potential, internal, interaction and otmap models.

        Returns
        -------
        Callable[[jnp.ndarray], jnp.ndarray]
            Function that computes the interaction.
        """
        params = state.interaction_ema_params if self.use_ema else state.interaction.params
        return lambda x: state.interaction.apply_fn({"params": params}, x)

    def get_beta(self, state: InverseJKOnetState) -> float:
        """
        Get the beta value from the model state.

        Parameters
        ----------
        state : InverseJKOnetState
            Training state containing potential, internal, interaction and otmap models.

        Returns
        -------
        float
            The beta value from the internal energy model.
        """
        params = state.internal_ema_params if self.use_ema else state.internal.params
        return abs(state.internal.apply_fn({"params": params}, jnp.asarray([1])).item())


class InverseJKOnetMultimapPotentialInternal(InverseJKOnetMultimap):
    """
    A specialized variant of the InverseJKOnet model that only considers potential and internal terms.
    """

    def _compute_energy_single_step(
        self, rho: jnp.ndarray, potential_params: dict, interaction_params: dict
    ) -> jnp.ndarray:
        """Computes the energy without entropy."""

        def energy_fn(x: jnp.ndarray) -> jnp.ndarray:
            potential_energy = self.model_potential.apply({"params": potential_params}, x)
            return potential_energy

        return jnp.mean(jax.vmap(energy_fn)(rho))

    def get_interaction(self, state: InverseJKOnetState) -> Callable[[jnp.ndarray], jnp.ndarray]:
        return lambda x: 0.0


class InverseJKOnetMultimapPotential(InverseJKOnetMultimapPotentialInternal):
    """
    A variant of the InverseJKOnet model to learn only the potential term.
    """

    def _compute_entropy_contribution_for_energy_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> float:
        return 0.0

    def _compute_entropy_contribution_for_otmap_update(
        self, rho: jnp.ndarray, internal_params: dict, otmap_params: dict, t: int
    ) -> jnp.ndarray:
        return 0.0

    def get_beta(self, state: InverseJKOnetState) -> float:
        return 0.0

class InverseJKOnetMultimapTimePotential(InverseJKOnetMultimapPotential):

    def create_state(self, rng: jax.random.PRNGKey) -> InverseJKOnetState:
        """
        Create initial training states for the potential, internal, interaction and OT map models.

        Parameters
        ----------
        rng : jax.random.PRNGKey
            Random key for initialization.

        Returns
        -------
        InverseJKOnetState
            Tuple containing the training states for the potential, internal, interaction and OT map models.
        """
        initial_potential_state = create_train_state(
            rng, self.model_potential, get_optimizer(self.config_energy_optimizer), self.data_dim+1
        )
        initial_internal_state = create_train_state(
            rng, self.model_internal, get_optimizer(self.config_energy_optimizer), 1
        )
        initial_interaction_state = create_train_state(
            rng, self.model_interaction, get_optimizer(self.config_energy_optimizer), self.data_dim
        )
        initial_otmaps_state = create_train_state(
            rng, self.model_otmaps, get_optimizer(self.config_otmap_optimizer), (self.K - 1, self.data_dim)
        )
        state = InverseJKOnetState(
            potential=initial_potential_state,
            internal=initial_internal_state,
            interaction=initial_interaction_state,
            otmaps=initial_otmaps_state,
            potential_ema_params=initial_potential_state.params,
            internal_ema_params=initial_internal_state.params,
            interaction_ema_params=initial_interaction_state.params,
            key=rng,
        )
        return state
    def _compute_potential(
        self, rho: jnp.ndarray, potential_params: dict, interaction_params: dict, t: int
    ) -> jnp.ndarray:
        """Computes the energy without entropy."""

        t_array = t * jnp.ones([rho.shape[0], 1])
        rho_t = jnp.concat([rho, t_array], axis=1)

        def energy_fn(x: jnp.ndarray) -> jnp.ndarray:
            potential_energy = self.model_potential.apply({"params": potential_params}, x)
            return potential_energy

        return jnp.mean(jax.vmap(energy_fn)(rho_t))

    def _compute_energy_single_step(
        self, rho: jnp.ndarray, potential_params: dict, interaction_params: dict,t:int
    ) -> jnp.ndarray:
        """Computes the energy without entropy."""

        
        potential_t_energy = self._compute_potential(rho,potential_params,interaction_params,t)
        return potential_t_energy 

    def loss_fn_otmap(
        self,
        potential_params: dict,
        internal_params: dict,
        interaction_params: dict,
        otmap_params: dict,
        batch: jnp.ndarray,
        epoch: int,
    ) -> tuple[jnp.ndarray, jnp.ndarray]:
        """
        Compute the loss for the otmap.

        Parameters
        ----------
        potential_params : dict
            Parameters for the potential model.
        internal_params : dict
            Parameters for the internal model.
        interaction_params : dict
            Parameters for the interaction model.
        otmap_params : dict
            Parameters for the otmap model.
        batch : jnp.ndarray
            Array of shape (num_timesteps, num_particles, num_features) containing
            the particle trajectories. Each entry in the array represents a particle's
            state at a given timestep.

        Returns
        -------
        jnp.ndarray
            Total loss value.
        """
        rho = batch[: self.K - 1]  # (K - 1, N, D)

        rho_next_gt = batch[1:]  # (K - 1, N, D)
        t = jnp.array(sorted(self.dataset.trajectory.keys())[1:]).reshape(-1, 1)
        rho_next = jnp.swapaxes(
            jax.vmap(self.model_otmaps.apply, in_axes=(None, 0))({"params": otmap_params}, jnp.swapaxes(rho, 0, 1)),
            1,
            0,
        )  # (K - 1, N, D)

        energy = jax.vmap(self._compute_energy_single_step, in_axes=(0, None, None,0))(
            rho_next, potential_params, interaction_params,t
        )  # (K - 1)
        energy_gt = jax.vmap(self._compute_energy_single_step, in_axes=(0, None, None,0))(
            rho_next_gt, potential_params, interaction_params,t
        )  # (K - 1)

        loss_e = energy
        loss_e_gt = energy_gt
        loss_p = jnp.mean(jnp.sum((rho_next - rho) ** 2, axis=2), axis=1)
        base_loss = loss_e - loss_e_gt + (1 / (2 * self.tau)) * loss_p
        #print(self.tau)
        return base_loss  # K - 1

    def loss_fn_energy(
        self,
        potential_params: dict,
        internal_params: dict,
        interaction_params: dict,
        otmap_params: dict,
        batch: jnp.ndarray,
    ) -> jnp.ndarray:
        """
        Compute the total energy loss for the model by combining potential, internal, and interaction terms.

        Parameters
        ----------
        potential_params : dict
            Parameters for the potential model.
        internal_params : dict
            Parameters for the internal model.
        interaction_params : dict
            Parameters for the interaction model.
        otmap_params : dict
            Parameters for the otmap model.
        batch : jnp.ndarray
            Array of shape (num_timesteps, num_particles, num_features) containing
            the particle trajectories. Each entry in the array represents a particle's
            state at a given timestep.

        Returns
        -------
        jnp.ndarray
            Total loss value.
        """
        rho = batch[: self.K - 1]
        rho_next_gt = batch[1:]
        t = jnp.array(sorted(self.dataset.trajectory.keys())[1:]).reshape(-1, 1)
        # pushforward = self._get_pushforward(otmap_params, t)
        rho_next_predicted = jnp.swapaxes(
            jax.vmap(self.model_otmaps.apply, in_axes=(None, 0))({"params": otmap_params}, jnp.swapaxes(rho, 0, 1)),
            1,
            0,
        )

        energy_predicted = jax.vmap(self._compute_energy_single_step, in_axes=(0, None, None,0))(
            rho_next_predicted, potential_params, interaction_params,t
        )
        energy_gt = jax.vmap(self._compute_energy_single_step, in_axes=(0, None, None,0))(
            rho_next_gt, potential_params, interaction_params,t
        )

        return energy_gt - energy_predicted
