import os, time
import flax.serialization as flax_ser
import jax.numpy as jnp
from typing import NamedTuple


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray


def default_run_dir(cfg):
    stamp = time.strftime("%Y%m%d_%H%M%S")
    name = f"{cfg['ENV_NAME']}_rapo_{stamp}"
    return os.path.join(cfg.get("LOG_DIR", "runs"), name)


def save_params(path, params):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "wb") as f:
        f.write(flax_ser.to_bytes(params))