from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import TypeVar

from flax import nnx

from offline import base
from offline.bc.arguments import Arguments
from offline.bc.core import train_step
from offline.bc.modules import BCPolicy
from offline.modules.base import TrainState
from offline.types import SaBatch


@dataclass(frozen=True)
class TrainerState(base.TrainerState[None]):
    graphdef: nnx.GraphDef[TrainState[BCPolicy]]
    graphstate: nnx.GraphState | nnx.VariableState
    sa_iter: Iterator[SaBatch]

    @property
    def policy(self) -> BCPolicy:
        train_state = nnx.merge(self.graphdef, self.graphstate)
        return train_state.model


T = TypeVar("T", bound=TrainerState)


def train_fn(step: int, args: Arguments, state: T) -> T:
    batch = next(state.sa_iter)
    graphstate, results = train_step(
        actions=batch.actions,
        graphdef=state.graphdef,
        graphstate=state.graphstate,
        observations=batch.observations,
    )
    args.logger.write(step, **results)
    return replace(state, graphstate=graphstate)
