from flax import nnx
import jax.random
from numpy.random import Generator, default_rng

from offline import base
from offline.bppo.tc.arguments import Arguments, build_argument_parser
from offline.bppo.tc.modules import BPPOTCPolicy, QCritic
from offline.bppo.tc.train import (
    TrainerState,
    behavior_cloning_fn,
    high_level_train_fn,
    sarsa_fn,
    train_fn,
)
from offline.bppo.tc.types import BCBatch
from offline.bppo.tc.utils import (
    prepare_bppo_tc_dataset,
    prepare_high_level_q_learning_dataset,
    prepare_sarsa_dataset,
)
from offline.modules.actor.ensemble import GaussianActorEnsembleWithIndices
from offline.modules.base import TrainState, get_optimizer
from offline.modules.mlp import MLP
from offline.lbp.tc.__main__ import (
    init_classifier_fn,
    init_trajectory_clustering_fn,
    init_vcritic_fn,
)
from offline.lbp.tc.train import compute_assignments
from offline.lbp.tc.utils import prepare_tc_dataset
from offline.types import IntArray, OfflineData, OfflineDataWithInfos
from offline.utils.data import DataLoader, Dataset
from offline.utils.logger import Logger
from offline.utils.nnx import default_nnx_rngs


def init_behavior_actor_fn(
    assignments: IntArray,
    batch_size: int,
    codebook_size: int,
    data: OfflineData,
    hidden_features: int,
    layer_norm: bool,
    learning_rate: float,
    logger: Logger,
    num_layers: int,
    numpy_rng: Generator,
    rngs: nnx.Rngs,
    steps: int,
):
    actor = GaussianActorEnsembleWithIndices(
        action_dim=data.actions.shape[1],
        ensemble_size=codebook_size,
        hidden_features=hidden_features,
        layer_norm=layer_norm,
        observation_dim=data.observations.shape[1],
        num_layers=num_layers,
        rngs=rngs,
    )
    actor = behavior_cloning_fn(
        actor=actor,
        logger=logger,
        optimizer=get_optimizer(model=actor, learning_rate=learning_rate),
        sa_iter=DataLoader(
            Dataset(
                BCBatch(
                    actions=data.actions,
                    assignments=assignments,
                    observations=data.observations,
                )
            ),
            batch_size=batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        steps=steps,
    )
    actor.eval()
    return actor


def init_behavior_qcritic_fn(
    assignments: IntArray,
    batch_size: int,
    codebook_size: int,
    data: OfflineData,
    gamma: float,
    hidden_features: int,
    layer_norm: bool,
    learning_rate: float,
    logger: Logger,
    num_layers: int,
    numpy_rng: Generator,
    rngs: nnx.Rngs,
    steps: int,
    tau: float,
    update_every: int,
) -> QCritic:
    qcritic = QCritic(
        action_dim=data.actions.shape[1],
        codebook_size=codebook_size,
        hidden_features=hidden_features,
        layer_norm=layer_norm,
        num_layers=num_layers,
        observation_dim=data.observations.shape[1],
        rngs=rngs,
    )
    qcritic = sarsa_fn(
        gamma=gamma,
        logger=logger,
        optimizer=get_optimizer(model=qcritic, learning_rate=learning_rate),
        qcritic=qcritic,
        sarsa_iter=DataLoader(
            prepare_sarsa_dataset(assignments=assignments, data=data),
            batch_size=batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        steps=steps,
        tau=tau,
        update_every=update_every,
    )
    qcritic.eval()
    return qcritic


def init_behavior_fn(
    args: Arguments, data: OfflineData, numpy_rng: Generator, rngs: nnx.Rngs
):
    tc_dataset = prepare_tc_dataset(data=data, filter_trajectories=False)
    tc_autoencoder = init_trajectory_clustering_fn(
        action_dim=data.actions.shape[1],
        batch_size=args.tc_batch_size,
        clip_eigenvalues=args.clip_eigenvalues,
        codebook_size=args.codebook_size,
        commitment_cost=args.commitment_cost,
        dataset=tc_dataset,
        decay=args.tc_decay,
        deterministic_reward=args.deterministic_reward,
        deterministic_transition=args.deterministic_transition,
        hidden_features=args.tc_hidden_features,
        latent_dim=args.latent_dim,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        max_timescale=args.max_timescale,
        min_timescale=args.min_timescale,
        num_blocks=args.num_blocks,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        observation_dim=data.observations.shape[1],
        observation_embedding_dim=args.observation_embedding_dim,
        phase_split=args.tc_phase_split,
        reward_embedding_dim=args.reward_embedding_dim,
        reward_weight=args.reward_weight,
        reweight=args.tc_reweight,
        rngs=rngs,
        ssm_base_size=args.ssm_base_size,
        steps=args.tc_steps,
        subsample_decodes=args.subsample_decodes,
        subsample_latents=args.subsample_latents,
        threshold=args.tc_threshold,
        transition_weight=args.transition_weight,
        use_next_observation=args.use_next_observation,
    )
    if args.save:
        args.logger.save_model("tc_autoencoder", model=tc_autoencoder)
        args.logger.save_toml(
            "codebook_size.toml",
            obj={"codebook_size": tc_autoencoder.codebook_size},
        )
    assignments = compute_assignments(
        batch_size=args.batch_size,
        dataset=tc_dataset,
        tc_autoencoder=tc_autoencoder,
    )
    actor = init_behavior_actor_fn(
        assignments=assignments,
        batch_size=args.batch_size,
        codebook_size=tc_autoencoder.codebook_size,
        data=data,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        rngs=rngs,
        steps=args.bc_steps,
    )
    if args.save:
        args.logger.save_model("behavior_actor", model=actor)
    classifier = init_classifier_fn(
        assignments=assignments,
        batch_size=args.batch_size,
        classifier_steps=args.classifier_steps,
        codebook_size=tc_autoencoder.codebook_size,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        observations=data.observations,
        rngs=rngs,
    )
    if args.save:
        args.logger.save_model("classifier", model=classifier)
    qcritic = init_behavior_qcritic_fn(
        assignments=assignments,
        batch_size=args.batch_size,
        codebook_size=tc_autoencoder.codebook_size,
        data=data,
        gamma=args.gamma,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        rngs=rngs,
        steps=args.sarsa_steps,
        tau=args.tau,
        update_every=args.update_every,
    )
    vcritic = init_vcritic_fn(
        assignments=assignments,
        batch_size=args.batch_size,
        codebook_size=tc_autoencoder.codebook_size,
        data=data,
        gamma=args.gamma,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        rngs=rngs,
        steps=args.v_learning_steps,
        tau=args.tau,
        update_every=args.update_every,
    )
    if args.save:
        args.logger.save_model("behavior_vcritic", model=vcritic)
    return actor, assignments, classifier, qcritic, tc_autoencoder, vcritic


def init_high_level_critic_fn(
    assignments: IntArray,
    batch_size: int,
    codebook_size: int,
    classifier: MLP,
    data: OfflineData,
    gamma: float,
    hidden_features: int,
    layer_norm: bool,
    learning_rate: float,
    logger: Logger,
    num_layers: int,
    numpy_rng: Generator,
    rngs: nnx.Rngs,
    steps: int,
    threshold: float,
    update_every: int,
) -> MLP:
    q_learning_dataset = prepare_high_level_q_learning_dataset(
        assignments=assignments,
        classifier=classifier,
        data=data,
        threshold=threshold,
    )
    qcritic = MLP(
        hidden_features=hidden_features,
        in_features=data.observations.shape[1],
        layer_norm=layer_norm,
        num_layers=num_layers,
        out_features=codebook_size,
        rngs=rngs,
    )
    qcritic = high_level_train_fn(
        gamma=gamma,
        logger=logger,
        optimizer=get_optimizer(model=qcritic, learning_rate=learning_rate),
        q_learning_iter=DataLoader(
            q_learning_dataset,
            batch_size=batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        qcritic=qcritic,
        steps=steps,
        update_every=update_every,
    )
    qcritic.eval()
    return qcritic


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if not args.unsquash:
        raise NotImplementedError()

    numpy_rng = default_rng(args.seed)
    rngs = default_nnx_rngs(args.seed)
    actor, assignments, classifier, qcritic, tc_autoencoder, vcritic = (
        init_behavior_fn(
            args=args, data=data.data, numpy_rng=numpy_rng, rngs=rngs
        )
    )
    high_level_qcritic = init_high_level_critic_fn(
        assignments=assignments,
        batch_size=args.batch_size,
        classifier=classifier,
        codebook_size=tc_autoencoder.codebook_size,
        data=data.data,
        gamma=args.gamma,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        rngs=rngs,
        steps=args.high_level_steps,
        threshold=args.threshold,
        update_every=args.high_level_update_every,
    )
    if args.save:
        args.logger.save_model("high_level_qcritic", model=high_level_qcritic)

    bppo_tc_dataset = prepare_bppo_tc_dataset(
        actor=actor,
        assignments=assignments,
        observations=data.data.observations,
        vcritic=vcritic,
    )
    bppo_tc_loader = DataLoader(
        bppo_tc_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        rng=numpy_rng,
    )
    policy = BPPOTCPolicy(
        actor=actor,
        classifier=classifier,
        critic=qcritic,
        high_level_critic=high_level_qcritic,
        threshold=args.threshold,
    )
    train_state = TrainState(
        model=policy,
        optimizer=get_optimizer(
            policy.actor,
            every_k_schedule=(
                1 if args.stochastic_update else len(bppo_tc_loader)
            ),
            learning_rate=args.learning_rate,
            max_gradient_norm=args.max_gradient_norm,
        ),
    )
    graphdef, graphstate = nnx.split(train_state)

    return TrainerState(
        bppo_tc_loader=bppo_tc_loader,
        eval_state=None,
        graphdef=graphdef,
        graphstate=graphstate,
        train_key=jax.random.key(args.seed),
    )


def load_fn(
    step: int | None,
    action_dim: int,
    hidden_features: int,
    layer_norm: bool,
    num_layers: int,
    logger: Logger,
    observation_dim: int,
    threshold: float,
    **kwargs,
):
    del kwargs

    codebook_size: int = logger.load_toml("codebook_size.toml")["codebook_size"]

    def model_fn():
        return MLP(
            hidden_features=hidden_features,
            in_features=observation_dim,
            layer_norm=layer_norm,
            num_layers=num_layers,
            out_features=codebook_size,
            rngs=default_nnx_rngs(0),
        )

    classifier = logger.restore_model("classifier", model_fn=model_fn)
    high_level_critic = logger.restore_model(
        "high_level_qcritic", model_fn=model_fn
    )

    def actor_model_fn():
        return GaussianActorEnsembleWithIndices(
            action_dim=action_dim,
            ensemble_size=codebook_size,
            hidden_features=hidden_features,
            layer_norm=layer_norm,
            num_layers=num_layers,
            observation_dim=observation_dim,
            rngs=default_nnx_rngs(0),
        )

    if step is None:
        actor = logger.restore_model("actor", model_fn=actor_model_fn)
    else:
        actor = logger.restore_model(
            "checkpoints", f"actor_{step}", model_fn=actor_model_fn
        )
    critic = QCritic(
        action_dim=action_dim,
        codebook_size=codebook_size,
        hidden_features=hidden_features,
        layer_norm=layer_norm,
        num_layers=num_layers,
        observation_dim=observation_dim,
        rngs=default_nnx_rngs(0),
    )
    policy = BPPOTCPolicy(
        actor=actor,
        classifier=classifier,
        critic=critic,
        high_level_critic=high_level_critic,
        threshold=threshold,
    )
    del policy.critic
    return policy, None


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        train_fn=train_fn,
        **vars(build_argument_parser().parse_args()),
    )
