[project]
name = "metaworld-algorithms"
version = "0.1.0"
description = "Implementations of Multi-Task and Meta-Learning baselines for the Metaworld benchmark"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "gymnasium>=1.0.0",
    "metaworld",
    "distrax>=0.1.5",
    "flax>=0.9.0",
    "jaxtyping>=0.2.34",
    "numpy>=2.1.2",
    "orbax-checkpoint<=0.11.11",
    "tyro>=0.8.11",
    "wandb<=0.19.9",
    "scipy>=1.15.2",
]

[project.optional-dependencies]
cpu = ["jax>=0.5.0"]
# https://github.com/jax-ml/jax/issues/27062
metal = ["jax<=0.5.0", "jax-metal>=0.1.0; sys_platform == 'darwin'"]
cuda12 = ["jax[cuda12]>=0.5.0"]
tpu = ["jax[tpu]>=0.5.0"]
testing = [
    "chex>=0.1.89",
    "pytest>=8.3.5",
]

[tool.uv]
prerelease = "allow"
conflicts = [
  [
    { extra = "cpu" },
    { extra = "metal" },
    { extra = "cuda12" },
    { extra = "tpu" },
  ],
]

[tool.uv.sources]
metaworld = { git = "https://github.com/reginald-mclean/Metaworld.git", rev = "same-step-autoreset" }

[tool.ruff]
ignore = ["F722"]
ignore-init-module-imports = true

[tool.setuptools.packages.find]
include = ["metaworld_algorithms", "metaworld_algorithms.*"]
