from flax import nnx
import jax.random
from numpy.random import Generator, default_rng
from optax import cosine_decay_schedule

from offline import base
from offline.modules.actor.base import DeterministicActor
from offline.modules.base import TargetModel, get_optimizer
from offline.lbp.utils import get_max_min_reward
from offline.svr.arguments import Arguments, build_argument_parser
from offline.svr.modules import ActorFilter, SVRPolicy, SVRTrainState
from offline.svr.train import TrainerState, behavior_cloning_fn, train_fn
from offline.svr.utils import prepare_svr_dataset
from offline.types import OfflineData, OfflineDataWithInfos
from offline.utils.data import DataLoader
from offline.utils.dataset import prepare_sa_dataset
from offline.utils.logger import Logger
from offline.utils.nnx import default_nnx_rngs


def init_behavior_actor_fn(
    batch_size: int,
    bc_steps: int,
    data: OfflineData,
    hidden_features: int,
    layer_norm: bool,
    learning_rate: float,
    logger: Logger,
    num_layers: int,
    numpy_rng: Generator,
    rngs: nnx.Rngs,
) -> DeterministicActor:
    actor = DeterministicActor(
        action_dim=data.actions.shape[1],
        hidden_features=hidden_features,
        layer_norm=layer_norm,
        num_layers=num_layers,
        observation_dim=data.observations.shape[1],
        rngs=rngs,
        squash=True,
    )
    actor = behavior_cloning_fn(
        actor=actor,
        logger=logger,
        optimizer=get_optimizer(model=actor, learning_rate=learning_rate),
        sa_iter=DataLoader(
            prepare_sa_dataset(data=data),
            batch_size=batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        steps=bc_steps,
    )
    actor.eval()
    return actor


def init_behavior_fn(
    args: Arguments, data: OfflineData, numpy_rng: Generator, rngs: nnx.Rngs
):
    actor = init_behavior_actor_fn(
        batch_size=args.batch_size,
        bc_steps=args.bc_steps,
        data=data,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        rngs=rngs,
    )
    if args.save:
        args.logger.save_model("behavior_actor", model=actor)
    return actor


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if args.unsquash:
        raise NotImplementedError()

    data_without_info = data.data
    _, min_reward = get_max_min_reward(args.dataset)
    if min_reward is None:
        min_reward = float(data_without_info.rewards.min())

    numpy_rng = default_rng(args.seed)
    rngs = default_nnx_rngs(args.seed)
    behavior_actor = init_behavior_fn(
        args=args, data=data_without_info, numpy_rng=numpy_rng, rngs=rngs
    )
    svr_dataset = prepare_svr_dataset(
        actor=behavior_actor, data=data_without_info, sample_std=args.sample_std
    )
    policy = SVRPolicy(
        action_dim=data_without_info.actions.shape[1],
        ensemble_size=args.ensemble_size,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        num_layers=args.num_layers,
        observation_dim=data_without_info.observations.shape[1],
        rngs=rngs,
    )
    if args.constant_schedule:
        schedule = args.learning_rate
    else:
        schedule = cosine_decay_schedule(
            args.learning_rate, args.total_steps // args.update_every
        )
    train_state = SVRTrainState(
        actor_optimizer=get_optimizer(
            policy,
            learning_rate=schedule,
            max_gradient_norm=args.max_gradient_norm,
            wrt=ActorFilter,
        ),
        critic_optimizer=get_optimizer(
            policy.critic,
            learning_rate=args.learning_rate,
            max_gradient_norm=args.max_gradient_norm,
        ),
        policy=policy,
        target_policy=TargetModel(policy),
    )
    graphdef, graphstate = nnx.split(train_state)
    train_key = jax.random.key(args.seed)

    return TrainerState(
        eval_state=None,
        graphdef=graphdef,
        graphstate=graphstate,
        svr_iter=DataLoader(
            svr_dataset,
            batch_size=args.batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        targets_regularizer=min_reward / (1 - args.gamma),
        train_key=train_key,
    )


def load_fn(
    step: int | None,
    action_dim: int,
    ensemble_size: int,
    hidden_features: int,
    layer_norm: bool,
    num_layers: int,
    logger: Logger,
    observation_dim: int,
    **kwargs,
):
    del kwargs

    def behavior_actor_model_fn():
        return DeterministicActor(
            action_dim=action_dim,
            hidden_features=hidden_features,
            layer_norm=layer_norm,
            num_layers=num_layers,
            observation_dim=observation_dim,
            rngs=default_nnx_rngs(0),
            squash=True,
        )

    behavior_actor = logger.restore_model(
        "behavior_actor", model_fn=behavior_actor_model_fn
    )

    def model_fn():
        return SVRPolicy(
            action_dim=action_dim,
            behavior_actor=behavior_actor,
            ensemble_size=ensemble_size,
            hidden_features=hidden_features,
            layer_norm=layer_norm,
            num_layers=num_layers,
            observation_dim=observation_dim,
            rngs=default_nnx_rngs(0),
        )

    policy = base.default_load_fn(logger=logger, model_fn=model_fn, step=step)
    del policy.critic
    return policy, None


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        skip_reward_normalization=True,
        train_fn=train_fn,
        **vars(build_argument_parser().parse_args()),
    )
