from flax import nnx
from jax import Array

from offline.diffusion.ddim.modules.conditional import ConditionalDDIM
from offline.diffusion.modules import DiffusionPolicy


class DDIMPolicy(DiffusionPolicy):
    def __init__(  # pylint: disable=super-init-not-called
        self,
        action_dim: int,
        beta_schedule: str,
        clip_sample: bool,
        diffusion_steps: int,
        eta: float,
        inference_steps: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        temperature: float,
        time_dim: int,
        timestep_spacing: str,
        **kwargs
    ):
        self.diffusion = ConditionalDDIM(
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            condition_dim=observation_dim,
            diffusion_steps=diffusion_steps,
            eta=eta,
            inference_steps=inference_steps,
            rngs=rngs,
            sample_dim=action_dim,
            temperature=temperature,
            time_dim=time_dim,
            timestep_spacing=timestep_spacing,
            **kwargs
        )

    def __call__(self, observations: Array, state: tuple[Array, int]):
        actions, state = self.diffusion.ddim(observations, state)
        return actions, state, {}
