from collections.abc import Iterator

from flax import nnx

from offline.bppo.tc.iql.core import (
    high_level_train_step,
    high_level_train_step_with_target_update,
)
from offline.bppo.tc.iql.types import HLQLearningBatch, HighLevelTrainState
from offline.lbp.tc.types import AssignmentBatch
from offline.modules.base import TargetModel
from offline.modules.critic import VCritic
from offline.modules.mlp import MLP
from offline.utils.logger import Logger
from offline.utils.tqdm import trange


def high_level_train_fn(
    expectile: float,
    gamma: float,
    logger: Logger,
    optimizer_qcritic: nnx.Optimizer,
    optimizer_vcritic: nnx.Optimizer,
    q_learning_iter: Iterator[HLQLearningBatch],
    qcritic: MLP,
    steps: int,
    update_every: int,
    v_learning_iter: Iterator[AssignmentBatch],
    vcritic: VCritic,
):
    graphdef, graphstate = nnx.split(
        HighLevelTrainState(
            qcritic=qcritic,
            optimizer_qcritic=optimizer_qcritic,
            optimizer_vcritic=optimizer_vcritic,
            target_qcritic=TargetModel(qcritic),
            vcritic=vcritic,
        )
    )
    for step in trange(steps, desc="HL"):
        batch_qcritic = next(q_learning_iter)
        batch_vcritic = next(v_learning_iter)
        if (step + 1) % update_every == 0:
            graphstate, results = high_level_train_step_with_target_update(
                assignments=batch_qcritic.assignments,
                assignments_vcritic=batch_vcritic.assignments,
                dones=batch_qcritic.dones,
                expectile=expectile,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_observations=batch_qcritic.next_observations,
                observations=batch_qcritic.observations,
                observations_vcritic=batch_vcritic.observations,
                rewards=batch_qcritic.rewards,
            )
        else:
            graphstate, results = high_level_train_step(
                assignments=batch_qcritic.assignments,
                assignments_vcritic=batch_vcritic.assignments,
                dones=batch_qcritic.dones,
                expectile=expectile,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_observations=batch_qcritic.next_observations,
                observations=batch_qcritic.observations,
                observations_vcritic=batch_vcritic.observations,
                rewards=batch_qcritic.rewards,
            )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.qcritic, train_state.vcritic
