from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import TypeVar

from flax import nnx
from jax import Array

from offline import base
from offline.td3bc.arguments import Arguments
from offline.td3bc.core import train_actor_critic_step, train_critic_step
from offline.td3bc.modules import TD3BCPolicy, TD3BCTrainState
from offline.types import SaBatch, QLearningBatch


@dataclass(frozen=True)
class TrainerState(base.TrainerState[None]):
    graphdef: nnx.GraphDef[TD3BCTrainState]
    graphstate: nnx.GraphState | nnx.VariableState
    sa_iter: Iterator[SaBatch]
    qlearning_iter: Iterator[QLearningBatch]
    train_key: Array

    @property
    def policy(self) -> TD3BCPolicy:
        train_state = nnx.merge(self.graphdef, self.graphstate)
        return train_state.policy


T = TypeVar("T", bound=TrainerState)


def train_fn(step: int, args: Arguments, state: T) -> T:
    critic_batch = next(state.qlearning_iter)
    if (step + 1) % args.update_every == 0:
        actor_batch = next(state.sa_iter)
        graphstate, results = train_actor_critic_step(
            actions=critic_batch.actions,
            actions_actor=actor_batch.actions,
            alpha=args.alpha,
            dones=critic_batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            next_observations=critic_batch.next_observations,
            noise_clip=args.noise_clip,
            observations=critic_batch.observations,
            observations_actor=actor_batch.observations,
            policy_noise=args.policy_noise,
            rewards=critic_batch.rewards,
            step=step,
            tau=args.tau,
            train_key=state.train_key,
        )
    else:
        graphstate, results = train_critic_step(
            actions=critic_batch.actions,
            dones=critic_batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            next_observations=critic_batch.next_observations,
            noise_clip=args.noise_clip,
            observations=critic_batch.observations,
            policy_noise=args.policy_noise,
            rewards=critic_batch.rewards,
            step=step,
            train_key=state.train_key,
        )
    args.logger.write(step, **results)
    return replace(state, graphstate=graphstate)
