from flax import nnx
import jax.random
from numpy.random import default_rng

from offline import base
from offline.modules.base import TargetModel, get_optimizer
from offline.td3bc.arguments import Arguments, build_argument_parser
from offline.td3bc.modules import ActorFilter, TD3BCPolicy, TD3BCTrainState
from offline.td3bc.train import TrainerState, train_fn
from offline.types import OfflineDataWithInfos
from offline.utils.data import DataLoader
from offline.utils.dataset import prepare_sa_dataset, prepare_q_learning_dataset
from offline.utils.logger import Logger
from offline.utils.nnx import default_nnx_rngs


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if args.unsquash:
        raise NotImplementedError()

    policy = TD3BCPolicy(
        action_dim=data.data.actions.shape[1],
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        num_layers=args.num_layers,
        observation_dim=data.data.observations.shape[1],
        rngs=default_nnx_rngs(args.seed),
    )
    rng = default_rng(args.seed)
    qlearning_loader = DataLoader(
        batch_size=args.batch_size,
        dataset=prepare_q_learning_dataset(data.data),
        drop_last=True,
        rng=rng,
    )
    sa_loader = DataLoader(
        batch_size=args.batch_size,
        dataset=prepare_sa_dataset(data.data),
        drop_last=True,
        rng=rng,
    )
    train_state = TD3BCTrainState(
        actor_optimizer=get_optimizer(
            policy, learning_rate=args.learning_rate, wrt=ActorFilter
        ),
        critic_optimizer=get_optimizer(
            policy.critic, learning_rate=args.learning_rate
        ),
        policy=policy,
        target_policy=TargetModel(policy),
    )
    graphdef, graphstate = nnx.split(train_state)

    return TrainerState(
        eval_state=None,
        graphdef=graphdef,
        graphstate=graphstate,
        qlearning_iter=qlearning_loader.repeat_forever(),
        sa_iter=sa_loader.repeat_forever(),
        train_key=jax.random.key(args.seed),
    )


def load_fn(
    step: int | None,
    action_dim: int,
    hidden_features: int,
    layer_norm: bool,
    num_layers: int,
    logger: Logger,
    observation_dim: int,
    **kwargs,
):
    del kwargs

    def model_fn():
        return TD3BCPolicy(
            action_dim=action_dim,
            hidden_features=hidden_features,
            layer_norm=layer_norm,
            num_layers=num_layers,
            observation_dim=observation_dim,
            rngs=default_nnx_rngs(0),
        )

    policy = base.default_load_fn(logger=logger, model_fn=model_fn, step=step)
    del policy.critic
    return policy, None


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        train_fn=train_fn,
        **vars(build_argument_parser().parse_args()),
    )
