from flax import nnx
import jax.random
from numpy.random import Generator, default_rng
from optax import cosine_decay_schedule
from scipy.stats import chi2

from offline import base
from offline.modules.actor.ensemble import GaussianActorEnsemble
from offline.modules.base import TargetModel, get_optimizer
from offline.modules.mlp import MLP
from offline.lbp.tc.arguments import Arguments, build_argument_parser
from offline.lbp.tc.modules.auto_encoder import TCAutoEncoder
from offline.lbp.tc.modules import (
    ActorCriticFilter,
    ActorFilter,
    LBPTCPolicy,
    LBPTCTrainState,
)
from offline.lbp.tc.train import (
    TrainerState,
    behavior_cloning_fn,
    compute_assignments,
    pretrain_critic_fn,
    train_classifier_fn,
    train_fn,
    train_tcae_fn,
    v_learning_fn,
)
from offline.lbp.tc.types import AssignmentBatch, SaBatch
from offline.lbp.tc.utils import (
    normalize_rewards,
    prepare_actor_q_learning_regularizer_dataset,
    prepare_v_learning_dataset,
    prepare_tc_dataset,
)
from offline.lbp.utils import get_max_min_reward
from offline.types import (
    FloatArray,
    IntArray,
    OfflineData,
    OfflineDataWithInfos,
)
from offline.utils.data import DataLoader, Dataset
from offline.utils.logger import ChildLogger, Logger
from offline.utils.nnx import default_nnx_rngs


def init_behavior_actor_fn(
    actions: FloatArray,
    assignments: IntArray,
    batch_size: int,
    bc_steps: int,
    codebook_size: int,
    hidden_features: int,
    layer_norm: bool,
    learning_rate: float,
    logger: Logger,
    num_layers: int,
    numpy_rng: Generator,
    observations: FloatArray,
    rngs: nnx.Rngs,
):
    dataset = Dataset(
        SaBatch(
            actions=actions, assignments=assignments, observations=observations
        )
    )
    actor = GaussianActorEnsemble(
        action_dim=actions.shape[1],
        ensemble_size=codebook_size,
        hidden_features=hidden_features,
        layer_norm=layer_norm,
        num_layers=num_layers,
        observation_dim=observations.shape[1],
        out_axis=-2,
        rngs=rngs,
    )
    actor = behavior_cloning_fn(
        actor=actor,
        logger=logger,
        optimizer=get_optimizer(model=actor, learning_rate=learning_rate),
        sa_iter=DataLoader(
            dataset, batch_size=batch_size, drop_last=True, rng=numpy_rng
        ).repeat_forever(),
        steps=bc_steps,
    )
    actor.eval()
    return actor


def init_classifier_fn(
    assignments: IntArray,
    batch_size: int,
    classifier_steps: int,
    codebook_size: int,
    hidden_features: int,
    layer_norm: bool,
    learning_rate: float,
    logger: Logger,
    num_layers: int,
    numpy_rng: Generator,
    observations: FloatArray,
    rngs: nnx.Rngs,
):
    dataset = Dataset(
        AssignmentBatch(assignments=assignments, observations=observations)
    )
    classifier = MLP(
        hidden_features=hidden_features,
        in_features=observations.shape[1],
        layer_norm=layer_norm,
        num_layers=num_layers,
        out_features=codebook_size,
        rngs=rngs,
    )
    classifier = train_classifier_fn(
        classifier=classifier,
        logger=logger,
        optimizer=get_optimizer(model=classifier, learning_rate=learning_rate),
        steps=classifier_steps,
        classifier_iter=DataLoader(
            dataset, batch_size=batch_size, drop_last=True, rng=numpy_rng
        ).repeat_forever(),
    )
    classifier.eval()
    return classifier


def init_vcritic_fn(
    assignments: IntArray,
    batch_size: int,
    codebook_size: int,
    data: OfflineData,
    hidden_features: int,
    gamma: float,
    layer_norm: bool,
    learning_rate: float,
    logger: Logger,
    num_layers: int,
    numpy_rng: Generator,
    rngs: nnx.Rngs,
    steps: int,
):
    dataset = prepare_v_learning_dataset(
        assignments=assignments, data=data, gamma=gamma
    )
    vcritic = MLP(
        hidden_features=hidden_features,
        in_features=data.observations.shape[1],
        layer_norm=layer_norm,
        num_layers=num_layers,
        out_features=codebook_size,
        rngs=rngs,
    )
    vcritic = v_learning_fn(
        logger=logger,
        optimizer=get_optimizer(model=vcritic, learning_rate=learning_rate),
        steps=steps,
        v_learning_iter=DataLoader(
            dataset, batch_size=batch_size, drop_last=True, rng=numpy_rng
        ).repeat_forever(),
        vcritic=vcritic,
    )
    vcritic.eval()
    return vcritic


def init_trajectory_clustering_fn(
    action_dim: int,
    batch_size: int,
    clip_eigenvalues: bool,
    codebook_size: int,
    commitment_cost: float,
    data: OfflineData,
    decay: float,
    deterministic_reward: bool,
    deterministic_transition: bool,
    hidden_features: int,
    latent_dim: int,
    layer_norm: bool,
    learning_rate: float,
    logger: Logger,
    max_timescale: float,
    min_timescale: float,
    num_blocks: int,
    num_layers: int,
    numpy_rng: Generator,
    observation_dim: int,
    observation_embedding_dim: int,
    phase_split: float,
    reward_embedding_dim: int,
    reward_weight: float,
    reweight: bool,
    rngs: nnx.Rngs,
    ssm_base_size: int,
    steps: int,
    subsample_decodes: int,
    subsample_latents: int,
    threshold: float,
    transition_weight: float,
    use_next_observation: bool,
    **kwargs,
) -> tuple[IntArray, TCAutoEncoder]:
    dataset = prepare_tc_dataset(data=data, filter_trajectories=False)
    tc_autoencoder = TCAutoEncoder(
        action_dim=action_dim,
        clip_eigenvalues=clip_eigenvalues,
        codebook_size=codebook_size,
        decay=decay,
        decode_reward=reward_weight > 0,
        decode_transition=transition_weight > 0,
        deterministic_reward=deterministic_reward,
        deterministic_transition=deterministic_transition,
        hidden_features=hidden_features,
        latent_dim=latent_dim,
        layer_norm=layer_norm,
        max_timescale=max_timescale,
        min_timescale=min_timescale,
        num_blocks=num_blocks,
        num_layers=num_layers,
        observation_dim=observation_dim,
        observation_embedding_dim=observation_embedding_dim,
        reward_embedding_dim=reward_embedding_dim,
        rngs=rngs,
        ssm_base_size=ssm_base_size,
        use_next_observation=use_next_observation,
        **kwargs,
    )
    tc_autoencoder = train_tcae_fn(
        batch_size=batch_size,
        commitment_cost=commitment_cost,
        dataset=dataset,
        logger=logger,
        optimizer=get_optimizer(
            model=tc_autoencoder, learning_rate=learning_rate
        ),
        phase_split=phase_split,
        reward_weight=reward_weight,
        reweight=reweight,
        rng=numpy_rng,
        steps=steps,
        subsample_decodes=subsample_decodes,
        subsample_latents=subsample_latents,
        tc_autoencoder=tc_autoencoder,
        threshold=threshold,
        transition_weight=transition_weight,
    )
    tc_autoencoder.eval()
    codebook_size = tc_autoencoder.codebook_size
    assignments = compute_assignments(
        batch_size=batch_size,
        dataset=dataset,
        tc_autoencoder=tc_autoencoder,
    )
    return assignments, tc_autoencoder


def init_behavior_fn(
    args: Arguments,
    assignments: IntArray,
    data: OfflineData,
    numpy_rng: Generator,
    rngs: nnx.Rngs,
    tc_autoencoder: TCAutoEncoder,
):
    actor = init_behavior_actor_fn(
        actions=data.actions,
        assignments=assignments,
        batch_size=args.batch_size,
        bc_steps=args.bc_steps,
        codebook_size=tc_autoencoder.codebook_size,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        observations=data.observations,
        rngs=rngs,
    )
    if args.save:
        args.logger.save_model("behavior_actor", model=actor)
    classifier = init_classifier_fn(
        assignments=assignments,
        batch_size=args.batch_size,
        classifier_steps=args.classifier_steps,
        codebook_size=tc_autoencoder.codebook_size,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        observations=data.observations,
        rngs=rngs,
    )
    if args.save:
        args.logger.save_model("classifier", model=classifier)
    vcritic = init_vcritic_fn(
        assignments=assignments,
        batch_size=args.batch_size,
        codebook_size=tc_autoencoder.codebook_size,
        data=data,
        gamma=args.gamma,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        rngs=rngs,
        steps=args.v_learning_steps,
    )
    if args.save:
        args.logger.save_model("behavior_vcritic", model=vcritic)
    return actor, classifier, vcritic


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if not args.unsquash:
        raise NotImplementedError()

    numpy_rng = default_rng(args.seed)
    rngs = default_nnx_rngs(args.seed)

    assignments, tc_autoencoder = init_trajectory_clustering_fn(
        action_dim=data.data.actions.shape[1],
        batch_size=args.tc_batch_size,
        clip_eigenvalues=args.clip_eigenvalues,
        codebook_size=args.codebook_size,
        commitment_cost=args.commitment_cost,
        data=data.data,
        decay=args.tc_decay,
        deterministic_reward=args.deterministic_reward,
        deterministic_transition=args.deterministic_transition,
        hidden_features=args.tc_hidden_features,
        latent_dim=args.latent_dim,
        layer_norm=args.layer_norm,
        learning_rate=args.learning_rate,
        logger=args.logger,
        max_timescale=args.max_timescale,
        min_timescale=args.min_timescale,
        num_blocks=args.num_blocks,
        num_layers=args.num_layers,
        numpy_rng=numpy_rng,
        observation_dim=data.data.observations.shape[1],
        observation_embedding_dim=args.observation_embedding_dim,
        phase_split=args.tc_phase_split,
        reward_embedding_dim=args.reward_embedding_dim,
        reward_weight=args.reward_weight,
        reweight=args.tc_reweight,
        rngs=rngs,
        ssm_base_size=args.ssm_base_size,
        steps=args.tc_steps,
        subsample_decodes=args.subsample_decodes,
        subsample_latents=args.subsample_latents,
        threshold=args.tc_threshold,
        transition_weight=args.transition_weight,
        use_next_observation=args.use_next_observation,
    )
    codebook_size = tc_autoencoder.codebook_size
    if args.save:
        args.logger.save_numpy("assignments.npz", data=assignments)
        args.logger.save_toml(
            "codebook_size.toml", obj={"codebook_size": codebook_size}
        )
        args.logger.save_model("tc_autoencoder", model=tc_autoencoder)

    data_without_info = data.data
    max_reward, min_reward = get_max_min_reward(args.dataset)
    if max_reward is None:
        max_reward = float(data_without_info.rewards.max())
    if min_reward is None:
        min_reward = float(data_without_info.rewards.min())

    if args.normalize_rewards:
        coefficient = args.reward_multiplier * normalize_rewards(
            assignments=assignments, data=data.data, gamma=args.gamma
        )
        data_without_info = data_without_info._replace(
            rewards=data_without_info.rewards * coefficient
        )
        max_reward = coefficient * max_reward
        min_reward = coefficient * min_reward

    behavior_actor, classifier, behavior_vcritic = init_behavior_fn(
        args=args,
        assignments=assignments,
        data=data_without_info,
        numpy_rng=numpy_rng,
        rngs=rngs,
        tc_autoencoder=tc_autoencoder,
    )
    actor_dataset, q_learning_dataset, regularizer_dataset = (
        prepare_actor_q_learning_regularizer_dataset(
            actor=behavior_actor,
            assignments=assignments,
            classifier=classifier,
            critic=behavior_vcritic,
            data=data_without_info,
            offset=0 if args.sparse else (max_reward - min_reward),
            threshold=args.threshold,
        )
    )
    q_learning_loader = DataLoader(
        q_learning_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        rng=numpy_rng,
    )
    policy = LBPTCPolicy(
        action_dim=data_without_info.actions.shape[1],
        behavior_actor=behavior_actor,
        classifier=classifier,
        deltas_multiplier=args.deltas_multiplier,
        ensemble_size=args.ensemble_size,
        hidden_features=args.hidden_features,
        layer_norm=args.layer_norm,
        num_behaviors=codebook_size,
        num_layers=args.num_layers,
        observation_dim=data_without_info.observations.shape[1],
        rngs=rngs,
        threshold=args.threshold,
    )
    if args.constant_schedule:
        schedule = args.learning_rate
    else:
        schedule = cosine_decay_schedule(
            args.learning_rate, args.total_steps // args.update_every
        )
    if args.pretrain_steps > 0:
        policy.critic = pretrain_critic_fn(
            gamma=args.gamma,
            logger=args.logger,
            optimizer=get_optimizer(
                policy.critic,
                learning_rate=args.learning_rate,
                max_gradient_norm=args.max_gradient_norm,
            ),
            qcritic=policy.critic,
            q_learning_iter=q_learning_loader.repeat_forever(),
            steps=args.pretrain_steps,
            tau=args.tau,
            update_every=args.update_every,
        )
    train_state = LBPTCTrainState(
        actor_optimizer=get_optimizer(
            policy,
            learning_rate=schedule,
            max_gradient_norm=args.max_gradient_norm,
            wrt=ActorFilter,
        ),
        critic_optimizer=get_optimizer(
            policy.critic,
            learning_rate=args.learning_rate,
            max_gradient_norm=args.max_gradient_norm,
        ),
        policy=policy,
        target_policy=TargetModel(policy, poi=ActorCriticFilter),
    )
    graphdef, graphstate = nnx.split(train_state)
    (
        train_actor_key,
        train_critic_key,
        train_regularizer_key_noise,
        train_regularizer_key_policy,
    ) = jax.random.split(jax.random.key(args.seed), 4)

    return TrainerState(
        actor_iter=DataLoader(
            actor_dataset,
            batch_size=args.batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        eval_state=None,
        graphdef=graphdef,
        graphstate=graphstate,
        min_target=min_reward / (1 - args.gamma),
        ood_threshold=float(
            chi2.isf(args.ood_probability, data_without_info.actions.shape[1])
        ),
        q_learning_iter=q_learning_loader.repeat_forever(),
        regularizer_iter=DataLoader(
            regularizer_dataset,
            batch_size=args.batch_size,
            drop_last=True,
            rng=numpy_rng,
        ).repeat_forever(),
        train_actor_key=train_actor_key,
        train_critic_key=train_critic_key,
        train_regularizer_key_noise=train_regularizer_key_noise,
        train_regularizer_key_policy=train_regularizer_key_policy,
    )


def load_fn(
    step: int | None,
    action_dim: int,
    deltas_multiplier: float,
    ensemble_size: int,
    hidden_features: int,
    layer_norm: bool,
    num_layers: int,
    logger: Logger,
    observation_dim: int,
    threshold: float,
    **kwargs,
):
    del kwargs

    if isinstance(logger, ChildLogger):
        base_logger = logger.parent
    else:
        base_logger = logger

    codebook_size: int
    codebook_size = base_logger.load_toml("codebook_size.toml")["codebook_size"]

    def classifier_model_fn():
        return MLP(
            hidden_features=hidden_features,
            in_features=observation_dim,
            layer_norm=layer_norm,
            num_layers=num_layers,
            out_features=codebook_size,
            rngs=default_nnx_rngs(0),
        )

    classifier = base_logger.restore_model(
        "classifier", model_fn=classifier_model_fn
    )

    def behavior_actor_model_fn():
        return GaussianActorEnsemble(
            action_dim=action_dim,
            ensemble_size=codebook_size,
            hidden_features=hidden_features,
            layer_norm=layer_norm,
            num_layers=num_layers,
            observation_dim=observation_dim,
            out_axis=-2,
            rngs=default_nnx_rngs(0),
        )

    behavior_actor = base_logger.restore_model(
        "behavior_actor", model_fn=behavior_actor_model_fn
    )

    def model_fn():
        return LBPTCPolicy(
            action_dim=action_dim,
            behavior_actor=behavior_actor,
            classifier=classifier,
            deltas_multiplier=deltas_multiplier,
            ensemble_size=ensemble_size,
            hidden_features=hidden_features,
            layer_norm=layer_norm,
            num_behaviors=codebook_size,
            num_layers=num_layers,
            observation_dim=observation_dim,
            rngs=default_nnx_rngs(0),
            threshold=threshold,
        )

    policy = base.default_load_fn(
        logger=logger, model_fn=model_fn, poi=LBPTCPolicy.POI, step=step
    )
    policy.behavior_actor = behavior_actor
    policy.classifier = classifier
    return policy, None


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        skip_reward_normalization=True,
        train_fn=train_fn,
        **vars(build_argument_parser().parse_args()),
    )
