from argparse import ArgumentParser
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Any

from flax import nnx
import gymnasium
import numpy as np

from offline import helper
from offline.envs.custom.utils import save_dataset
from offline.envs.registration import get_ref_scores, make_env_and_load_data
from offline.types import FloatArray
from offline.utils.logger import ChildLogger, Logger
from offline.utils.misc import robustify
from offline.utils.parser import RESERVED_KEYWORDS
from offline.utils.tqdm import tqdm


def build_argument_parser():
    parser = ArgumentParser()
    parser.add_argument("path")
    parser.add_argument("--env", default="", dest="env_id")
    parser.add_argument("-f", "--force", action="store_true")
    parser.add_argument("--label", required=True)
    parser.add_argument("--num-samples", type=int, default=1000000)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--step", type=robustify(int), default=None)
    parser.add_argument("-s", "--silent", action="store_true")
    return parser


def main(
    env_id: str,
    force: bool,
    label: str,
    num_samples: int,
    path: str,
    seed: int,
    silent: bool,
    step: int | None,
):
    root = Path(path)
    logger = Logger(root=root)
    arguments = logger.load_args()
    try:
        logger = ChildLogger(
            root=root, parent=Path(arguments[RESERVED_KEYWORDS.PARENT])
        )
    except KeyError:
        pass

    if env_id:
        env = gymnasium.make(env_id)
        max_score, min_score = None, None
    else:
        env, _, _ = make_env_and_load_data(arguments["dataset"])
        if env.spec is None:
            raise ValueError("Cannot save datasets from envs with no EnvSpecs.")
        env_id = env.spec.id
        max_score, min_score = get_ref_scores(arguments["dataset"])

    results = rollout(
        arguments=arguments,
        env=env,
        logger=logger,
        num_samples=num_samples,
        seed=seed,
        step=step,
    )

    if not silent:
        save_dataset(
            env_id=env_id,
            force=force,
            label=label,
            max_score=max_score,
            min_score=min_score,
            results=results,
        )


def rollout(
    arguments: dict[str, Any],
    env: gymnasium.Env,
    logger: Logger,
    num_samples: int,
    seed: int,
    step: int | None,
) -> dict[str, np.ndarray]:
    load_fn = import_module(arguments[RESERVED_KEYWORDS.MAIN]).load_fn
    if arguments["normalize_observations"]:
        stats = logger.load_numpy("stats.npz")
        mean, std = stats["mean"], stats["std"]
    else:
        mean, std = 0, 1
    assert isinstance(env.action_space, gymnasium.spaces.Box)
    assert isinstance(env.observation_space, gymnasium.spaces.Box)

    policy, state = load_fn(
        action_dim=np.prod(env.action_space.shape),
        logger=logger,
        observation_dim=np.prod(env.observation_space.shape),
        step=step,
        **arguments,
    )
    act_fn: helper.ActFunction[Any] = helper.compile_act(
        action_space=env.action_space,
        mean=mean,
        std=std,
        unsquash=arguments["unsquash"],
    )
    results: dict[str, list[Any]] = defaultdict(list)
    samples = 0
    graphdef, graphstate = nnx.split(policy)
    with tqdm(total=num_samples) as progress_bar:
        observation: FloatArray
        observation, _ = env.reset(seed=seed)
        while True:
            results["observations"].append(np.copy(observation))
            observation = observation.ravel()
            action, state, info = act_fn(
                graphdef, graphstate, observation, state
            )
            action = action[0]
            observation, reward, terminated, truncated, env_info = env.step(
                action
            )
            done = terminated or truncated
            results["actions"].append(np.copy(action))
            results["dones"].append(done)
            results["rewards"].append(reward)
            results["terminals"].append(terminated)
            results["timeouts"].append(truncated)
            for key, value in (info | env_info).items():
                results[f"info/{key}"].append(np.copy(value))
            samples += 1
            progress_bar.update(1)
            if done:
                if samples >= num_samples:
                    break
                observation, _ = env.reset()
    return {key: np.asarray(value) for key, value in results.items()}


if __name__ == "__main__":
    main(**vars(build_argument_parser().parse_args()))
