from flax import nnx
import jax.random
from numpy.random import Generator, default_rng
from scipy.stats import chi2

from offline import base
from offline.hdr.arguments import Arguments, build_argument_parser
from offline.hdr.modules import (
    ActorCriticFilter,
    ActorFilter,
    HDRPolicy,
    HDRTrainState,
)
from offline.hdr.train import TrainerState, train_fn
from offline.hdr.utils import prepare_actor_q_learning_dataset
from offline.modules.base import TargetModel, get_optimizer
from offline.lbp.__main__ import (
    init_behavior_actor_fn,
    init_behavior_vcritic_fn,
    load_fn,
)
from offline.types import OfflineData, OfflineDataWithInfos
from offline.utils.data import DataLoader
from offline.utils.nnx import default_nnx_rngs


__all__ = ["load_fn"]


def init_behavior_fn(
    args: Arguments, data: OfflineData, numpy_rng: Generator, rngs: nnx.Rngs
):
    actor = init_behavior_actor_fn(
        batch_size=args.batch_size,
        bc_steps=args.bc_steps,
        data=data,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        rngs=rngs,
    )
    if args.save:
        args.logger.save_model("behavior_actor", model=actor)
    vcritic = init_behavior_vcritic_fn(
        batch_size=args.batch_size,
        data=data,
        gamma=args.gamma,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        rngs=rngs,
        tau=args.tau,
        update_every=args.update_every,
        v_learning_steps=args.v_learning_steps,
    )
    if args.save:
        args.logger.save_model("behavior_vcritic", model=vcritic)
    return actor, vcritic


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if not args.unsquash:
        raise NotImplementedError()

    numpy_rng = default_rng(args.seed)
    rngs = default_nnx_rngs(args.seed)
    behavior_actor, behavior_vcritic = init_behavior_fn(
        args=args, data=data.data, numpy_rng=numpy_rng, rngs=rngs
    )
    actor_dataset, q_learning_dataset = prepare_actor_q_learning_dataset(
        actor=behavior_actor, critic=behavior_vcritic, data=data.data
    )
    policy = HDRPolicy(
        action_dim=data.data.actions.shape[1],
        behavior_actor=behavior_actor,
        deterministic=args.deterministic,
        ensemble_size=args.ensemble_size,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        num_layers=args.num_layers,
        observation_dim=data.data.observations.shape[1],
        rngs=rngs,
    )
    train_state = HDRTrainState(
        actor_optimizer=get_optimizer(
            policy, learning_rate=args.learning_rate, wrt=ActorFilter
        ),
        critic_optimizer=get_optimizer(
            policy.critic, learning_rate=args.learning_rate
        ),
        policy=policy,
        target_policy=TargetModel(policy, poi=ActorCriticFilter),
    )
    graphdef, graphstate = nnx.split(train_state)
    train_actor_key, train_critic_key = jax.random.split(
        jax.random.key(args.seed), 2
    )

    return TrainerState(
        actor_iter=DataLoader(
            actor_dataset,
            batch_size=args.batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        eval_state=None,
        graphdef=graphdef,
        graphstate=graphstate,
        ood_threshold=float(
            chi2.isf(args.ood_probability, data.data.actions.shape[1])
        ),
        q_learning_iter=DataLoader(
            q_learning_dataset,
            batch_size=args.batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        train_actor_key=train_actor_key,
        train_critic_key=train_critic_key,
    )


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        train_fn=train_fn,
        **vars(build_argument_parser().parse_args()),
    )
