from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import Generic, TypeVar

from flax import nnx
from jax import Array

from offline import base
from offline.diffusion.arguments import Arguments
from offline.diffusion.core import train_step
from offline.diffusion.modules import DiffusionPolicy
from offline.modules.base import TrainState
from offline.types import SaBatch


PolicyT = TypeVar("PolicyT", bound=DiffusionPolicy)


@dataclass(frozen=True)
class TrainerState(base.TrainerState[tuple[Array, int]], Generic[PolicyT]):
    graphdef: nnx.GraphDef[TrainState[PolicyT]]
    graphstate: nnx.GraphState | nnx.VariableState
    sa_iter: Iterator[SaBatch]
    train_noise_key: Array
    train_time_key: Array

    @property
    def policy(self) -> PolicyT:
        train_state = nnx.merge(self.graphdef, self.graphstate)
        return train_state.model


T = TypeVar("T", bound=TrainerState[DiffusionPolicy])


def train_fn(step: int, args: Arguments, state: T) -> T:
    batch = next(state.sa_iter)
    graphstate, results = train_step(
        actions=batch.actions,
        graphdef=state.graphdef,
        graphstate=state.graphstate,
        observations=batch.observations,
        step=step,
        train_noise_key=state.train_noise_key,
        train_time_key=state.train_time_key,
    )
    args.logger.write(step, **results)
    return replace(state, graphstate=graphstate)
