from argparse import ArgumentParser
from pathlib import Path
from typing import Any

import gymnasium
import jax
import numpy as np
import tensorflow
from tqdm import tqdm

from offline.lbp.tc.modules.auto_encoder import TCAutoEncoder
from offline.lbp.tc.train import compute_assignments
from offline.lbp.tc.utils import prepare_tc_dataset
from offline.modules.mlp import MLP
from offline.types import OfflineDataWithInfos
from offline.utils.dataset import unsquash_actions
from offline.utils.logger import Logger
from offline.utils.nnx import compute_mlp_outputs, default_nnx_rngs
from offline.utils.parser import RESERVED_KEYWORDS


del tensorflow  # unnecessary


def build_argument_parser():
    parser = ArgumentParser()
    parser.add_argument("--cpu", action="store_true")
    parser.add_argument("-s", "--silent", action="store_true")
    parser.add_argument("paths", nargs="+")
    return parser


def load_fn(
    action_dim: int,
    deterministic_reward: bool,
    deterministic_transition: bool,
    hidden_features: int,
    latent_dim: int,
    layer_norm: bool,
    num_blocks: int,
    num_layers: int,
    logger: Logger,
    observation_dim: int,
    observation_embedding_dim: int,
    reward_embedding_dim: int,
    reward_weight: float,
    ssm_base_size: int,
    tc_hidden_features: int,
    transition_weight: float,
    use_next_observation: bool,
    **kwargs,
):
    del kwargs

    codebook_size: int = logger.load_toml("codebook_size.toml")["codebook_size"]

    def mlp_model_fn():
        return MLP(
            hidden_features=hidden_features,
            in_features=observation_dim,
            layer_norm=layer_norm,
            num_layers=num_layers,
            out_features=codebook_size,
            rngs=default_nnx_rngs(0),
        )

    classifier = logger.restore_model("classifier", model_fn=mlp_model_fn)
    qcritic = logger.restore_model("high_level_qcritic", model_fn=mlp_model_fn)

    def tc_autoencoder_model_fn():
        return TCAutoEncoder(
            action_dim=action_dim,
            clip_eigenvalues=False,
            codebook_size=codebook_size,
            decay=0.99,
            decode_reward=reward_weight > 0,
            decode_transition=transition_weight > 0,
            deterministic_reward=deterministic_reward,
            deterministic_transition=deterministic_transition,
            hidden_features=tc_hidden_features,
            latent_dim=latent_dim,
            max_timescale=0.1,
            min_timescale=0.001,
            num_blocks=num_blocks,
            observation_dim=observation_dim,
            observation_embedding_dim=observation_embedding_dim,
            reward_embedding_dim=reward_embedding_dim,
            rngs=default_nnx_rngs(0),
            ssm_base_size=ssm_base_size,
            use_next_observation=use_next_observation,
        )

    tc_autoencoder = logger.restore_model(
        "tc_autoencoder", model_fn=tc_autoencoder_model_fn
    )

    return classifier, qcritic, tc_autoencoder


def analyze(logger: Logger, arguments: dict[str, Any]) -> dict[str, np.ndarray]:
    env = gymnasium.make(arguments["env"])
    data_with_info: OfflineDataWithInfos
    data_with_info = env.unwrapped.load_data(normalize=False)  # type: ignore
    if arguments["normalize_observations"]:
        stats = logger.load_numpy("stats.npz")
        mean, std = stats["mean"], stats["std"]
        observations = (data_with_info.data.observations - mean) / std
    else:
        observations = data_with_info.data.observations
    if arguments["unsquash"]:
        actions = unsquash_actions(data_with_info.data.actions)
    else:
        actions = data_with_info.data.actions
    data = data_with_info.data._replace(
        actions=actions, observations=observations
    )
    classifier, qcritic, tc_autoencoder = load_fn(
        action_dim=data.actions.shape[1],
        logger=logger,
        observation_dim=data.observations.shape[1],
        **arguments,
    )
    tc_dataset = prepare_tc_dataset(
        ant_maze=arguments["env"].startswith("antmaze"),
        data=data,
        filter_trajectories=False,
    )
    assignments = compute_assignments(
        dataset=tc_dataset, tc_autoencoder=tc_autoencoder
    )
    logits = compute_mlp_outputs(classifier, data.observations)
    qvalues = compute_mlp_outputs(qcritic, data.observations)
    results = {
        "assignments": assignments,
        "dones": np.logical_or(data.terminals, data.timeouts),
        "logits": logits,
        "qvalues": qvalues,
        "rewards": data.rewards,
    }
    try:
        results["sources"] = data_with_info.infos["sources"]
    except KeyError:
        pass
    return results


def main(cpu: bool, paths: list[str], silent: bool):
    if cpu:
        jax.config.update("jax_default_device", jax.devices("cpu")[0])
    logger_list = []
    arguments_list = []
    for path in paths:
        for argfile in Path(path).rglob("arguments.toml"):
            logger = Logger(root=argfile.parent)
            arguments = logger.load_args()
            if arguments[RESERVED_KEYWORDS.MAIN] == "offline.bppo.tc.__main__":
                logger_list.append(logger)
                arguments_list.append(arguments)
    for logger, arguments in zip(tqdm(logger_list), arguments_list):
        results = analyze(logger, arguments)
        if not silent:
            logger.save_numpy("results.npz", **results)


if __name__ == "__main__":
    main(**vars(build_argument_parser().parse_args()))
