from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import TypeVar

from flax import nnx
from jax import Array

from offline import base
from offline.modules.actor.ensemble import GaussianActorEnsembleWithIndices
from offline.bppo.tc.arguments import Arguments
from offline.bppo.tc.core import (
    behavior_cloning_step,
    high_level_train_step,
    high_level_train_step_with_target_update,
    sarsa_step,
    sarsa_step_with_target_update,
    train_step,
)
from offline.bppo.tc.modules import BPPOTCPolicy, QCritic
from offline.bppo.tc.types import (
    BCBatch,
    BPPOTCBatch,
    HLQLearningBatch,
    SarsaBatch,
)
from offline.modules.base import TargetModel, TrainState, TrainStateWithTarget
from offline.modules.mlp import MLP
from offline.utils.data import DataLoader
from offline.utils.logger import Logger
from offline.utils.tqdm import tqdm, trange


@dataclass(frozen=True)
class TrainerState(base.TrainerState[None]):
    bppo_tc_loader: DataLoader[BPPOTCBatch]
    graphdef: nnx.GraphDef[TrainState[BPPOTCPolicy]]
    graphstate: nnx.GraphState | nnx.VariableState
    train_key: Array
    step: int = 0

    @property
    def policy(self) -> BPPOTCPolicy:
        train_state = nnx.merge(self.graphdef, self.graphstate)
        return train_state.model


T = TypeVar("T", bound=TrainerState)


def behavior_cloning_fn(
    actor: GaussianActorEnsembleWithIndices,
    logger: Logger,
    optimizer: nnx.Optimizer,
    sa_iter: Iterator[BCBatch],
    steps: int,
) -> GaussianActorEnsembleWithIndices:
    graphdef, graphstate = nnx.split(
        TrainState(model=actor, optimizer=optimizer)
    )
    for step in trange(steps, desc="BC"):
        batch = next(sa_iter)
        graphstate, results = behavior_cloning_step(
            actions=batch.actions,
            assignments=batch.assignments,
            graphdef=graphdef,
            graphstate=graphstate,
            observations=batch.observations,
        )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def high_level_train_fn(
    gamma: float,
    logger: Logger,
    optimizer: nnx.Optimizer,
    q_learning_iter: Iterator[HLQLearningBatch],
    qcritic: MLP,
    steps: int,
    update_every: int,
):
    graphdef, graphstate = nnx.split(
        TrainStateWithTarget(
            model=qcritic, optimizer=optimizer, target=TargetModel(qcritic)
        )
    )
    for step in trange(steps, desc="HL"):
        batch = next(q_learning_iter)
        if (step + 1) % update_every == 0:
            graphstate, results = high_level_train_step_with_target_update(
                assignments=batch.assignments,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_mask=batch.next_mask,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
            )
        else:
            graphstate, results = high_level_train_step(
                assignments=batch.assignments,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_mask=batch.next_mask,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
            )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def sarsa_fn(
    gamma: float,
    logger: Logger,
    optimizer: nnx.Optimizer,
    qcritic: QCritic,
    sarsa_iter: Iterator[SarsaBatch],
    steps: int,
    tau: float,
    update_every: int,
) -> QCritic:
    graphdef, graphstate = nnx.split(
        TrainStateWithTarget(
            model=qcritic,
            optimizer=optimizer,
            target=TargetModel(qcritic),
        )
    )
    for step in trange(steps, desc="SARSA"):
        batch = next(sarsa_iter)
        if (step + 1) % update_every == 0:
            graphstate, results = sarsa_step_with_target_update(
                actions=batch.actions,
                assignments=batch.assignments,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_actions=batch.next_actions,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
                tau=tau,
            )
        else:
            graphstate, results = sarsa_step(
                actions=batch.actions,
                assignments=batch.assignments,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_actions=batch.next_actions,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
            )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def train_fn(step: int, args: Arguments, state: T) -> T:
    graphstate = state.graphstate
    step_grad = state.step
    for step_grad, batch in enumerate(
        tqdm(state.bppo_tc_loader, desc=f"Epoch {step}", leave=False),
        start=state.step,
    ):
        graphstate, results = train_step(
            assignments=batch.assignments,
            clip_epsilon=args.clip_epsilon,
            entropy_weight=args.entropy_weight,
            graphdef=state.graphdef,
            graphstate=graphstate,
            means=batch.means,
            observations=batch.observations,
            omega=args.omega,
            stds=batch.stds,
            step=step_grad,
            train_key=state.train_key,
            values=batch.values,
        )
        args.logger.write(step, **results)
    return replace(state, graphstate=graphstate, step=step_grad + 1)
