from argparse import ArgumentParser
from pathlib import Path
from typing import Any

import jax
import numpy as np
from tqdm import tqdm

from offline.envs.registration import make_env_and_load_data
from offline.lbp.tc.utils import compute_mask_means_stds_values
from offline.modules.actor.ensemble import GaussianActorEnsemble
from offline.modules.mlp import MLP
from offline.types import OfflineDataWithInfos
from offline.utils.dataset import unsquash_actions
from offline.utils.logger import Logger
from offline.utils.nnx import compute_mlp_outputs, default_nnx_rngs
from offline.utils.parser import RESERVED_KEYWORDS


def build_argument_parser():
    parser = ArgumentParser()
    parser.add_argument("--cpu", action="store_true")
    parser.add_argument("-s", "--silent", action="store_true")
    parser.add_argument("paths", nargs="+")
    return parser


def load_fn(
    action_dim: int,
    hidden_features: int,
    layer_norm: bool,
    num_layers: int,
    logger: Logger,
    observation_dim: int,
    **kwargs,
):
    del kwargs

    codebook_size: int = logger.load_toml("codebook_size.toml")["codebook_size"]

    def behavior_actor_model_fn():
        return GaussianActorEnsemble(
            action_dim=action_dim,
            ensemble_size=codebook_size,
            hidden_features=hidden_features,
            layer_norm=layer_norm,
            num_layers=num_layers,
            observation_dim=observation_dim,
            out_axis=-2,
            rngs=default_nnx_rngs(0),
        )

    def mlp_model_fn():
        return MLP(
            hidden_features=hidden_features,
            in_features=observation_dim,
            layer_norm=layer_norm,
            num_layers=num_layers,
            out_features=codebook_size,
            rngs=default_nnx_rngs(0),
        )

    behavior_actor = logger.restore_model(
        "behavior_actor", model_fn=behavior_actor_model_fn
    )
    classifier = logger.restore_model("classifier", model_fn=mlp_model_fn)
    vcritic = logger.restore_model("behavior_vcritic", model_fn=mlp_model_fn)
    return behavior_actor, classifier, vcritic


def analyze(logger: Logger, arguments: dict[str, Any]) -> dict[str, np.ndarray]:
    _, data_with_info, _ = make_env_and_load_data(arguments["dataset"])
    if arguments["normalize_observations"]:
        stats = logger.load_numpy("stats.npz")
        mean, std = stats["mean"], stats["std"]
        observations = (data_with_info.data.observations - mean) / std
    else:
        observations = data_with_info.data.observations
    if arguments["unsquash"]:
        actions = unsquash_actions(data_with_info.data.actions)
    else:
        actions = data_with_info.data.actions
    data = data_with_info.data._replace(
        actions=actions, observations=observations
    )
    behavior_actor, classifier, vcritic = load_fn(
        action_dim=data.actions.shape[1],
        logger=logger,
        observation_dim=data.observations.shape[1],
        **arguments,
    )
    assignments = logger.load_numpy("assignments.npz")["data"]
    logits = compute_mlp_outputs(classifier, data.observations)
    mask, means, stds, values = compute_mask_means_stds_values(
        actor=behavior_actor,
        classifier=classifier,
        critic=vcritic,
        observations=data.observations,
        threshold=arguments["threshold"],
    )
    results = {
        "assignments": assignments,
        "dones": np.logical_or(data.terminals, data.dones),
        "logits": logits,
        "mask": mask,
        "means": means,
        "rewards": data.rewards,
        "stds": stds,
        "values": values,
    }
    try:
        results["sources"] = data_with_info.infos["sources"]
    except KeyError:
        pass
    return results


def main(cpu: bool, paths: list[str], silent: bool):
    if cpu:
        jax.config.update("jax_default_device", jax.devices("cpu")[0])
    logger_list = []
    arguments_list = []
    for path in paths:
        for argfile in Path(path).rglob("arguments.toml"):
            logger = Logger(root=argfile.parent)
            arguments = logger.load_args()
            if ".tc.__main__" in arguments[RESERVED_KEYWORDS.MAIN]:
                logger_list.append(logger)
                arguments_list.append(arguments)
    for logger, arguments in zip(tqdm(logger_list), arguments_list):
        results = analyze(logger, arguments)
        if not silent:
            logger.save_numpy("results.npz", **results)


if __name__ == "__main__":
    main(**vars(build_argument_parser().parse_args()))
