from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import TypeVar

from flax import nnx
from jax import Array

from offline import base
from offline.hdr.arguments import Arguments
from offline.hdr.types import ActorBatch, QLearningBatch
from offline.hdr.core import train_actor_critic_step, train_critic_step
from offline.hdr.modules import HDRPolicy, HDRTrainState


@dataclass(frozen=True)
class TrainerState(base.TrainerState[None]):
    actor_iter: Iterator[ActorBatch]
    graphdef: nnx.GraphDef[HDRTrainState]
    graphstate: nnx.GraphState
    ood_threshold: float
    q_learning_iter: Iterator[QLearningBatch]
    train_actor_key: Array
    train_critic_key: Array

    @property
    def policy(self) -> HDRPolicy:
        train_state = nnx.merge(self.graphdef, self.graphstate)
        return train_state.policy


T = TypeVar("T", bound=TrainerState)


def train_fn(step: int, args: Arguments, state: T) -> T:
    critic_batch = next(state.q_learning_iter)
    if (step + 1) % args.update_every == 0:
        actor_batch = next(state.actor_iter)
        graphstate, results = train_actor_critic_step(
            actions=critic_batch.actions,
            dones=critic_batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            means_actor=actor_batch.means,
            next_means=critic_batch.next_means,
            next_observations=critic_batch.next_observations,
            next_stds=critic_batch.next_stds,
            next_values=critic_batch.next_values,
            observations=critic_batch.observations,
            observations_actor=actor_batch.observations,
            ood_threshold=state.ood_threshold,
            rewards=critic_batch.rewards,
            stds_actor=actor_batch.stds,
            step=step,
            tau=args.tau,
            train_actor_key=state.train_actor_key,
            train_critic_key=state.train_critic_key,
        )
    else:
        graphstate, results = train_critic_step(
            actions=critic_batch.actions,
            dones=critic_batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            next_means=critic_batch.next_means,
            next_observations=critic_batch.next_observations,
            next_stds=critic_batch.next_stds,
            next_values=critic_batch.next_values,
            observations=critic_batch.observations,
            ood_threshold=state.ood_threshold,
            rewards=critic_batch.rewards,
            step=step,
            train_critic_key=state.train_critic_key,
        )
    args.logger.write(step, **results)
    return replace(state, graphstate=graphstate)
