from pathlib import Path

import fire
from beartype import beartype
from tensordict import TensorDict

from helpers import logger


@beartype
def log_perf(expert_path: str):
    """Log to stdout the mean and std of expert performance over X demos"""
    logger.configure(directory=None, format_strs=["stdout"])
    logger.set_level(logger.WARN)

    expert_path = Path(expert_path)
    for directory in expert_path.iterdir():
        if not directory.is_dir():
            continue
        logger.warn(f"{directory=}")
        returns = []
        for i, fpath in enumerate(sorted((expert_path / directory.name).glob("*.h5"))):
            td = TensorDict.from_h5(fpath)
            returns.append(td["return"])
            avg = sum(returns) / (i + 1)
            std = (
                sum((r - avg) ** 2 for r in returns) / i
            ) ** 0.5 if i > 0 else 0.  # Bessel correction
            fmt = '"dems{}": {{"avg": {}, "std": {}}},'
            logger.warn(fmt.format(str(i + 1).zfill(2), avg, std))


if __name__ == "__main__":
    fire.Fire(log_perf)
