[project]
name = "latte_trans"
version = "0.1.0"
description = "Default template for PDM package"
authors = [
    {name = "", email = ""},
]
packages = [
    { include = "latte_trans", from = "src" },
]


requires-python = "==3.10.*"
readme = "README.md"
license = {text = "MIT"}
dependencies = [
    "jax[cuda12]==0.4.31",
    "tqdm==4.66.2",
    "transformers==4.44.2",
    "datasets==2.16.1",
    "wandb==0.16.4",
    "black>=24.8.0",
    "scikit-learn==1.4.1.post1",
    "scipy==1.12.0",
    "sacremoses==0.1.1",
    "evaluate==0.4.1",
    "nltk==3.8.1",
    "sacrebleu==2.4.1",
    "einops>=0.8.0",
    "torch==2.0.1",
    "torchtext==0.15.2",
    "flax==0.8.4",
    "orbax-checkpoint==0.5.17",
    "matplotlib>=3.9.1",
    "accelerate>=0.33.0",
    "zstandard>=0.23.0",
    "pip>=24.2",
    "torchvision>=0.15.2",
    "tensorflow>=2.17.0",
    "jaxtyping>=0.2.33",
    "sentencepiece>=0.2.0",
    "rouge-score>=0.1.2",
    "seaborn>=0.13.2",
    "deepeval>=1.2.2",
    "mpl-sizes>=0.0.2",
    "latex>=0.7.0",
    "mpu>=0.23.1",
]

[tool.black]
line-length = 88

[tool.pdm]
distribution = false

[tool.pdm.dev-dependencies]
dev = [
    "pyright>=1.1.349",
    "ipython>=8.20.0",
    "isort>=5.13.2",
    "flake8>=7.0.0",
    "black>=24.1.0",
    "autoflake>=2.2.1",
    "ipdb>=0.13.13",
]

[[tool.pdm.source]]
name = "jax"
url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
verify_ssl = true
type = "find_links"

[tool.pdm.scripts]
# extra installation step for non pip dependencies
# pre_install = "bash bash/extra_install.sh"
# expects base_dir - root dir of the project, config_file - config with hyperparams, name - name of the experiment
train_lra = "python -u -m latte_trans.experiments.lra"

train_lmsh = "python -u -m latte_trans.experiments.lmsh"
eval_lmsh = "python -u -m latte_trans.experiments.lmsh --evaluate"
train_lm = "python -u -m latte_trans.experiments.lm"
eval_lm = "python -u -m latte_trans.experiments.lm --evaluate"

# lava-multimodal style
train_mulmode = "python -u -m latte_trans.experiments.mulmode"

# pretrain on top of language modelling
train_lm_pret = "python -u -m latte_trans.experiments.lm_pret"
eval_lm_pret = "python -u -m latte_trans.experiments.lm_pret --evaluate"
infer_lm = "python -u -m latte_trans.experiments.inference"

# nmt
prepare_wmt14 = "bash bin/download_wmt14.sh"
train_nmt = "python -u -m latte_trans.experiments.nmt"
eval_nmt = "python -u -m latte_trans.experiments.nmt --evaluate"

# scrolls
train_scrolls = "python -u -m latte_trans.experiments.scrolls"

# copy task
train_copy = "python -u -m latte_trans.experiments.copy_syn"
