load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test")

package(default_visibility = ["distrib_robust:__subpackages__"])

py_library(
    name = "trainer_federated_lib",
    srcs = ["trainer_federated.py"],
    srcs_version = "PY3",
    deps = [
        "distrib_robust/tasks:cifar100_image",
        "distrib_robust/tasks:emnist_character",
        "distrib_robust/tasks:shakespeare_character",
        "distrib_robust/tasks:stackoverflow_word",
        "distrib_robust/tasks:training_specs",
        "distrib_robust/utils:fed_avg_schedule",
        "distrib_robust/utils:federated_training_loop",
        "distrib_robust/utils:metric_utils",
        "//utils:utils_impl",
        "//utils/optimizers:optimizer_utils",
    ],
)

py_binary(
    name = "trainer_federated",
    srcs = ["trainer_federated.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [
        ":trainer_federated_lib",
    ],
)

py_test(
    name = "trainer_federated_test",
    size = "large",
    srcs = ["trainer_federated_test.py"],
    python_version = "PY3",
    shard_count = 10,
    srcs_version = "PY3",
    deps = [
        "distrib_robust/tasks:cifar100_image",
        "distrib_robust/tasks:emnist_character",
        "distrib_robust/tasks:shakespeare_character",
        "distrib_robust/tasks:stackoverflow_word",
        "distrib_robust/tasks:training_specs",
        "distrib_robust/utils:federated_training_loop",
        "distrib_robust/utils:metric_utils",
    ],
)

py_library(
    name = "eval_metric_distribution",
    srcs = ["eval_metric_distribution.py"],
    srcs_version = "PY3",
    deps = ["//third_party/py/more_itertools"],
)

py_test(
    name = "eval_metric_distribution_test",
    srcs = ["eval_metric_distribution_test.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [":eval_metric_distribution"],
)

py_library(
    name = "trainer_centralized_lib",
    srcs = ["trainer_centralized.py"],
    srcs_version = "PY3",
    deps = [
        "distrib_robust/tasks:cifar100_image",
        "distrib_robust/tasks:emnist_character",
        "distrib_robust/tasks:shakespeare_character",
        "distrib_robust/tasks:stackoverflow_word",
        "distrib_robust/tasks:training_specs",
        "distrib_robust/utils:centralized_training_loop",
        "distrib_robust/utils:metric_utils",
        "//utils:utils_impl",
        "//utils/optimizers:optimizer_utils",
    ],
)

py_binary(
    name = "trainer_centralized",
    srcs = ["trainer_centralized.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [
        ":trainer_centralized_lib",
    ],
)

py_test(
    name = "trainer_centralized_test",
    size = "large",
    srcs = ["trainer_centralized_test.py"],
    python_version = "PY3",
    shard_count = 10,
    srcs_version = "PY3",
    deps = [
        "distrib_robust/tasks:cifar100_image",
        "distrib_robust/tasks:emnist_character",
        "distrib_robust/tasks:shakespeare_character",
        "distrib_robust/tasks:stackoverflow_word",
        "distrib_robust/tasks:training_specs",
        "distrib_robust/utils:centralized_training_loop",
        "distrib_robust/utils:metric_utils",
    ],
)
