from flax import nnx
from jax import Array

from offline.diffusion.modules.conditional import ConditionalDDPM
from offline.modules.policy import Policy


class DiffusionPolicy(Policy[tuple[Array, int]]):
    def __init__(
        self,
        action_dim: int,
        beta_schedule: str,
        clip_sample: bool,
        diffusion_steps: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        temperature: float,
        time_dim: int,
        **kwargs
    ):
        self.diffusion = ConditionalDDPM(
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            condition_dim=observation_dim,
            diffusion_steps=diffusion_steps,
            rngs=rngs,
            sample_dim=action_dim,
            temperature=temperature,
            time_dim=time_dim,
            **kwargs
        )

    def __call__(self, observations: Array, state: tuple[Array, int]):
        actions, state = self.diffusion(observations, state)
        return actions, state, {}
