from dataclasses import dataclass

from flax import nnx
import jax.random
from numpy.random import default_rng
from scipy.stats import chi2

from offline import base
from offline.lbp.finetune import postprocess_args
from offline.lbp.tc import arguments as lbp_tc
from offline.lbp.tc.__main__ import load_fn
from offline.lbp.tc.modules import (
    ActorCriticFilter,
    ActorFilter,
    LBPTCPolicy,
    LBPTCTrainState,
)
from offline.lbp.tc.modules.auto_encoder import TCAutoEncoder
from offline.lbp.tc.train import TrainerState, train_fn
from offline.lbp.tc.utils import prepare_noise_q_learning_dataset
from offline.modules.base import TargetModel, get_optimizer
from offline.modules.mlp import MLP
from offline.types import OfflineData, OfflineDataWithInfos
from offline.utils.data import DataLoader
from offline.utils.logger import Logger
from offline.utils.nnx import default_nnx_rngs


__all__ = ["load_fn"]


@dataclass(frozen=True)
class Arguments(lbp_tc.Arguments):
    parent_logger: Logger


def build_argument_parser(**kwargs):
    parser = lbp_tc.build_argument_parser(
        extra=True,
        fix_keys=(
            "deterministic_reward",
            "deterministic_transition",
            "env",
            "hidden_features",
            "latent_dim",
            "layer_norm",
            "normalize_observations",
            "num_blocks",
            "num_layers",
            "observation_embedding_dim",
            "reward_embedding_dim",
            "reward_weight",
            "ssm_base_size",
            "transition_weight",
            "use_next_observation",
        ),
        **kwargs,
    )
    return parser


def init_behavior_fn(args: Arguments, data: OfflineData):

    codebook_size: int = args.parent_logger.load_toml("codebook_size.toml")[
        "codebook_size"
    ]

    def classifier_model_fn():
        return MLP(
            hidden_features=args.hidden_features,
            in_features=data.observations.shape[1],
            layer_norm=args.layer_norm,
            num_layers=args.num_layers,
            out_features=codebook_size,
            rngs=default_nnx_rngs(0),
        )

    classifier = args.parent_logger.restore_model(
        "classifier", model_fn=classifier_model_fn
    )

    def tc_autoencoder_model_fn():
        return TCAutoEncoder(
            action_dim=data.actions.shape[1],
            clip_eigenvalues=False,
            codebook_size=codebook_size,
            decay=0.99,
            decode_reward=args.reward_weight > 0,
            decode_transition=args.transition_weight > 0,
            deterministic_reward=args.deterministic_reward,
            deterministic_transition=args.deterministic_transition,
            hidden_features=args.tc_hidden_features,
            latent_dim=args.latent_dim,
            max_timescale=0.1,
            min_timescale=0.001,
            num_blocks=args.num_blocks,
            observation_dim=data.observations.shape[1],
            observation_embedding_dim=args.observation_embedding_dim,
            reward_embedding_dim=args.reward_embedding_dim,
            rngs=default_nnx_rngs(0),
            ssm_base_size=args.ssm_base_size,
            use_next_observation=args.use_next_observation,
        )

    tc_autoencoder = args.parent_logger.restore_model(
        "tc_autoencoder", model_fn=tc_autoencoder_model_fn
    )

    def vcritic_model_fn():
        return MLP(
            hidden_features=args.hidden_features,
            in_features=data.observations.shape[1],
            layer_norm=args.layer_norm,
            num_layers=args.num_layers,
            out_features=codebook_size,
            rngs=default_nnx_rngs(0),
        )

    vcritic = args.parent_logger.restore_model(
        "behavior_vcritic", model_fn=vcritic_model_fn
    )
    args.parent_logger.wait()
    return classifier, tc_autoencoder, vcritic


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if not args.unsquash:
        raise NotImplementedError()

    numpy_rng = default_rng(args.seed)
    rngs = default_nnx_rngs(args.seed)
    classifier, tc_autoencoder, behavior_vcritic = init_behavior_fn(
        args=args, data=data.data
    )
    noise_dataset, q_learning_dataset = prepare_noise_q_learning_dataset(
        classifier=classifier,
        critic=behavior_vcritic,
        data=data.data,
        sparse=args.sparse,
        tc_autoencoder=tc_autoencoder,
        threshold=args.threshold,
    )
    policy = LBPTCPolicy(
        action_dim=data.data.actions.shape[1],
        classifier=classifier,
        deterministic=args.deterministic,
        ensemble_size=args.ensemble_size,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        num_layers=args.num_layers,
        observation_dim=data.data.observations.shape[1],
        rngs=rngs,
        tc_autoencoder=tc_autoencoder,
        threshold=args.threshold,
    )
    train_state = LBPTCTrainState(
        actor_optimizer=get_optimizer(
            policy, learning_rate=args.learning_rate, wrt=ActorFilter
        ),
        critic_optimizer=get_optimizer(
            policy.critic, learning_rate=args.learning_rate
        ),
        policy=policy,
        target_policy=TargetModel(policy, poi=ActorCriticFilter),
    )
    graphdef, graphstate = nnx.split(train_state)
    train_actor_key, train_critic_key, train_regularizer_key = jax.random.split(
        jax.random.key(args.seed), 3
    )

    return TrainerState(
        eval_state=None,
        graphdef=graphdef,
        graphstate=graphstate,
        min_target=float(data.data.rewards.min()) / (1 - args.gamma),
        noise_iter=DataLoader(
            noise_dataset,
            batch_size=args.batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        ood_threshold=float(
            chi2.isf(args.ood_probability, data.data.actions.shape[1])
        ),
        q_learning_iter=DataLoader(
            q_learning_dataset,
            batch_size=args.batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        train_actor_key=train_actor_key,
        train_critic_key=train_critic_key,
        train_regularizer_key=train_regularizer_key,
    )


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        train_fn=train_fn,
        **postprocess_args(**vars(build_argument_parser().parse_args())),
    )
