from dataclasses import dataclass

from flax import nnx
import jax.random
from numpy.random import default_rng

from offline import base
from offline.bppo import arguments as bppo
from offline.bppo.__main__ import load_fn
from offline.bppo.modules import BPPOPolicy
from offline.bppo.train import TrainerState, train_fn
from offline.bppo.utils import prepare_bppo_dataset
from offline.modules.actor.base import GaussianActor
from offline.modules.base import TrainState, get_optimizer
from offline.modules.critic import QCritic, VCritic
from offline.types import OfflineData, OfflineDataWithInfos
from offline.utils.data import DataLoader
from offline.utils.logger import Logger
from offline.utils.nnx import default_nnx_rngs


__all__ = ["load_fn"]


@dataclass(frozen=True)
class Arguments(bppo.Arguments):
    parent_logger: Logger


def build_argument_parser(**kwargs):
    parser = bppo.build_argument_parser(
        extra=True,
        fix_keys=(
            "dataset",
            "hidden_features",
            "layer_norm",
            "normalize_observations",
            "normalize_rewards",
            "num_layers",
            "reward_multiplier",
        ),
        **kwargs,
    )
    return parser


def init_behavior_fn(args: Arguments, data: OfflineData):
    def actor_model_fn():
        return GaussianActor(
            action_dim=data.actions.shape[1],
            hidden_features=args.hidden_features,
            layer_norm=args.layer_norm,
            num_layers=args.num_layers,
            observation_dim=data.observations.shape[1],
            rngs=default_nnx_rngs(0),
        )

    actor = args.parent_logger.restore_model(
        "behavior_actor", model_fn=actor_model_fn
    )

    def qcritic_model_fn():
        return QCritic(
            action_dim=data.actions.shape[1],
            hidden_features=args.hidden_features,
            layer_norm=args.layer_norm,
            num_layers=args.num_layers,
            observation_dim=data.observations.shape[1],
            rngs=default_nnx_rngs(0),
        )

    qcritic = args.parent_logger.restore_model(
        "behavior_qcritic", model_fn=qcritic_model_fn
    )

    def vcritic_model_fn():
        return VCritic(
            hidden_features=args.hidden_features,
            layer_norm=args.layer_norm,
            num_layers=args.num_layers,
            observation_dim=data.observations.shape[1],
            rngs=default_nnx_rngs(0),
        )

    vcritic = args.parent_logger.restore_model(
        "behavior_vcritic", model_fn=vcritic_model_fn
    )
    args.parent_logger.wait()
    return actor, qcritic, vcritic


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if not args.unsquash:
        raise NotImplementedError()

    numpy_rng = default_rng(args.seed)
    actor, qcritic, vcritic = init_behavior_fn(args=args, data=data.data)
    bppo_dataset = prepare_bppo_dataset(
        actor=actor, observations=data.data.observations, vcritic=vcritic
    )
    bppo_loader = DataLoader(
        bppo_dataset, batch_size=args.batch_size, drop_last=True, rng=numpy_rng
    )
    policy = BPPOPolicy(actor=actor, critic=qcritic)
    train_state = TrainState(
        model=policy,
        optimizer=get_optimizer(
            policy.actor,
            every_k_schedule=1 if args.stochastic_update else len(bppo_loader),
            learning_rate=args.learning_rate,
            max_gradient_norm=args.max_gradient_norm,
        ),
    )
    graphdef, graphstate = nnx.split(train_state)

    return TrainerState(
        bppo_loader=bppo_loader,
        eval_state=None,
        graphdef=graphdef,
        graphstate=graphstate,
        train_key=jax.random.key(args.seed),
    )


def postprocess_args(**kwargs):
    kwargs["parent_logger"] = kwargs["logger"].parent
    return kwargs


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        skip_reward_normalization=True,
        train_fn=train_fn,
        **postprocess_args(**vars(build_argument_parser().parse_args())),
    )
