from argparse import ArgumentParser
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Any

import gymnasium
import numpy as np
from tqdm import tqdm, trange

from offline import helper
from offline.utils.logger import ChildLogger, Logger
from offline.utils.misc import robustify
from offline.utils.parser import MAIN, PARENT


def build_argument_parser():
    parser = ArgumentParser()
    parser.add_argument("paths", nargs="+")
    parser.add_argument("--episodes", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--step", type=robustify(int), default=None)
    parser.add_argument("-s", "--silent", action="store_true")
    return parser


def rollout(
    arguments: dict[str, Any],
    episodes: int,
    logger: Logger,
    seed: int,
    step: int | None,
):
    load_fn = import_module(arguments[MAIN]).load_fn
    if arguments["normalize_observations"]:
        stats = logger.load_numpy("stats.npz")
        mean, std = stats["mean"], stats["std"]
    else:
        mean, std = 0, 1

    env = gymnasium.make(arguments["env"])
    assert isinstance(env.action_space, gymnasium.spaces.Box)
    assert isinstance(env.observation_space, gymnasium.spaces.Box)

    policy, state = load_fn(
        action_dim=np.prod(env.action_space.shape),
        logger=logger,
        observation_dim=np.prod(env.observation_space.shape),
        step=step,
        **arguments
    )
    act_fn: helper.ActFunction[Any] = helper.compile_act(
        action_space=env.action_space,
        mean=mean,
        std=std,
        unsquash=arguments["unsquash"],
    )
    total_results = defaultdict(list)
    for seed_ in trange(seed, seed + episodes, dynamic_ncols=True, leave=False):
        results = helper.rollout(
            act_fn=act_fn, env=env, policy=policy, seed=seed_, state=state
        )
        for key, value in results.items():
            total_results[key].append(value)

    return {key: np.concatenate(value) for key, value in total_results.items()}


def main(
    episodes: int, paths: list[str], seed: int, silent: bool, step: int | None
):
    for path in tqdm(paths, dynamic_ncols=True):
        root = Path(path)
        logger = Logger(root=root)
        arguments = logger.load_args()
        try:
            logger = ChildLogger(root=root, parent=Path(arguments[PARENT]))
        except KeyError:
            pass

        results = rollout(
            arguments=arguments,
            episodes=episodes,
            logger=logger,
            seed=seed,
            step=step,
        )
        if not silent:
            logger.save_numpy("rollouts.npz", **results)


if __name__ == "__main__":
    main(**vars(build_argument_parser().parse_args()))
