from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import TypeVar

from flax import nnx
from jax import Array

from offline import base
from offline.svr.arguments import Arguments
from offline.svr.core import (
    behavior_cloning_step,
    train_actor_critic_step,
    train_critic_step,
)
from offline.svr.modules import SVRPolicy, SVRTrainState
from offline.svr.types import SVRBatch
from offline.modules.actor.base import DeterministicActor
from offline.modules.base import TrainState
from offline.types import SaBatch
from offline.utils.logger import Logger
from offline.utils.tqdm import trange


@dataclass(frozen=True)
class TrainerState(base.TrainerState[None]):
    graphdef: nnx.GraphDef[SVRTrainState]
    graphstate: nnx.GraphState | nnx.VariableState
    svr_iter: Iterator[SVRBatch]
    targets_regularizer: float
    train_key: Array

    @property
    def policy(self) -> SVRPolicy:
        train_state = nnx.merge(self.graphdef, self.graphstate)
        return train_state.policy


T = TypeVar("T", bound=TrainerState)


def behavior_cloning_fn(
    actor: DeterministicActor,
    logger: Logger,
    optimizer: nnx.Optimizer,
    sa_iter: Iterator[SaBatch],
    steps: int,
) -> DeterministicActor:
    graphdef, graphstate = nnx.split(
        TrainState(model=actor, optimizer=optimizer)
    )
    for step in trange(steps, desc="BC"):
        batch = next(sa_iter)
        graphstate, results = behavior_cloning_step(
            actions=batch.actions,
            graphdef=graphdef,
            graphstate=graphstate,
            observations=batch.observations,
        )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def train_fn(step: int, args: Arguments, state: T) -> T:
    batch = next(state.svr_iter)
    if (step + 1) % args.update_every == 0:
        graphstate, results = train_actor_critic_step(
            actions=batch.actions,
            alpha=args.alpha,
            dones=batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            log_likelihoods=batch.log_likelihoods,
            next_observations=batch.next_observations,
            observations=batch.observations,
            regularizer_weight=args.regularizer_weight,
            rewards=batch.rewards,
            sample_std=args.sample_std,
            step=step,
            targets_regularizer=state.targets_regularizer,
            tau=args.tau,
            train_key=state.train_key,
        )
    else:
        graphstate, results = train_critic_step(
            actions=batch.actions,
            dones=batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            log_likelihoods=batch.log_likelihoods,
            next_observations=batch.next_observations,
            observations=batch.observations,
            regularizer_weight=args.regularizer_weight,
            rewards=batch.rewards,
            sample_std=args.sample_std,
            step=step,
            targets_regularizer=state.targets_regularizer,
            train_key=state.train_key,
        )
    args.logger.write(step, **results)
    return replace(state, graphstate=graphstate)
