from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import TypeVar

from flax import nnx
from jax import Array

from offline import base
from offline.bppo.arguments import Arguments
from offline.bppo.core import (
    sarsa_step,
    sarsa_step_with_target_update,
    train_step,
)
from offline.bppo.modules import BPPOPolicy
from offline.bppo.types import BPPOBatch
from offline.modules.base import TargetModel, TrainState, TrainStateWithTarget
from offline.modules.critic import QCritic
from offline.types import SarsaBatch
from offline.utils.data import DataLoader
from offline.utils.logger import Logger
from offline.utils.tqdm import tqdm, trange


@dataclass(frozen=True)
class TrainerState(base.TrainerState[None]):
    bppo_loader: DataLoader[BPPOBatch]
    graphdef: nnx.GraphDef[TrainState[BPPOPolicy]]
    graphstate: nnx.GraphState | nnx.VariableState
    train_key: Array
    step: int = 0

    @property
    def policy(self) -> BPPOPolicy:
        train_state: TrainState[BPPOPolicy] = nnx.merge(
            self.graphdef, self.graphstate
        )
        return train_state.model


T = TypeVar("T", bound=TrainerState)


def sarsa_fn(
    gamma: float,
    logger: Logger,
    optimizer: nnx.Optimizer,
    qcritic: QCritic,
    sarsa_iter: Iterator[SarsaBatch],
    steps: int,
    tau: float,
    update_every: int,
) -> QCritic:
    graphdef, graphstate = nnx.split(
        TrainStateWithTarget(
            model=qcritic,
            optimizer=optimizer,
            target=TargetModel(qcritic),
        )
    )
    for step in trange(steps, desc="SARSA"):
        batch = next(sarsa_iter)
        if (step + 1) % update_every == 0:
            graphstate, results = sarsa_step_with_target_update(
                actions=batch.actions,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_actions=batch.next_actions,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
                tau=tau,
            )
        else:
            graphstate, results = sarsa_step(
                actions=batch.actions,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_actions=batch.next_actions,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
            )
        logger.write(step, **results)
    train_state: TrainStateWithTarget[QCritic] = nnx.merge(graphdef, graphstate)
    return train_state.model


def train_fn(step: int, args: Arguments, state: T) -> T:
    graphstate = state.graphstate
    step_grad = state.step
    for step_grad, batch in enumerate(
        tqdm(state.bppo_loader, desc=f"Epoch {step}", leave=False),
        start=state.step,
    ):
        graphstate, results = train_step(
            clip_epsilon=args.clip_epsilon,
            entropy_weight=args.entropy_weight,
            graphdef=state.graphdef,
            graphstate=graphstate,
            means=batch.means,
            observations=batch.observations,
            omega=args.omega,
            stds=batch.stds,
            step=step_grad,
            train_key=state.train_key,
            values=batch.values,
        )
        args.logger.write(step, **results)
    return replace(state, graphstate=graphstate, step=step_grad + 1)
