from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import TypeVar

from flax import nnx
from jax import Array

from offline import base
from offline.lbp.arguments import Arguments
from offline.lbp.core import (
    behavior_cloning_step,
    pretrain_critic_step,
    pretrain_critic_step_with_target_update,
    train_actor_critic_step,
    train_critic_step,
    v_learning_step,
)
from offline.lbp.modules import LBPPolicy, LBPTrainState
from offline.lbp.types import ActorBatch, RegularizerBatch, QLearningBatch
from offline.modules.actor.base import GaussianActor
from offline.modules.base import TargetModel, TrainState, TrainStateWithTarget
from offline.modules.critic import QCriticEnsemble, VCritic
from offline.types import RegressionBatch, SaBatch
from offline.utils.logger import Logger
from offline.utils.tqdm import trange


@dataclass(frozen=True)
class TrainerState(base.TrainerState[None]):
    actor_iter: Iterator[ActorBatch]
    graphdef: nnx.GraphDef[LBPTrainState]
    graphstate: nnx.GraphState | nnx.VariableState
    min_target: float
    ood_threshold: float
    q_learning_iter: Iterator[QLearningBatch]
    regularizer_iter: Iterator[RegularizerBatch]
    train_actor_key: Array
    train_critic_key: Array
    train_regularizer_key: Array

    @property
    def policy(self) -> LBPPolicy:
        train_state = nnx.merge(self.graphdef, self.graphstate)
        return train_state.policy


T = TypeVar("T", bound=TrainerState)


def behavior_cloning_fn(
    actor: GaussianActor,
    logger: Logger,
    optimizer: nnx.Optimizer,
    sa_iter: Iterator[SaBatch],
    steps: int,
) -> GaussianActor:
    graphdef, graphstate = nnx.split(
        TrainState(model=actor, optimizer=optimizer)
    )
    for step in trange(steps, desc="BC"):
        batch = next(sa_iter)
        graphstate, results = behavior_cloning_step(
            actions=batch.actions,
            graphdef=graphdef,
            graphstate=graphstate,
            observations=batch.observations,
        )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def pretrain_critic_fn(
    gamma: float,
    logger: Logger,
    optimizer: nnx.Optimizer,
    qcritic: QCriticEnsemble,
    q_learning_iter: Iterator[QLearningBatch],
    steps: int,
    tau: float,
    update_every: int,
) -> QCriticEnsemble:
    graphdef, graphstate = nnx.split(
        TrainStateWithTarget(
            model=qcritic,
            optimizer=optimizer,
            target=TargetModel(qcritic),
        )
    )
    for step in trange(steps, desc="Pretrain"):
        batch = next(q_learning_iter)
        if (step + 1) % update_every == 0:
            graphstate, results = pretrain_critic_step_with_target_update(
                actions=batch.actions,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_means=batch.next_means,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
                tau=tau,
            )
        else:
            graphstate, results = pretrain_critic_step(
                actions=batch.actions,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                next_means=batch.next_means,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
            )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def train_fn(step: int, args: Arguments, state: T) -> T:
    critic_batch = next(state.q_learning_iter)
    regularizer_batch = next(state.regularizer_iter)
    noise_batch_size = int(round(args.batch_size * args.noise_ratio))
    if (step + 1) % args.update_every == 0:
        actor_batch = next(state.actor_iter)
        graphstate, results = train_actor_critic_step(
            actions=critic_batch.actions,
            baseline_regularizer=regularizer_batch.baseline,
            dones=critic_batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            lipschitz_constant=args.lipschitz_constant,
            means_actor=actor_batch.means,
            means_regularizer=regularizer_batch.means,
            min_target=state.min_target,
            next_means=critic_batch.next_means,
            next_observations=critic_batch.next_observations,
            noise_batch_size=noise_batch_size,
            observations=critic_batch.observations,
            observations_actor=actor_batch.observations,
            observations_regularizer=regularizer_batch.observations,
            ood_threshold=state.ood_threshold,
            regularizer_weight=args.regularizer_weight,
            rewards=critic_batch.rewards,
            std_multiplier=args.std_multiplier,
            stds_regularizer=regularizer_batch.stds,
            step=step,
            tau=args.tau,
            train_actor_key=state.train_actor_key,
            train_critic_key=state.train_critic_key,
            train_regularizer_key=state.train_regularizer_key,
        )
    else:
        graphstate, results = train_critic_step(
            actions=critic_batch.actions,
            baseline_regularizer=regularizer_batch.baseline,
            dones=critic_batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            lipschitz_constant=args.lipschitz_constant,
            means_regularizer=regularizer_batch.means,
            min_target=state.min_target,
            next_means=critic_batch.next_means,
            next_observations=critic_batch.next_observations,
            noise_batch_size=noise_batch_size,
            observations=critic_batch.observations,
            observations_regularizer=regularizer_batch.observations,
            ood_threshold=state.ood_threshold,
            regularizer_weight=args.regularizer_weight,
            rewards=critic_batch.rewards,
            std_multiplier=args.std_multiplier,
            stds_regularizer=regularizer_batch.stds,
            step=step,
            train_critic_key=state.train_critic_key,
            train_regularizer_key=state.train_regularizer_key,
        )
    args.logger.write(step, **results)
    return replace(state, graphstate=graphstate)


def v_learning_fn(
    logger: Logger,
    optimizer: nnx.Optimizer,
    steps: int,
    v_learning_iter: Iterator[RegressionBatch],
    vcritic: VCritic,
) -> VCritic:
    graphdef, graphstate = nnx.split(
        TrainState(model=vcritic, optimizer=optimizer)
    )
    for step in trange(steps, desc="V"):
        batch = next(v_learning_iter)
        graphstate, results = v_learning_step(
            graphdef=graphdef,
            graphstate=graphstate,
            observations=batch.features,
            targets=batch.targets,
        )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model
