from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import Generic, TypeVar

from flax import nnx
from jax import Array
import numpy as np
from numpy.typing import NDArray

from offline import base
from offline.diffusion.arguments import Arguments
from offline.diffusion.repaint.core import train_step
from offline.diffusion.repaint.modules import InpaintPolicy
from offline.modules.base import TrainState


PolicyT = TypeVar("PolicyT", bound=InpaintPolicy)


@dataclass(frozen=True)
class TrainerState(base.TrainerState[tuple[Array, int]], Generic[PolicyT]):
    data_iter: Iterator[NDArray[np.floating]]
    graphdef: nnx.GraphDef[TrainState[PolicyT]]
    graphstate: nnx.GraphState | nnx.VariableState
    train_noise_key: Array
    train_time_key: Array

    @property
    def policy(self) -> PolicyT:
        train_state = nnx.merge(self.graphdef, self.graphstate)
        return train_state.model


T = TypeVar("T", bound=TrainerState[InpaintPolicy])


def train_fn(step: int, args: Arguments, state: T) -> T:
    batch = next(state.data_iter)
    graphstate, results = train_step(
        batch=batch,
        diffusion_steps=args.diffusion_steps,
        graphdef=state.graphdef,
        graphstate=state.graphstate,
        step=step,
        train_noise_key=state.train_noise_key,
        train_time_key=state.train_time_key,
    )
    args.logger.write(step, **results)
    return replace(state, graphstate=graphstate)
