load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test")

licenses(["notice"])

package(
    default_applicable_licenses = ["//:package_license"],
    default_visibility = [
        ":distributed_dp_matrix_factorization_packages",
    ],
)

py_library(
    name = "accounting_utils",
    srcs = ["accounting_utils.py"]
)

py_library(
    name = "compression_query",
    srcs = ["compression_query.py"],
    deps = [
        ":compression_utils",
    ]
)

py_library(
    name = "compression_utils",
    srcs = ["compression_utils.py"]
)

py_library(
    name = "discrete_gaussian_utils",
    srcs = ["discrete_gaussian_utils.py"],
)

package_group(
    name = "distributed_dp_matrix_factorization_packages",
    packages = ["//distributed_dp_matrix_factorization/..."],
)

py_binary(
    name = "factorize_prefix_sum",
    srcs = ["factorize_prefix_sum.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [
        ":initializers",
        ":loops",
    ],
)

py_library(
    name = "fixed_point_library",
    srcs = ["fixed_point_library.py"],
    srcs_version = "PY3",
)

py_test(
    name = "fixed_point_library_test",
    srcs = ["fixed_point_library_test.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [":fixed_point_library"],
)

py_library(
    name = "initializers",
    srcs = ["initializers.py"],
    srcs_version = "PY3",
    deps = [":matrix_constructors"],
)

py_test(
    name = "initializers_test",
    srcs = ["initializers_test.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [":initializers"],
)

py_library(
    name = "loops",
    srcs = ["loops.py"],
    srcs_version = "PY3",
    deps = [
        ":fixed_point_library",
        ":matrix_constructors",
        ":solvers",
    ],
)

py_test(
    name = "loops_test",
    srcs = ["loops_test.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [
        ":constraint_builders",
        ":loops",
        ":matrix_constructors",
    ],
)

py_library(
    name = "matrix_constructors",
    srcs = ["matrix_constructors.py"],
    srcs_version = "PY3",
    deps = [
        ":constraint_builders",
        ":matrix_factorization_query",
    ],
)

py_test(
    name = "matrix_constructors_test",
    srcs = ["matrix_constructors_test.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [":matrix_constructors"],
)

py_library(
    name = "matrix_factorization_query",
    srcs = ["matrix_factorization_query.py"],
    srcs_version = "PY3",
    deps = [
        ":discrete_gaussian_utils"
    ]
)

py_test(
    name = "matrix_factorization_query_test",
    srcs = ["matrix_factorization_query_test.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [":matrix_factorization_query"],
)

py_library(
    name = "modular_clipping_factory",
    srcs = ["modular_clipping_factory.py"],
)

py_library(
    name = "constraint_builders",
    srcs = ["constraint_builders.py"],
    srcs_version = "PY3",
)

py_test(
    name = "constraint_builders_test",
    srcs = ["constraint_builders_test.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [":constraint_builders"],
)

py_library(
    name = "tff_aggregator",
    srcs = ["tff_aggregator.py"],
    srcs_version = "PY3",
    deps = [
        ":matrix_constructors",
        ":matrix_factorization_query",
        ":compression_query",
        ":modular_clipping_factory",
        ":accounting_utils"
    ],
)

py_test(
    name = "tff_aggregator_test",
    srcs = ["tff_aggregator_test.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [
        ":matrix_constructors",
        ":matrix_factorization_query",
        ":tff_aggregator",
    ],
)

py_library(
    name = "solvers",
    srcs = ["solvers.py"],
    srcs_version = "PY3",
)

py_test(
    name = "solvers_test",
    srcs = ["solvers_test.py"],
    python_version = "PY3",
    srcs_version = "PY3",
    deps = [
        ":constraint_builders",
        ":matrix_constructors",
        ":solvers",
    ],
)
