from dataclasses import dataclass

from flax import nnx
import jax.random
from numpy.random import default_rng
from optax import cosine_decay_schedule
from scipy.stats import chi2

from offline import base
from offline.lbp import arguments as lbp
from offline.lbp.__main__ import load_fn
from offline.lbp.modules import (
    ActorCriticFilter,
    ActorFilter,
    LBPPolicy,
    LBPTrainState,
)
from offline.lbp.train import TrainerState, pretrain_critic_fn, train_fn
from offline.lbp.utils import (
    get_max_min_reward,
    normalize_rewards,
    prepare_actor_regularizer_qlearning_dataset,
)
from offline.modules.actor.base import GaussianActor
from offline.modules.base import TargetModel, get_optimizer
from offline.modules.critic import VCritic
from offline.types import OfflineData, OfflineDataWithInfos
from offline.utils.data import DataLoader
from offline.utils.logger import Logger
from offline.utils.nnx import default_nnx_rngs


__all__ = ["load_fn"]


@dataclass(frozen=True)
class Arguments(lbp.Arguments):
    parent_logger: Logger


def build_argument_parser(**kwargs):
    parser = lbp.build_argument_parser(
        extra=True,
        fix_keys=(
            "dataset",
            "hidden_features",
            "layer_norm",
            "normalize_observations",
            "normalize_rewards",
            "num_layers",
            "reward_multiplier",
        ),
        **kwargs,
    )
    return parser


def init_behavior_fn(args: Arguments, data: OfflineData):
    def actor_model_fn():
        return GaussianActor(
            action_dim=data.actions.shape[1],
            hidden_features=args.hidden_features,
            layer_norm=args.layer_norm,
            num_layers=args.num_layers,
            observation_dim=data.observations.shape[1],
            rngs=default_nnx_rngs(0),
        )

    actor = args.parent_logger.restore_model(
        "behavior_actor", model_fn=actor_model_fn
    )

    def vcritic_model_fn():
        return VCritic(
            hidden_features=args.hidden_features,
            layer_norm=args.layer_norm,
            num_layers=args.num_layers,
            observation_dim=data.observations.shape[1],
            rngs=default_nnx_rngs(0),
        )

    vcritic = args.parent_logger.restore_model(
        "behavior_vcritic", model_fn=vcritic_model_fn
    )
    args.parent_logger.wait()
    return actor, vcritic


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if not args.unsquash:
        raise NotImplementedError()

    data_without_info = data.data

    max_reward, min_reward = get_max_min_reward(args.dataset)
    if max_reward is None:
        max_reward = float(data_without_info.rewards.max())
    if min_reward is None:
        min_reward = float(data_without_info.rewards.min())

    if args.normalize_rewards:
        coefficient = args.reward_multiplier * normalize_rewards(
            data=data.data, gamma=args.gamma
        )
        data_without_info = data_without_info._replace(
            rewards=coefficient * data_without_info.rewards
        )
        max_reward = coefficient * max_reward
        min_reward = coefficient * min_reward

    numpy_rng = default_rng(args.seed)
    behavior_actor, behavior_vcritic = init_behavior_fn(
        args=args, data=data_without_info
    )
    actor_dataset, regularizer_dataset, q_learning_dataset = (
        prepare_actor_regularizer_qlearning_dataset(
            actor=behavior_actor,
            critic=behavior_vcritic,
            data=data_without_info,
            offset=0 if args.sparse else (max_reward - min_reward),
        )
    )
    q_learning_loader = DataLoader(
        q_learning_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        rng=numpy_rng,
    )
    policy = LBPPolicy(
        action_dim=data_without_info.actions.shape[1],
        behavior_actor=behavior_actor,
        ensemble_size=args.ensemble_size,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        num_layers=args.num_layers,
        observation_dim=data_without_info.observations.shape[1],
        rngs=default_nnx_rngs(args.seed),
        zero_mean=args.zero_mean,
    )
    if args.constant_schedule:
        schedule = args.learning_rate
    else:
        schedule = cosine_decay_schedule(
            args.learning_rate, args.total_steps // args.update_every
        )
    if args.pretrain_steps > 0:
        policy.critic = pretrain_critic_fn(
            gamma=args.gamma,
            logger=args.logger,
            optimizer=get_optimizer(
                policy.critic,
                learning_rate=args.learning_rate,
                max_gradient_norm=args.max_gradient_norm,
            ),
            qcritic=policy.critic,
            q_learning_iter=q_learning_loader.repeat_forever(),
            steps=args.pretrain_steps,
            tau=args.tau,
            update_every=args.update_every,
        )
    train_state = LBPTrainState(
        actor_optimizer=get_optimizer(
            policy,
            learning_rate=schedule,
            max_gradient_norm=args.max_gradient_norm,
            wrt=ActorFilter,
        ),
        critic_optimizer=get_optimizer(
            policy.critic,
            learning_rate=args.learning_rate,
            max_gradient_norm=args.max_gradient_norm,
        ),
        policy=policy,
        target_policy=TargetModel(policy, poi=ActorCriticFilter),
    )
    graphdef, graphstate = nnx.split(train_state)
    train_actor_key, train_critic_key, train_regularizer_key = jax.random.split(
        jax.random.key(args.seed), 3
    )

    return TrainerState(
        actor_iter=DataLoader(
            actor_dataset,
            batch_size=args.batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        eval_state=None,
        graphdef=graphdef,
        graphstate=graphstate,
        min_target=float(min_reward) / (1 - args.gamma),
        ood_threshold=float(
            chi2.isf(args.ood_probability, data_without_info.actions.shape[1])
        ),
        q_learning_iter=q_learning_loader.repeat_forever(),
        regularizer_iter=DataLoader(
            regularizer_dataset,
            batch_size=args.batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        train_actor_key=train_actor_key,
        train_critic_key=train_critic_key,
        train_regularizer_key=train_regularizer_key,
    )


def postprocess_args(**kwargs):
    kwargs["parent_logger"] = kwargs["logger"].parent
    return kwargs


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        skip_reward_normalization=True,
        train_fn=train_fn,
        **postprocess_args(**vars(build_argument_parser().parse_args())),
    )
