from flax import nnx
import jax.random
from numpy.random import Generator, default_rng

from offline import base
from offline.bppo.tc.__main__ import init_behavior_fn, load_fn
from offline.bppo.tc.iql.arguments import Arguments, build_argument_parser
from offline.bppo.tc.iql.train import high_level_train_fn
from offline.bppo.tc.iql.utils import prepare_high_level_q_learning_dataset
from offline.bppo.tc.modules import BPPOTCPolicy
from offline.bppo.tc.train import TrainerState, train_fn
from offline.bppo.tc.utils import prepare_bppo_tc_dataset
from offline.modules.actor.ensemble import GaussianActorEnsembleWithIndices
from offline.modules.base import TrainState, get_optimizer
from offline.modules.critic import VCritic
from offline.modules.mlp import MLP
from offline.lbp.tc.types import AssignmentBatch
from offline.types import IntArray, OfflineData, OfflineDataWithInfos
from offline.utils.data import DataLoader, Dataset
from offline.utils.logger import Logger
from offline.utils.nnx import default_nnx_rngs


__all__ = ["load_fn"]


def init_high_level_critic_fn(
    assignments: IntArray,
    batch_size: int,
    codebook_size: int,
    data: OfflineData,
    expectile: float,
    gamma: float,
    hidden_features: int,
    layer_norm: bool,
    learning_rate: float,
    logger: Logger,
    num_layers: int,
    numpy_rng: Generator,
    rngs: nnx.Rngs,
    steps: int,
    update_every: int,
) -> tuple[MLP, VCritic]:
    q_learning_dataset = prepare_high_level_q_learning_dataset(
        assignments=assignments, data=data
    )
    qcritic = MLP(
        hidden_features=hidden_features,
        in_features=data.observations.shape[1],
        layer_norm=layer_norm,
        num_layers=num_layers,
        out_features=codebook_size,
        rngs=rngs,
    )
    vcritic = VCritic(
        hidden_features=hidden_features,
        layer_norm=layer_norm,
        observation_dim=data.observations.shape[1],
        num_layers=num_layers,
        rngs=rngs,
    )
    qcritic, vcritic = high_level_train_fn(
        expectile=expectile,
        gamma=gamma,
        logger=logger,
        optimizer_qcritic=get_optimizer(
            model=qcritic, learning_rate=learning_rate
        ),
        optimizer_vcritic=get_optimizer(
            model=vcritic, learning_rate=learning_rate
        ),
        q_learning_iter=DataLoader(
            q_learning_dataset,
            batch_size=batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        qcritic=qcritic,
        steps=steps,
        update_every=update_every,
        v_learning_iter=DataLoader(
            Dataset(
                AssignmentBatch(
                    assignments=assignments, observations=data.observations
                )
            ),
            batch_size=batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        vcritic=vcritic,
    )
    qcritic.eval()
    vcritic.eval()
    return qcritic, vcritic


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if not args.unsquash:
        raise NotImplementedError()

    numpy_rng = default_rng(args.seed)
    rngs = default_nnx_rngs(args.seed)
    actor, assignments, classifier, qcritic, tc_autoencoder, vcritic = (
        init_behavior_fn(
            args=args, data=data.data, numpy_rng=numpy_rng, rngs=rngs
        )
    )
    high_level_qcritic, high_level_vcritic = init_high_level_critic_fn(
        assignments=assignments,
        batch_size=args.batch_size,
        codebook_size=tc_autoencoder.codebook_size,
        data=data.data,
        expectile=args.expectile,
        gamma=args.gamma,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        rngs=rngs,
        steps=args.high_level_steps,
        update_every=args.high_level_update_every,
    )
    if args.save:
        args.logger.save_model("high_level_qcritic", model=high_level_qcritic)
        args.logger.save_model("high_level_vcritic", model=high_level_vcritic)

    bppo_tc_dataset = prepare_bppo_tc_dataset(
        actor=actor,
        assignments=assignments,
        observations=data.data.observations,
        vcritic=vcritic,
    )
    bppo_tc_loader = DataLoader(
        bppo_tc_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        rng=numpy_rng,
    )
    policy = BPPOTCPolicy(
        actor=actor,
        classifier=classifier,
        critic=qcritic,
        high_level_critic=high_level_qcritic,
        threshold=args.threshold,
    )
    train_state = TrainState(
        model=policy,
        optimizer=get_optimizer(
            policy.actor,
            every_k_schedule=(
                1 if args.stochastic_update else len(bppo_tc_loader)
            ),
            learning_rate=args.learning_rate,
        ),
    )
    graphdef, graphstate = nnx.split(train_state)

    return TrainerState(
        bppo_tc_loader=bppo_tc_loader,
        eval_state=None,
        graphdef=graphdef,
        graphstate=graphstate,
        train_key=jax.random.key(args.seed),
    )


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        train_fn=train_fn,
        **vars(build_argument_parser().parse_args()),
    )
