# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diagnose script for checking OS/hardware/python/pip/verl/network.
The output of this script can be a very good hint to issue/problem.
"""
import subprocess
import psutil
import platform, subprocess, sys, os
import socket, time
try:
    from urllib.request import urlopen
    from urllib.parse import urlparse
except ImportError:
    from urlparse import urlparse
    from urllib2 import urlopen
import argparse
import importlib.metadata
import torch

URLS = {
    'PYPI': 'https://pypi.python.org/pypi/pip',
}

REGIONAL_URLS = {
    'cn': {
        'PYPI(douban)': 'https://pypi.douban.com/',
        'Conda(tsinghua)': 'https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/',
    }
}


def test_connection(name, url, timeout=10):
    """Simple connection test"""
    urlinfo = urlparse(url)
    start = time.time()
    try:
        ip = socket.gethostbyname(urlinfo.netloc)
    except Exception as e:
        print('Error resolving DNS for {}: {}, {}'.format(name, url, e))
        return
    dns_elapsed = time.time() - start
    start = time.time()
    try:
        _ = urlopen(url, timeout=timeout)
    except Exception as e:
        print("Error open {}: {}, {}, DNS finished in {} sec.".format(name, url, e, dns_elapsed))
        return
    load_elapsed = time.time() - start
    print("Timing for {}: {}, DNS: {:.4f} sec, LOAD: {:.4f} sec.".format(name, url, dns_elapsed, load_elapsed))


def check_python():
    print('----------Python Info----------')
    print('Version      :', platform.python_version())
    print('Compiler     :', platform.python_compiler())
    print('Build        :', platform.python_build())
    print('Arch         :', platform.architecture())


def check_pip():
    print('------------Pip Info-----------')
    try:
        import pip
        print('Version      :', pip.__version__)
        print('Directory    :', os.path.dirname(pip.__file__))
    except ImportError:
        print('No corresponding pip install for current python.')


def _get_current_git_commit():
    try:
        result = subprocess.run(['git', 'rev-parse', 'HEAD'], capture_output=True, text=True, check=True)
        return result.stdout.strip()
    except subprocess.CalledProcessError as e:
        print(f"Error running git command: {e.stderr.strip()}")
        return None
    except FileNotFoundError:
        print("Did not find command: git")
        return None


def check_verl():
    print('----------verl Info-----------')
    try:
        sys.path.insert(0, os.getcwd())
        import verl
        print('Version      :', verl.__version__)
        verl_dir = os.path.dirname(verl.__file__)
        print('Directory    :', verl_dir)
        try:
            commit_hash = _get_current_git_commit()
            print('Commit Hash  :', commit_hash)
        except AttributeError:
            print('Commit hash not found. ')
    except ImportError as e:
        print(f'No verl installed: {e}')
    except Exception as e:
        import traceback
        if not isinstance(e, IOError):
            print("An error occured trying to import verl.")
            print("This is very likely due to missing missing or incompatible library files.")
        print(traceback.format_exc())


def check_os():
    print('----------Platform Info----------')
    print('Platform     :', platform.platform())
    print('system       :', platform.system())
    print('node         :', platform.node())
    print('release      :', platform.release())
    print('version      :', platform.version())


def check_hardware():
    print('----------Hardware Info----------')
    print('machine      :', platform.machine())
    print('processor    :', platform.processor())
    if sys.platform.startswith('darwin'):
        pipe = subprocess.Popen(('sysctl', '-a'), stdout=subprocess.PIPE)
        output = pipe.communicate()[0]
        for line in output.split(b'\n'):
            if b'brand_string' in line or b'features' in line:
                print(line.strip())
    elif sys.platform.startswith('linux'):
        subprocess.call(['lscpu'])
    elif sys.platform.startswith('win32'):
        subprocess.call(['wmic', 'cpu', 'get', 'name'])


def check_network(args):
    print('----------Network Test----------')
    if args.timeout > 0:
        print('Setting timeout: {}'.format(args.timeout))
        socket.setdefaulttimeout(10)
    for region in args.region.strip().split(','):
        r = region.strip().lower()
        if not r:
            continue
        if r in REGIONAL_URLS:
            URLS.update(REGIONAL_URLS[r])
        else:
            import warnings
            warnings.warn('Region {} do not need specific test, please refer to global sites.'.format(r))
    for name, url in URLS.items():
        test_connection(name, url, args.timeout)


def check_environment():
    print('----------Environment----------')
    for k, v in os.environ.items():
        if k.startswith('VERL_') or k.startswith('OMP_') or k.startswith('KMP_') or k == 'CC' or k == 'CXX':
            print('{}="{}"'.format(k, v))


def check_pip_package_versions():
    packages = ['vllm', 'sglang', 'ray', 'torch']
    for package in packages:
        try:
            version = importlib.metadata.version(package)
            print(f"{package}\t     : {version}")
        except importlib.metadata.PackageNotFoundError:
            print(f"{package}\t     : not found.")


def check_cuda_versions():
    if torch.cuda.is_available():
        try:
            cuda_runtime_version = torch.version.cuda
            print(f"CUDA Runtime : {cuda_runtime_version}")
            import subprocess
            nvcc_output = subprocess.check_output(['nvcc', '--version']).decode('utf-8')
            cuda_compiler_version = next((line for line in nvcc_output.splitlines() if 'release' in line), None)
            if cuda_compiler_version:
                print(f"CUDA Compiler : {cuda_compiler_version.strip()}")
            else:
                print("Could not determine CUDA compiler version.")
        except FileNotFoundError as e:
            print(f"CUDA compiler : Not found: {e}")
        except Exception as e:
            print(f"An error occurred while checking CUDA versions: {e}")
    else:
        print("CUDA is not available.")


def _get_cpu_memory():
    """
    Get the total CPU memory capacity in GB.
    """
    memory = psutil.virtual_memory()
    return memory.total / (1024**3)


def _get_gpu_info():
    """
    Get GPU type, GPU memory, and GPU count using nvidia-smi command.
    """
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=gpu_name,memory.total', '--format=csv,noheader,nounits'],
                                capture_output=True,
                                text=True,
                                check=True)
        gpu_lines = result.stdout.strip().split('\n')
        gpu_count = len(gpu_lines)
        gpu_info = []
        for line in gpu_lines:
            gpu_name, gpu_memory = line.split(', ')
            gpu_info.append({
                'type': gpu_name,
                'memory': float(gpu_memory) / 1024  # Convert to GB
            })
        return gpu_count, gpu_info
    except subprocess.CalledProcessError:
        print("Failed to execute nvidia-smi command.")
        return 0, []


def _get_system_info():
    """
    Get CPU memory capacity, GPU type, GPU memory, and GPU count.
    """
    cpu_memory = _get_cpu_memory()
    gpu_count, gpu_info = _get_gpu_info()
    return {'cpu_memory': cpu_memory, 'gpu_count': gpu_count, 'gpu_info': gpu_info}


def check_system_info():
    print('----------System Info----------')
    system_info = _get_system_info()
    print(f"CPU Memory\t: {system_info['cpu_memory']:.2f} GB")
    print(f"GPU Count\t: {system_info['gpu_count']}")
    for i, gpu in enumerate(system_info['gpu_info']):
        print(f"GPU {i + 1}\tType    : {gpu['type']}")
        print(f"GPU {i + 1}\tMemory  : {gpu['memory']:.2f} GB")


def parse_args():
    """Parse arguments."""
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description='Diagnose script for checking the current system.')
    choices = ['python', 'pip', 'verl', 'system', 'os', 'environment']
    for choice in choices:
        parser.add_argument('--' + choice, default=1, type=int, help='Diagnose {}.'.format(choice))
    parser.add_argument('--network', default=0, type=int, help='Diagnose network.')
    parser.add_argument('--hardware', default=0, type=int, help='Diagnose hardware.')
    parser.add_argument('--region',
                        default='',
                        type=str,
                        help="Additional sites in which region(s) to test. \
                        Specify 'cn' for example to test mirror sites in China.")
    parser.add_argument('--timeout', default=10, type=int, help="Connection test timeout threshold, 0 to disable.")
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    if args.python:
        check_python()

    if args.pip:
        check_pip()
        check_pip_package_versions()

    if args.verl:
        check_verl()

    if args.os:
        check_os()

    if args.hardware:
        check_hardware()

    if args.network:
        check_network(args)

    if args.environment:
        check_environment()
        check_cuda_versions()

    if args.system:
        check_system_info()
