#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import os
import re

import setuptools

this_dir = os.path.dirname(os.path.abspath(__file__))


def fetch_requirements():
    with open("requirements.txt") as f:
        reqs = f.read().strip().split("\n")
    return reqs


# https://packaging.python.org/guides/single-sourcing-package-version/
def find_version(version_file_path) -> str:
    with open(version_file_path) as version_file:
        version_match = re.search(r"^__version_tuple__ = (.*)", version_file.read(), re.M)
        if version_match:
            ver_tup = eval(version_match.group(1))
            ver_str = ".".join([str(x) for x in ver_tup])
            return ver_str
        raise RuntimeError("Unable to find version tuple.")


extensions = []
cmdclass = {}
setup_requires = []

if os.getenv("BUILD_CUDA_EXTENSIONS", "0") == "1":
    from torch.utils.cpp_extension import BuildExtension, CUDAExtension

    setup_requires = ["ninja"]
    extensions.extend(
        [
            CUDAExtension(
                name="fairscale.fused_adam_cuda",
                include_dirs=[os.path.join(this_dir, "fairscale/clib/fused_adam_cuda")],
                sources=[
                    "fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp",
                    "fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu",
                ],
                extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"]},
            )
        ]
    )

    cmdclass["build_ext"] = BuildExtension


if __name__ == "__main__":
    setuptools.setup(
        name="fairscale",
        description="FairScale: A PyTorch library for large-scale and high-performance training.",
        version=find_version("fairscale/version.py"),
        setup_requires=setup_requires,
        install_requires=fetch_requirements(),
        include_package_data=True,
        packages=setuptools.find_packages(include=["fairscale*"]),  # Only include code within fairscale.
        ext_modules=extensions,
        cmdclass=cmdclass,
        python_requires=">=3.8",
        author="Foundational AI Research @ Meta AI",
        author_email="todo@meta.com",
        long_description=(
            "FairScale is a PyTorch extension library for high performance and "
            "large scale training on one or multiple machines/nodes. This library "
            "extends basic PyTorch capabilities while adding new experimental ones."
        ),
        long_description_content_type="text/markdown",
        entry_points={"console_scripts": ["wgit = fairscale.experimental.wgit.__main__:main"]},
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "License :: OSI Approved :: BSD License",
            "Topic :: Scientific/Engineering :: Artificial Intelligence",
            "Operating System :: OS Independent",
        ],
    )
